基础引入
import os from PIL import Image from datetime import datetime import torch from torch import nn from torch.utils.data import Dataset from torch.utils.data import DataLoader from torchvision import transforms from torchvision.utils import save_image
生成器
class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() # 预处理层: 100 -> 4*4*1024 latent_dim = 100 self.linear1 = nn.Linear(in_features=100, out_features=4*4*1024) # view重组: 4*4*1024 -> 1024,4,4 # 网络组合 self.model_blocks = nn.Sequential( # 第一层网络:1024,4,4 -> 512,8,8 nn.Upsample(scale_factor=2), # 1024,4,4 -> 1024,8,8 nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), # 1024,8,8 -> 512,8,8 nn.BatchNorm2d(512), nn.ReLU(inplace=True), # 第二层网络: 512,8,8 -> 256,16,16 nn.Upsample(scale_factor=2), nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), # 第三层网络: 256,16,16 -> 128,32,32 nn.Upsample(scale_factor=2), nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), # 第四层网络:128,32,32 -> 3,64,64 nn.Upsample(scale_factor=2), nn.Conv2d(128, 3, kernel_size=3, stride=1, padding=1), # nn.BatchNorm2d(3), nn.Tanh() ) # 前向传播方法 def forward(self, z): z = self.linear1(z) z = z.view(z.shape[0], 1024, 4, 4) img = self.model_blocks(z) return img
判别器
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model_blocks = nn.Sequential( # 输入: 3,64,64 # 第一层网络:3,64,64 -> 128,32,32 nn.Conv2d(3, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.5), # 第二层网络:128,32,32 -> 256,16,16 nn.Conv2d(128, 256, 3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.5), # 第三层网络:256,16,16 -> 512,8,8 nn.Conv2d(256, 512, 3, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.5), # 第四层网络:512,8,8 -> 1024,4,4 nn.Conv2d(512, 1024, 3, stride=2, padding=1), nn.BatchNorm2d(1024), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.5) ) # 展开: 1024,4,4 -> 4*4*1024 # 输出层 self.output = nn.Sequential( nn.Linear(in_features=4*4*1024, out_features=1), nn.Sigmoid() ) # 前向传播 def forward(self, x): """ x : batch,channel,w,h -> 64,3,64,64""" y = self.model_blocks(x) y = y.view(x.shape[0], -1) y = self.output(y) return y
数据加载器
# 数据加载器 class ImgDataset(Dataset): # 初始化 def __init__(self, root_dir, transform=None): self.transform=transform # 获取所有图片路径 self.img_paths = [] for filename in os.listdir(root_dir): if filename.lower().endswith(('.png', '.jpg', '.jpeg')): self.img_paths.append(os.path.join(root_dir, filename)) # 打印图片总数 print(f"All : {len(self.img_paths)}") # 长度方法 def __len__(self): return len(self.img_paths) # 提取其中一张图片的数据 def __getitem__(self, idx): # 找到加载图片 img_path = self.img_paths[idx] img = Image.open(img_path).convert('RGB') # 转换 if self.transform: img = self.transform(img) # 返回图片数据及其分类,所有图片只有一个分类 return img, 0
训练过程
def get_model_instance(device, g_lr, d_lr, b1, b2): # 图片生成器 generator = Generator().to(device) # 图片判别器 discriminator = Discriminator().to(device) # 优化器 optimizer_g = torch.optim.Adam(generator.parameters(), lr=g_lr, betas=(b1, b2)) optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=d_lr, betas=(b1, b2)) # 损失函数 AnimeFace criterion = torch.nn.BCELoss() return generator, discriminator, optimizer_g, optimizer_d, criterion def get_data_loader(root_dir, batch_size): # 图片数据转换:所有图片均为128*128, Resize为 64*64 transform = transforms.Compose([ transforms.Resize(64), transforms.CenterCrop(64), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # 图片加载器 dataset = ImgDataset(root_dir=root_dir, transform=transform) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=8, prefetch_factor=4, persistent_workers=True, drop_last=True ) return dataloader def append_logs(data, path='dcgan_logs.txt'): with open(path, 'a', encoding='utf-8') as f: print(data) f.write(data + '\n') def train_model(dataloader, discriminator, optimizer_d, generator, optimizer_g, criterion, epochs, device): append_logs(f"start : {datetime.now()}") for epoch in range(epochs): for i, (real_imgs, _) in enumerate(dataloader): # 内部批次 in_batch_size = real_imgs.shape[0] # 图片数据加载到设备 real_imgs = real_imgs.to(device) # 真假标签 real_labels = torch.ones(in_batch_size, 1, device=device) * 0.8 fake_labels = torch.ones(in_batch_size, 1, device=device) * 0.2 # 训练判别器 optimizer_d.zero_grad() # 真实损失 d_loss_real = criterion(discriminator(real_imgs), real_labels) # 生成假图片 z = torch.randn(in_batch_size, 100, device=device) # z = linear_scaling_batch(z) with torch.no_grad(): fake_imgs = generator(z) fake_imgs_detach = fake_imgs.detach() # 假图片损失 d_loss_fake = criterion(discriminator(fake_imgs_detach), fake_labels) # 总的判别器损失为: d_loss = (d_loss_real + d_loss_fake) / 2 # 反向传播与梯度更新 d_loss.backward() optimizer_d.step() # 生成图片 z = torch.randn(in_batch_size, 100, device=device) # z = linear_scaling_batch(z) fake_imgs = generator(z) # 生成器训练 optimizer_g.zero_grad() g_loss = criterion(discriminator(fake_imgs), real_labels) g_loss.backward() optimizer_g.step() # 每100批次记录日志 if i % 100 == 0: append_logs(f"{datetime.now()}, {epoch}, {i}, d_loss: {d_loss.item():.6f}, g_loss: {g_loss.item():.6f}") # 每轮结束打印 append_logs(f"{datetime.now()}, {epoch}, {i}, d_loss: {d_loss.item():.6f}, g_loss: {g_loss.item():.6f}") # 每5轮输出一次样例图片 if epoch % 2 == 0: save_image( fake_imgs.data[:25], f"sample_{epoch}.png", nrow=5, normalize=True ) # 保存模型 torch.save(generator.state_dict(), f'generator{epoch}.pth') print(f'generator{epoch}.pth saved') torch.save(discriminator.state_dict(), f'discriminator{epoch}.pth') print(f'discriminator{epoch}.pth saved') append_logs(f"end : {datetime.now()}")
基本参数
base_path = "/opt/notebook/dcgan/images_resized" batch_size = 64 g_lr = 0.0002 d_lr = 0.0002 b1 = 0.5 b2 = 0.999 epochs = 50 device = torch.device('cuda:0') print(device) (generator, discriminator, optimizer_g, optimizer_d, criterion) = get_model_instance(device, g_lr, d_lr, b1, b2) dataloader = get_data_loader(base_path, batch_size) train_model( dataloader, discriminator, optimizer_d, generator, optimizer_g, criterion, epochs, device ) # 保存模型 torch.save(generator.state_dict(), 'generator.pth') print('generator.pth saved')
生成器和判别器相互博弈,两个损失此消彼长不稳定,关键看生成的图片质量如何。