class Generator(nn.Module):
def init(self):
super().init()
self.conv1=Sequential(
nn.Conv2d(3,96,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(96,96,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(96,96,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.2,inplace=True)
)
self.down1=Sequential(
nn.Conv2d(96,192,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(192),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(192,192,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(192),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(192,192,kernel_size=3,padding=1,stride=2),
nn.InstanceNorm2d(192),
nn.LeakyReLU(0.2,inplace=True)
)
self.down2=Sequential(
nn.Conv2d(192,384,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(384),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(384,384,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(384),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(384,384,kernel_size=3,padding=1,stride=2),
nn.InstanceNorm2d(384),
nn.LeakyReLU(0.2,inplace=True)
)
self.down3=Sequential(
nn.Conv2d(384,768,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(768),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(768,768,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(768),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(768,768,kernel_size=3,padding=1,stride=2),
nn.InstanceNorm2d(768),
nn.LeakyReLU(0.2,inplace=True)
)
self.down_and_up=Sequential(
nn.Conv2d(768,1536,kernel_size=3,padding=1,stride=2),
nn.InstanceNorm2d(1536),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(1536,1536,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(1536),
nn.LeakyReLU(0.2,inplace=True),
nn.ConvTranspose2d(1536,768,kernel_size=2,padding=0,stride=2),
nn.InstanceNorm2d(768),
nn.LeakyReLU(0.2,inplace=True)
)
self.up1=Sequential(
nn.Conv2d(1536,768,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(768),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(768,768,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(768),
nn.LeakyReLU(0.2,inplace=True),
nn.ConvTranspose2d(768,384,kernel_size=2,padding=0,stride=2),
nn.InstanceNorm2d(384),
nn.LeakyReLU(0.2,inplace=True)
)
self.up2=Sequential(
nn.Conv2d(768,384,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(384),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(384,384,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(384),
nn.LeakyReLU(0.2,inplace=True),
nn.ConvTranspose2d(384,192,kernel_size=2,padding=0,stride=2),
nn.InstanceNorm2d(192),
nn.LeakyReLU(0.2,inplace=True)
)
self.up3=Sequential(
nn.Conv2d(384,192,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(192),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(192,192,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(192),
nn.LeakyReLU(0.2,inplace=True),
nn.ConvTranspose2d(192,96,kernel_size=2,padding=0,stride=2),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.2,inplace=True)
)
self.conv2=Sequential(
nn.Conv2d(192,96,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(96,96,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(96,96,kernel_size=3,padding=1,stride=1),
nn.InstanceNorm2d(96),
nn.LeakyReLU(0.2,inplace=True)
)
self.output=nn.Sequential(
nn.Conv2d(96,3,kernel_size=1,stride=1),
nn.Tanh()
)
def forward(self,x):
x1=self.conv1(x)
x2=self.down1(x1)
x3=self.down2(x2)
x4=self.down3(x3)
x=self.down_and_up(x4)
x=torch.cat([x,x4],dim=1)
x=self.up1(x)
x=torch.cat([x,x3],dim=1)
x=self.up2(x)
x=torch.cat([x,x2],dim=1)
x=self.up3(x)
x=torch.cat([x,x1],dim=1)
x=self.conv2(x)
x=self.output(x)
return x
class Distinguish(nn.Module):
def init(self):
super().init()
self.model=nn.Sequential(
nn.Conv2d(6,64,kernel_size=11,padding=5,stride=2),
nn.LeakyReLU(0.2,inplace=True),
nn.Dropout(0.2),
nn.Conv2d(64,128,kernel_size=5,padding=2),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(128,128,kernel_size=5,padding=2),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128,256,kernel_size=5,padding=2,stride=2),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(256,512,kernel_size=5,padding=2),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(512,512,kernel_size=5,padding=2),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512,512,kernel_size=5,padding=2,stride=2),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(512,512,kernel_size=5,padding=1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2,inplace=True),
nn.Conv2d(512,1,kernel_size=3,padding=1),
nn.Flatten(),
)
def forward(self,I,O):
x=torch.cat([I,O],dim=1)
x=self.model(x)
return x
D=Distinguish()
try:
D.load_state_dict(torch.load("D.pth"))
except:
pass
G=Generator().cuda()
try:
G.load_state_dict(torch.load("G.pth"))
except:
pass
totensor=transforms.Compose([
ToTensor(),
Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
class Pairdata():
def init(self):
...
def getitem(self):
...
def len(self):
return len(self.pair_img_path)
pairdate_loader=Pairdata()
D_optim=torch.optim.RMSprop(D.parameters(),lr=0.02)
G_optim=torch.optim.RMSprop(G.parameters(),lr=0.02)
l1loss=nn.L1Loss()
BCEloss=nn.BCEWithLogitsLoss()
for i in range(5):
D_optim.zero_grad()
G_optim.zero_grad()
for j in range(4):
input_img,output_img=pairdate_loader.getitem()
with torch.no_grad():
fake_img=G(input_img.cuda())
pred_real=D(input_img,output_img)
loss_real=-torch.mean(pred_real)
pred_fake=D(input_img,fake_img.cpu())
loss_fake=torch.mean(pred_fake)
current_D_loss=(loss_real+loss_fake)
(current_D_loss).backward()
print('D:')
print(current_D_loss.item())
fake_img=G(input_img.cuda())
pred_fake=D(input_img,fake_img.cpu())
advloss=-torch.mean(pred_fake)
l1_loss=l1loss(fake_img,output_img.cuda())
current_G_loss=advloss+l1_loss*2
(current_G_loss).backward()
print('G:')
print(current_G_loss.item())
D_optim.step()
G_optim.step()
clip_value=0.02
for param in D.parameters():
param.data.clamp_(-clip_value,clip_value)
print('返回梯度')
torch.save(D.state_dict(),"D.pth")
torch.save(G.state_dict(),"G.pth")
unloader=transforms.ToPILImage()
def tensor_to_PIL(tensor):
image=tensor.cpu().clone()
image=image.squeeze(0)
image=(image+1)/2.0
image=unloader(image)
return image
tensor_to_PIL(G(pairdate_loader.getitem()[0].cuda())).show()
while True:
pass
建议改进一下排版,支持 markdown 语法
– dudu 5个月前