查看猫狗分类数据集中的一张图片
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
image = Image.open(".\\data\\Dogs Vs Cats\\train\\cat.16.jpg").convert("RGB")
tensor = transforms.ToTensor()(image)
# 将维度从 [C, H, W] 转换为 [H, W, C]
image = tensor.numpy().transpose((1, 2, 0))
# 将数组的数据类型转换为 uint8
image = (image * 255).astype(np.uint8)
plt.title("Cat")
plt.imshow(image)
plt.axis("off")
plt.show()
效果图

查看猫狗分类数据集Tensor中的图片
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
train_size = 25000
indices = torch.randperm(train_size)
train_indices = indices[:20000]
test_indices = indices[20000:]
class DogsVsCatsDataset(Dataset):
def __init__(self, root, train=True, transform=None):
super().__init__()
self.root = root
self.transform = transform
self.classes = ["dog", "cat"]
self.files = []
self.labels = []
files = os.listdir(root)
index = train_indices if train else test_indices
for i in index:
file = files[i]
self.files.append(file)
if "dog" in file:
self.labels.append(0)
else:
self.labels.append(1)
def __len__(self):
return len(self.files)
def __getitem__(self, index):
path = os.path.join(self.root, self.files[index])
image = Image.open(path).convert("RGB")
label = self.labels[index]
if self.transform:
image = self.transform(image)
return image, label
transform = transforms.Compose(
[
transforms.Resize(size=(224, 224)),
transforms.ToTensor(),
]
)
train_dataset = DogsVsCatsDataset(
root=".\\data\\Dogs Vs Cats\\train", train=True, transform=transform
)
image, label = train_dataset[0]
# 将维度从 [C, H, W] 转换为 [H, W, C]
image = image.numpy().transpose((1, 2, 0))
# 将数组的数据类型转换为 uint8
image = (image * 255).astype(np.uint8)
plt.title(train_dataset.classes[label])
plt.imshow(image)
plt.axis("off")
plt.show()
效果图
