变换
变换组合
对数据进行预处理是炼丹中重要的一环,transforms.Compose 用于“打包”各种变换。
要理解该函数,需要先理解 Python 的“类”。
与 C++、Java 等语言不同(类是生成类实例的模板,类自身不是实例),Python 的类在行为上更类似于 Lua 的表,即类本身就是一个已经实例化的对象,要对这种特殊的对象进行操作,除了像定义成员方法并进行调用外,还有一种特殊的方式——元方法。
元方法从行为上与 C++ 的运算符重载非常类似。例如,你可以为对象定义 __call__ 方法,这相当于在 C++ 中重载 operator(),随后,该对象就成了可调用对象(Callable Objects)。
理解了这一点,就能够理解 PyTorch 中模块的概念。
以 ToTensor 方法为例:
class ToTensor:
def __call__(self, pic):
"""
Args:
pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
return F.to_tensor(pic)
定义了 __call__ 方法后,可以像调用一个普通的方法那样调用 ToTensor 模块:
ToTensor(pic)
与常规的方法不同的是,ToTensor 作为一个实例化的对象,其内部还能存储一些数据,定义一些辅助的成员方法,进而实现更简洁的模块设计,而不是将数据和流程(Procedure)全部塞进一个方法中。
回到 transforms.Compose 本身,该方法从根本上来说就是把传入其中的模块按序拼接在一起,前一个模块的输出作为后一个模块的输入,形成一个变换的流水线模型,例如,下面的例程将缩放、张量化、归一化操作三个模块拼接在一起,最后一个模块的输出作为预处理的结果。
training_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
数据集
定义了基本变换后,下一个步骤是加载数据集,并对数据执行预定义的变换。torchvision.datasets.ImageFolder 用于从文件夹中加载图像数据集,并执行预定义的预处理操作。图像数据集需要遵循如下的文件结构:
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
需要指出,torchvision.datasets.ImageFolder 仍是一个类,因此