torchvision.transforms使用详解
torchvision.transforms
1. torchvision.transforms.Compose
torchvision.transforms.Compose:是一个用于组合多个图像预处理操作的类,将多个预处理操作串联在一起,以便在数据加载时对图像进行连续的处理。
torchvision.transforms.Compose(transforms) 的参数是一个列表,其中包含要进行的图像预处理操作。
1 | import torchvision.transforms as transforms |
2.torchvision.transforms.RandomResizedCrop
torchvision.transforms.RandomResizedCrop:用于数据增广,以增加数据集的多样性和提高模型的泛化能力(随机裁剪图像,并将裁剪后的图像调整为指定的大小)
torchvision.transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(3.0/4.0, 4.0/3.0), interpolation=2)
的参数包括:
size
:输出的裁剪后图像的大小,可以是一个整数或一个元组 (height, width)。scale
:控制裁剪区域相对于原始图像大小的尺度范围,它是一个长度为2的元组 (min_scale, max_scale)。裁剪区域的大小在 [min_scale 图像大小, max_scale 图像大小] 之间随机选择。ratio
:控制裁剪区域的宽高比范围,它是一个长度为2的元组 (min_ratio, max_ratio)。裁剪区域的宽高比在 [min_ratio, max_ratio] 之间随机选择。interpolation
:插值方法,用于调整裁剪后图像的大小。默认值为2,表示使用双线性插值。注:在图像处理中,将图像从一个尺寸调整为另一个尺寸时,通常需要使用插值方法来计算新尺寸的像素值。
1 | import torchvision.transforms as transforms |
3.torchvision.transforms.RandomHorizontalFlip()
torchvision.transforms.RandomHorizontalFlip():对图像进行随机水平翻转,增加数据集的多样性和提高模型的鲁棒性。
torchvision.transforms.RandomHorizontalFlip(p=0.5)
的参数 p
(默认为0.5)控制水平翻转的概率。
当 p=0.5
时,有50%的概率对图像进行水平翻转。当 p=0
时,不进行翻转;当 p=1
时,100%进行翻转。
1 | import torchvision.transforms as transforms |
4.torchvision.transforms.ToTensor()
torchvision.transforms.ToTensor():用于将 PIL 图像或 NumPy 数组转换为 PyTorch 张量(Tensor)格式,以便在深度学习模型中使用。
1 | import torchvision.transforms as transforms |
5.torchvision.transforms.Normalize
torchvision.transforms.Normalize:对图像进行归一化处理,以便模型在训练和推断过程中更好地处理数据
5.1 计算数据集的均值和标准差
os.walk(dataset_path):用于遍历文件夹的一个函数,它生成一个三元组的迭代器,每次迭代返回一个包含当前目录路径、当前目录下所有子目录名、当前目录下所有文件名的元组。
1 | import numpy as np |
注:
1.表示图像的 Tensor,其维度为 (C, H, W),其中 C 表示通道数,H 表示图像的高度,W 表示图像的宽度。dim=(1, 2) 表示在第 1 和第 2 维度上进行求均值的操作,torch.mean(img, dim=(1, 2))将对每个通道的高度和宽度上的所有像素值进行求均值。
2.PIL 图像(NumPy 数组表示):在将图像转换为 NumPy 数组时,通常图像中的像素值被映射到 [0, 255] 范围内,方便对图像进行基本的像素级操作,如颜色调整、滤波等。
3.PyTorch 张量:PyTorch 张量在处理图像时,通常会进行数据归一化处理,图像的像素值会被映射到 [0, 1] 范围内,这可以有效地缩小不同通道之间数值的差异,避免数据在训练过程中产生较大的梯度,导致训练不稳定。(归一化方式是将像素值除以 255)
5.2 torchvision.transforms.Normalize
torchvision.transforms.Normalize(mean, std)
的参数包括:
mean
:一个包含三个元素的列表或元组,表示每个通道的均值。对于 RGB 图像,通常是 [R 均值, G 均值, B 均值]。std
:一个包含三个元素的列表或元组,表示每个通道的标准差。对于 RGB 图像,通常是 [R 标准差, G 标准差, B 标准差]
1 | import torchvision.transforms as transforms |
6.torchvision.transforms.functional.crop
torchvision.transforms.functional.crop:从输入图像中裁剪出指定区域的子图像
1 | torchvision.transforms.functional.crop(img, top, left, height, width) |
img
:输入的图像,通常是一个 PIL 图像或一个张量。top
:裁剪区域的顶部边界(以像素为单位)。left
:裁剪区域的左边界(以像素为单位)。height
:裁剪区域的高度(以像素为单位)。width
:裁剪区域的宽度(以像素为单位)。
1 | crop_rect = (0, 0, 320, 480) |