查看一张图片
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
train_dataset = datasets.FashionMNIST(
root="./data", train=True, transform=transforms.ToTensor(), download=True
)
image, label = train_dataset[0]
plt.title(train_dataset.classes[label])
plt.axis("off")
plt.imshow(image.squeeze(), cmap="gray")
plt.show()
效果图

查看多张图片
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
train_dataset = datasets.FashionMNIST(
root="./data", train=True, transform=transforms.ToTensor(), download=True
)
figure = plt.figure()
rows, cols = 3, 3
for row in range(rows):
for col in range(cols):
index = row * 3 + col
image, label = train_dataset[index]
figure.add_subplot(rows, cols, index + 1)
plt.title(train_dataset.classes[label])
plt.axis("off")
plt.imshow(image.squeeze(), cmap="gray")
plt.show()
效果图
