数据集下载:https://www.kaggle.com/datasets/soumikrakshit/anime-faces
训练代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import os
num_epochs = 100
batch_size = 64
learning_rate = 0.0002
latent_dim = 100
img_channels = 3
feature_maps = 64
anime_faces_folder = "./data/Anime Faces"
generator_model_path = "./model/anime_faces/generator.pth"
discriminator_model_path = "./model/anime_faces/discriminator.pth"
if not os.path.exists("./model/anime_faces"):
os.makedirs("./model/anime_faces")
class AnimeFacesDataset(Dataset):
def __init__(self, root, transform=None):
super().__init__()
self.root = root
self.transform = transform
self.files = []
self.labels = []
for file in os.listdir(root):
self.files.append(file)
self.labels.append(1)
def __len__(self):
return len(self.files)
def __getitem__(self, index):
path = os.path.join(self.root, self.files[index])
image = Image.open(path).convert("RGB")
label = self.labels[index]
if self.transform:
image = self.transform(image)
return image, label
transform = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
dataset = AnimeFacesDataset(root=anime_faces_folder, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Generator(nn.Module):
def __init__(self, latent_dim, img_channels, feature_maps):
super().__init__()
self.net = nn.Sequential(
nn.ConvTranspose2d(
latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0
),
nn.BatchNorm2d(feature_maps * 8),
nn.ReLU(),
nn.ConvTranspose2d(
feature_maps * 8, feature_maps * 4, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(feature_maps * 4),
nn.ReLU(),
nn.ConvTranspose2d(
feature_maps * 4, feature_maps * 2, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(feature_maps * 2),
nn.ReLU(),
nn.ConvTranspose2d(
feature_maps * 2, img_channels, kernel_size=4, stride=2, padding=1
),
nn.Tanh(),
)
def forward(self, x):
return self.net(x)
class Discriminator(nn.Module):
def __init__(self, img_channels, feature_maps):
super().__init__()
self.net = nn.Sequential(
# 输入: (3, 32, 32)
nn.Conv2d(img_channels, feature_maps, 4, 2, 1), # 输出: (64, 16, 16)
nn.LeakyReLU(0.2),
nn.Conv2d(feature_maps, feature_maps * 2, 4, 2, 1), # 输出: (128, 8, 8)
nn.BatchNorm2d(feature_maps * 2),
nn.LeakyReLU(0.2),
nn.Conv2d(feature_maps * 2, feature_maps * 4, 4, 2, 1), # 输出: (256, 4, 4)
nn.BatchNorm2d(feature_maps * 4),
nn.LeakyReLU(0.2),
nn.Conv2d(feature_maps * 4, 1, 4, 1, 0), # 输出: (1, 1, 1)
nn.Sigmoid(),
nn.Flatten(),
)
def forward(self, x):
return self.net(x)
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
generator = (
Generator(latent_dim, img_channels, feature_maps).apply(weights_init).to(device)
)
discriminator = Discriminator(img_channels, feature_maps).apply(weights_init).to(device)
if os.path.exists(generator_model_path):
generator.load_state_dict(torch.load(generator_model_path))
if os.path.exists(discriminator_model_path):
discriminator.load_state_dict(torch.load(discriminator_model_path))
criterion = nn.MSELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(
discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999)
)
for epoch in range(num_epochs):
for i, (real_imgs, _) in enumerate(dataloader):
real_imgs = real_imgs.to(device)
batch_size = real_imgs.size(0)
# 训练判别器
optimizer_D.zero_grad()
# 真实图像标签为1
real_labels = torch.full((batch_size, 1), 0.9, device=device)
real_output = discriminator(real_imgs)
d_loss_real = criterion(real_output, real_labels)
# 生成假图像
z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
fake_imgs = generator(z)
# 假图像标签为0
fake_labels = torch.full((batch_size, 1), 0.1, device=device)
fake_output = discriminator(fake_imgs.detach())
d_loss_fake = criterion(fake_output, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_D.step()
# 训练生成器两次
for _ in range(2):
optimizer_G.zero_grad()
# 重新生成假图像以增加多样性
z = torch.randn(batch_size, latent_dim, 1, 1).to(device)
fake_imgs = generator(z)
output = discriminator(fake_imgs)
# 生成器目标为真实标签
g_loss = criterion(output, torch.ones_like(output))
g_loss.backward()
optimizer_G.step()
if i % 100 == 0:
print(
f"Epoch [{epoch + 1}/{num_epochs}] Batch {i}/{len(dataloader)} "
f"Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}"
)
torch.save(generator.state_dict(), generator_model_path)
torch.save(discriminator.state_dict(), discriminator_model_path)
# 生成随机噪声
z = torch.randn(1, latent_dim, 1, 1).to(device)
# 生成图像
generated_img = generator(z).detach().cpu()
# 反归一化处理
generated_img = generated_img.squeeze().permute(1, 2, 0) * 0.5 + 0.5
# 可视化
plt.imshow(generated_img)
plt.axis("off")
plt.show()
生成代码
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import os
latent_dim = 100
img_channels = 3
feature_maps = 64
generator_model_path = "./model/anime_faces/generator.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Generator(nn.Module):
def __init__(self, latent_dim, img_channels, feature_maps):
super().__init__()
self.net = nn.Sequential(
nn.ConvTranspose2d(
latent_dim, feature_maps * 8, kernel_size=4, stride=1, padding=0
),
nn.BatchNorm2d(feature_maps * 8),
nn.ReLU(),
nn.ConvTranspose2d(
feature_maps * 8, feature_maps * 4, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(feature_maps * 4),
nn.ReLU(),
nn.ConvTranspose2d(
feature_maps * 4, feature_maps * 2, kernel_size=4, stride=2, padding=1
),
nn.BatchNorm2d(feature_maps * 2),
nn.ReLU(),
nn.ConvTranspose2d(
feature_maps * 2, img_channels, kernel_size=4, stride=2, padding=1
),
nn.Tanh(),
)
def forward(self, x):
return self.net(x)
generator = Generator(latent_dim, img_channels, feature_maps).to(device)
if os.path.exists(generator_model_path):
generator.load_state_dict(torch.load(generator_model_path))
rows, cols = 4, 4
figure = plt.figure()
for row in range(rows):
for col in range(cols):
z = torch.randn(1, latent_dim, 1, 1).to(device)
generated_img = generator(z).detach().cpu()
generated_img = generated_img.squeeze().permute(1, 2, 0) * 0.5 + 0.5
index = row * cols + col
plot = figure.add_subplot(rows, cols, index + 1)
plot.imshow(generated_img)
plot.axis("off")
plt.show()