Pytorch错误解决

1.writer.add_graph(model, model_input)

1
2
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('./run/exp')

使用TensorBoard对训练可视化,在增加模型的计算图的时候,报了以下错误:

1
writer.add_graph(model, [dummy_input, dummy_input,dummy_input], verbose=True)

RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

输入的数据类型与网络参数的类型不符

查询网上都说:

但是全都是bullshit!!!!!

正确解决方案:

1
2
dummy_input = torch.rand(2, 3, 378,378)
writer.add_graph(model, [dummy_input.cuda(), dummy_input.cuda(), dummy_input.cuda()], verbose=True)

注意:pytorch模型不会记录其输入输出的大小,更不会记录每层输出的尺寸。

所以,tensorbaord需要一个假的数据 datadummy_input)来探测网络各层输出大小,并指示输入尺寸。