查看数据集中的图片

创建日期:2024-06-21
更新日期:2025-02-10

查看一张图片

from torchvision import datasets, transforms
import matplotlib.pyplot as plt

train_dataset = datasets.FashionMNIST(
    root="./data", train=True, transform=transforms.ToTensor(), download=True
)

image, label = train_dataset[0]

plt.title(train_dataset.classes[label])
plt.axis("off")
plt.imshow(image.squeeze(), cmap="gray")
plt.show()

效果图

微信截图_20250210200827.png

查看多张图片

from torchvision import datasets, transforms
import matplotlib.pyplot as plt

train_dataset = datasets.FashionMNIST(
    root="./data", train=True, transform=transforms.ToTensor(), download=True
)

figure = plt.figure()

rows, cols = 3, 3

for row in range(rows):
    for col in range(cols):
        index = row * 3 + col
        image, label = train_dataset[index]

        figure.add_subplot(rows, cols, index + 1)
        plt.title(train_dataset.classes[label])
        plt.axis("off")
        plt.imshow(image.squeeze(), cmap="gray")

plt.show()

效果图

微信截图_20250210200945.png