def distconvert(self, dist, s, m):
a = torch.ones(dist.size())
b = torch.ones(dist.size())
c = a + b
final_dist = torch.pow(Variable(c),dist)
return final_dist
执行到第5行时,报如下错误:
(float base)didn't match because some of the arguments have invalid types:(torch.cuda.FloatTensor)
type(dist)类型时,也是Variable类型的