1.writer.add_graph(model, model_input)
1  | from torch.utils.tensorboard import SummaryWriter  | 
使用TensorBoard对训练可视化,在增加模型的计算图的时候,报了以下错误:
1  | writer.add_graph(model, [dummy_input, dummy_input,dummy_input], verbose=True)  | 
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
输入的数据类型与网络参数的类型不符。
查询网上都说:

但是全都是bullshit!!!!!
正确解决方案:
1  | dummy_input = torch.rand(2, 3, 378,378)  | 
注意:pytorch模型不会记录其输入输出的大小,更不会记录每层输出的尺寸。
所以,tensorbaord需要一个假的数据 data (dummy_input)来探测网络各层输出大小,并指示输入尺寸。


