使用PyTorch查看猫狗分类数据集中的图片

创建日期:2025-02-07
更新日期:2025-02-10

查看猫狗分类数据集中的一张图片

from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

image = Image.open(".\\data\\Dogs Vs Cats\\train\\cat.16.jpg").convert("RGB")
tensor = transforms.ToTensor()(image)

# 将维度从 [C, H, W] 转换为 [H, W, C]
image = tensor.numpy().transpose((1, 2, 0))

# 将数组的数据类型转换为 uint8
image = (image * 255).astype(np.uint8)

plt.title("Cat")
plt.imshow(image)
plt.axis("off")
plt.show()

效果图

微信截图_20250210201154.png

查看猫狗分类数据集Tensor中的图片

import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os

train_size = 25000
indices = torch.randperm(train_size)
train_indices = indices[:20000]
test_indices = indices[20000:]

class DogsVsCatsDataset(Dataset):
    def __init__(self, root, train=True, transform=None):
        super().__init__()
        self.root = root
        self.transform = transform
        self.classes = ["dog", "cat"]

        self.files = []
        self.labels = []

        files = os.listdir(root)

        index = train_indices if train else test_indices
        for i in index:
            file = files[i]
            self.files.append(file)
            if "dog" in file:
                self.labels.append(0)
            else:
                self.labels.append(1)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        path = os.path.join(self.root, self.files[index])
        image = Image.open(path).convert("RGB")
        label = self.labels[index]
        if self.transform:
            image = self.transform(image)
        return image, label

transform = transforms.Compose(
    [
        transforms.Resize(size=(224, 224)),
        transforms.ToTensor(),
    ]
)

train_dataset = DogsVsCatsDataset(
    root=".\\data\\Dogs Vs Cats\\train", train=True, transform=transform
)

image, label = train_dataset[0]

# 将维度从 [C, H, W] 转换为 [H, W, C]
image = image.numpy().transpose((1, 2, 0))

# 将数组的数据类型转换为 uint8
image = (image * 255).astype(np.uint8)

plt.title(train_dataset.classes[label])
plt.imshow(image)
plt.axis("off")
plt.show()

效果图

微信截图_20250210201315.png