动漫人物头像

创建日期:2025-03-10
更新日期:2025-03-10

数据集下载:https://www.kaggle.com/datasets/soumikrakshit/anime-faces

训练代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import os

num_epochs = 100
batch_size = 64
learning_rate = 0.0002
latent_dim = 100
img_channels = 3
feature_maps = 64

anime_faces_folder = "./data/Anime Faces"
generator_model_path = "./model/anime_faces/generator.pth"
discriminator_model_path = "./model/anime_faces/discriminator.pth"

if not os.path.exists("./model/anime_faces"):
    os.makedirs("./model/anime_faces")

class AnimeFacesDataset(Dataset):
    def __init__(self, root, transform=None):
        super().__init__()
        self.root = root
        self.transform = transform

        self.files = []
        self.labels = []

        for file in os.listdir(root):
            self.files.append(file)
            self.labels.append(1)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        path = os.path.join(self.root, self.files[index])
        image = Image.open(path).convert("RGB")
        label = self.labels[index]
        if self.transform:
            image = self.transform(image)
        return image, label

transform = transforms.Compose(
    [
        transforms.Resize((32, 32)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

dataset = AnimeFacesDataset(root=anime_faces_folder, transform=transform)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Generator(nn.Module):
    def __init__(self, latent_dim, img_channels, feature_maps):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(
                latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0
            ),
            nn.BatchNorm2d(feature_maps * 8),
            nn.ReLU(),
            nn.ConvTranspose2d(
                feature_maps * 8, feature_maps * 4, kernel_size=4, stride=2, padding=1
            ),
            nn.BatchNorm2d(feature_maps * 4),
            nn.ReLU(),
            nn.ConvTranspose2d(
                feature_maps * 4, feature_maps * 2, kernel_size=4, stride=2, padding=1
            ),
            nn.BatchNorm2d(feature_maps * 2),
            nn.ReLU(),
            nn.ConvTranspose2d(
                feature_maps * 2, img_channels, kernel_size=4, stride=2, padding=1
            ),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.net(x)

class Discriminator(nn.Module):
    def __init__(self, img_channels, feature_maps):
        super().__init__()
        self.net = nn.Sequential(
            # 输入: (3, 32, 32)
            nn.Conv2d(img_channels, feature_maps, 4, 2, 1),  # 输出: (64, 16, 16)
            nn.LeakyReLU(0.2),
            nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1),  # 输出: (128, 8, 8)
            nn.BatchNorm2d(feature_maps * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1),  # 输出: (256, 4, 4)
            nn.BatchNorm2d(feature_maps * 4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(feature_maps * 4, 1, 4, 1, 0),  # 输出: (1, 1, 1)
            nn.Sigmoid(),
            nn.Flatten(),
        )

    def forward(self, x):
        return self.net(x)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

generator = (
    Generator(latent_dim, img_channels, feature_maps).apply(weights_init).to(device)
)
discriminator = Discriminator(img_channels, feature_maps).apply(weights_init).to(device)

if os.path.exists(generator_model_path):
    generator.load_state_dict(torch.load(generator_model_path))

if os.path.exists(discriminator_model_path):
    discriminator.load_state_dict(torch.load(discriminator_model_path))

criterion = nn.MSELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(
    discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999)
)

for epoch in range(num_epochs):
    for i, (real_imgs, _) in enumerate(dataloader):
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)

        # 训练判别器
        optimizer_D.zero_grad()

        # 真实图像标签为1
        real_labels = torch.full((batch_size, 1), 0.9, device=device)
        real_output = discriminator(real_imgs)
        d_loss_real = criterion(real_output, real_labels)

        # 生成假图像
        z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
        fake_imgs = generator(z)

        # 假图像标签为0
        fake_labels = torch.full((batch_size, 1), 0.1, device=device)
        fake_output = discriminator(fake_imgs.detach())
        d_loss_fake = criterion(fake_output, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # 训练生成器两次
        for _ in range(2):
            optimizer_G.zero_grad()
            # 重新生成假图像以增加多样性
            z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
            fake_imgs = generator(z)
            output = discriminator(fake_imgs)
            # 生成器目标为真实标签
            g_loss = criterion(output, torch.ones_like(output))
            g_loss.backward()
            optimizer_G.step()

        if i % 100 == 0:
            print(
                f"Epoch [{epoch + 1}/{num_epochs}] Batch {i}/{len(dataloader)} "
                f"Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}"
            )

torch.save(generator.state_dict(), generator_model_path)
torch.save(discriminator.state_dict(), discriminator_model_path)

# 生成随机噪声
z = torch.randn(1, latent_dim, 1, 1).to(device)
# 生成图像
generated_img = generator(z).detach().cpu()
# 反归一化处理
generated_img = generated_img.squeeze().permute(1, 2, 0) * 0.5 + 0.5
# 可视化
plt.imshow(generated_img)
plt.axis("off")
plt.show()

生成代码

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os

latent_dim = 100
img_channels = 3
feature_maps = 64

generator_model_path = "./model/anime_faces/generator.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Generator(nn.Module):
    def __init__(self, latent_dim, img_channels, feature_maps):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(
                latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0
            ),
            nn.BatchNorm2d(feature_maps * 8),
            nn.ReLU(),
            nn.ConvTranspose2d(
                feature_maps * 8, feature_maps * 4, kernel_size=4, stride=2, padding=1
            ),
            nn.BatchNorm2d(feature_maps * 4),
            nn.ReLU(),
            nn.ConvTranspose2d(
                feature_maps * 4, feature_maps * 2, kernel_size=4, stride=2, padding=1
            ),
            nn.BatchNorm2d(feature_maps * 2),
            nn.ReLU(),
            nn.ConvTranspose2d(
                feature_maps * 2, img_channels, kernel_size=4, stride=2, padding=1
            ),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.net(x)

generator = Generator(latent_dim, img_channels, feature_maps).to(device)

if os.path.exists(generator_model_path):
    generator.load_state_dict(torch.load(generator_model_path))

rows, cols = 4, 4

figure = plt.figure()

for row in range(rows):
    for col in range(cols):
        z = torch.randn(1, latent_dim, 1, 1).to(device)
        generated_img = generator(z).detach().cpu()
        generated_img = generated_img.squeeze().permute(1, 2, 0) * 0.5 + 0.5

        index = row * cols + col
        plot = figure.add_subplot(rows, cols, index + 1)
        plot.imshow(generated_img)
        plot.axis("off")

plt.show()

效果图

Figure_1.png