首页 新闻 搜索 专区 学院

用pytorch的scatter函数把标签转换为one-hot编码为啥会出错啊……

0
悬赏园豆:50 [待解决问题]

import torch
import numpy as np

x = np.array([1, 2, 3])
x_tensor = torch.from_numpy(x)
y = x_tensor.reshape(-1, 1)
y_one_hot = torch.zeros(3, 5)
y_one_hot.scatter_(1, y, 1)
print(y_one_hot)

运行结果在这里:

大佬们,为什么会出现这种问题啊,有没有解决方法!!!

玖玖牛的主页 玖玖牛 | 初学一级 | 园豆:152
提问于:2020-06-12 03:03
< >
分享
清除回答草稿
   您需要登录以后才能回答,未注册用户请先注册