rt,直接上代码
import torch
import torch.utils.data as Data
from torch.nn.functional import softmax
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
BATCH_SIZE = 100
x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
)
net = torch.nn.Sequential(
torch.nn.Linear(1, 1),
torch.nn.ReLU(),
torch.nn.Linear(1, 1),
)
plt.ion()
plt.show()
optimizer = torch.optim.SGD(net.parameters(), lr=0.05)
loss_func = torch.nn.MSELoss()
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
for t in range(1000):
prediction = net(batch_x)
loss = loss_func(prediction, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if t % 5 == 0:
plt.cla()
plt.scatter(batch_x.data.numpy(), batch_y.data.numpy())
plt.plot(batch_x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'loss=%.4f' % loss.data, fontdict={'size': 20,'color': 'red'})
plt.pause(0.1)
plt.ioff()
plt.show()
如果在IDE运行时会出现一堆诡异的图像,求解
顶
– FogFoge 4年前再顶
– FogFoge 4年前继续顶
– FogFoge 4年前求回复
– FogFoge 4年前