首页 新闻 会员 周边

Pytorch神经网络出现问题

0
悬赏园豆:10 [待解决问题]

Pytorch神经网络出现问题

rt,直接上代码

import torch
import torch.utils.data as Data
from torch.nn.functional import softmax
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

BATCH_SIZE = 100

x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())

torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

net = torch.nn.Sequential(
    torch.nn.Linear(1, 1),
    torch.nn.ReLU(),
    torch.nn.Linear(1, 1),
)

plt.ion()
plt.show()

optimizer = torch.optim.SGD(net.parameters(), lr=0.05)
loss_func = torch.nn.MSELoss()

for epoch in range(3):
    for step, (batch_x, batch_y) in enumerate(loader):
        for t in range(1000):
            prediction = net(batch_x)

            loss = loss_func(prediction, batch_y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if t % 5 == 0:
                plt.cla()
                plt.scatter(batch_x.data.numpy(), batch_y.data.numpy())
                plt.plot(batch_x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
                plt.text(0.5, 0, 'loss=%.4f' % loss.data, fontdict={'size': 20,'color': 'red'})
                plt.pause(0.1)

plt.ioff()
plt.show()

如果在IDE运行时会出现一堆诡异的图像,求解

FogFoge的主页 FogFoge | 初学一级 | 园豆:192
提问于:2020-06-21 13:30

FogFoge 3年前

再顶

FogFoge 3年前

继续顶

FogFoge 3年前

求回复

FogFoge 3年前
< >
分享
清除回答草稿
   您需要登录以后才能回答,未注册用户请先注册