Cite: https://github.com/vahidk/EffectivePyTorch

N.B.

  • PyTorch ≈ NumPy
  • vectorized implementation

有效的 PyTorch

目录

  • PyTorch 基础
  • 用预定义模块封装模型
  • 广播的优点与缺点
  • 利用重载操作符
  • 用 TorchScript 优化 Runtime
  • 构建高效的自定义数据加载器
  • PyTorch 中的数值稳定性
  • 用自动混合精度加快训练

安装 PyTorch,请遵循官方网站上的说明:

pip install torch torchvision

PyTorch 基础

PyTorch 是进行数值计算的最流行库之一,目前也是用于执行机器学习研究的最广泛使用的库。在许多方面,PyTorch 与 NumPy 相似,额外的好处是 PyTorch 允许你在 CPU、GPU 和 TPU 上执行计算,而不需要对代码进行任何实质性更改。PyTorch 还使跨多个设备或机器分发计算变得容易。PyTorch 最重要的特性之一是自动微分。它允许以高效的方式计算函数的梯度,这对于使用梯度下降法训练机器学习模型至关重要。我们的目标是为 PyTorch 提供一个温和的介绍,并讨论使用 PyTorch 的最佳实践。

关于 PyTorch 首先需要了解的是张量的概念。张量简单地说是多维数组。PyTorch 张量与 NumPy 数组非常相似,具有一些额外的功能。

一个张量可以存储一个标量值:

import torch
a = torch.tensor(3)
print(a)  # tensor(3)

或者一个数组:

b = torch.tensor([1, 2])
print(b)  # tensor([1, 2])

一个矩阵:

c = torch.zeros([2, 2])
print(c)  # tensor([[0., 0.], [0., 0.]])

或者任何任意维度的张量:

d = torch.rand([2, 2, 2])

张量可以用来高效地执行代数运算。在机器学习应用中最常用的操作之一是矩阵乘法。假设你想乘以两个大小为 3x5 和 5x4 的随机矩阵,可以使用矩阵乘法@操作:

import torch

x = torch.randn([3, 5])
y = torch.randn([5, 4])
z = x @ y

print(z)

同样,要添加两个向量,可以这样做:

z = x + y

要将一个张量转换为 Numpy 数组,可以调用张量的 numpy() 方法:

print(z.numpy())

而且你可以随时将 Numpy 数组转换为张量:

x = torch.tensor(np.random.normal([3, 5]))

自动微分

PyTorch 相对于 NumPy 的最重要优势是其自动微分功能,这在优化应用中非常有用,例如优化神经网络的参数。让我们通过一个例子来理解它。

假设你有一个复合函数,它是两个函数的链:g(u(x))。 要计算 g 关于 x 的导数,我们可以使用链式法则,它指出:dg/dx = dg/du * du/dx。PyTorch 可以为我们分析计算导数。

要在 PyTorch 中计算导数,首先我们创建一个张量并将其 requires_grad 设置为 true。我们可以使用张量操作来定义我们的函数。我们假设 u 是一个二次函数,g 是一个简单的线性函数:

x = torch.tensor(1.0, requires_grad=True)
def u(x):
  return x * x
def g(u):
  return -u

在这种情况下,我们的复合函数是 g(u(x)) = -x*x。所以它关于 x 的导数是 -2x。在 x=1 这一点上,这等于 -2。

让我们验证一下。这可以使用 PyTorch 中的 grad 函数完成:

dgdx = torch.autograd.grad(g(u(x)), x)[0]
print(dgdx)  # tensor(-2.)

曲线拟合

为了理解自动微分的强大之处,让我们来看另一个例子。假设我们有一些来自曲线的样本(比如说 f(x) = 5x^2 + 3),我们想基于这些样本估计 f(x)。我们定义了一个参数化函数 g(x, w) = w0 x^2 + w1 x + w2,这是一个关于输入 x 和潜在参数 w 的函数,我们的目标是找到这样的潜在参数,使得 g(x, w) ≈ f(x)。这可以通过最小化以下损失函数来实现:L(w) = Σ (f(x) - g(x, w))^2。尽管这个简单问题有一个封闭形式的解,但我们选择使用一种更通用的方法,可以应用于任何任意可微函数,那就是使用随机梯度下降。我们只需计算 L(w) 关于 w 在一组样本点上的平均梯度,并朝相反方向移动。

以下是如何在 PyTorch 中完成的:

import numpy as np
import torch

# Assuming we know that the desired function is a polynomial of 2nd degree, we
# allocate a vector of size 3 to hold the coefficients and initialize it with
# random noise.
w = torch.tensor(torch.randn([3, 1]), requires_grad=True)

# We use the Adam optimizer with learning rate set to 0.1 to minimize the loss.
opt = torch.optim.Adam([w], 0.1)

def model(x):
    # We define yhat to be our estimate of y.
    f = torch.stack([x * x, x, torch.ones_like(x)], 1)
    yhat = torch.squeeze(f @ w, 1)
    return yhat

def compute_loss(y, yhat):
    # The loss is defined to be the mean squared error distance between our
    # estimate of y and its true value. 
    loss = torch.nn.functional.mse_loss(yhat, y)
    return loss

def generate_data():
    # Generate some training data based on the true function
    x = torch.rand(100) * 20 - 10
    y = 5 * x * x + 3
    return x, y

def train_step():
    x, y = generate_data()

    yhat = model(x)
    loss = compute_loss(y, yhat)

    opt.zero_grad()
    loss.backward()
    opt.step()

for _ in range(1000):
    train_step()

print(w.detach().numpy())

运行这段代码,你应该看到一个接近以下结果:

[4.9924135, 0.00040895029, 3.4504161]

这是我们参数的一个相对接近的近似值。

这只是 PyTorch 功能的冰山一角。许多问题,如优化具有数百万参数的大型神经网络,可以在 PyTorch 中仅用几行代码高效实现。PyTorch 负责跨多个设备、线程的扩展,并支持多种平台。

用预定义模块封装模型

在前一个例子中,我们使用原始的张量和张量操作来构建我们的模型。为了使你的代码稍微更有组织,建议使用 PyTorch 的 modules。一个模块只是你的参数的容器,并封装了模型操作。例如,假设你想表示一个线性模型 y = ax + b。这个模型可以用以下代码表示:

import torch

class Net(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.a = torch.nn.Parameter(torch.rand(1))
    self.b = torch.nn.Parameter(torch.rand(1))

  def forward(self, x):
    yhat = self.a * x + self.b
    return yhat

要在实践中使用这个模型,你实例化模块并像调用函数一样调用它:

x = torch.arange(100, dtype=torch.float32)

net = Net()
y = net(x)

参数本质上是将 requires_grad 设置为 True 的张量。使用参数非常方便,因为你可以通过模块的 parameters() 方法简单地检索它们所有:

for p in net.parameters():
  print(p)

现在,假设你有一个未知函数 y = 5x + 3 + noise,你想优化你的模型的参数以适应这个函数。你可以从你的函数中采样一些点开始:

x = torch.arange(100, dtype=torch.float32) / 100
y = 5 * x + 3 + torch.rand(100) * 0.3

与前一个例子类似,你可以定义一个损失函数并优化你的模型的参数,如下所示:

criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)

for i in range(10000):
  net.zero_grad()
  yhat = net(x)
  loss = criterion(yhat, y)
  loss.backward()
  optimizer.step()

print(net.a, net.b) # 应该接近 5 和 3

PyTorch 配备了许多预定义的模块。其中一个模块是 torch.nn.Linear,它比我们上面定义的线性函数更通用。我们可以用 torch.nn.Linear 重写我们的模块,如下所示:

class Net(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.linear = torch.nn.Linear(1, 1)

  def forward(self, x):
    yhat = self.linear(x.unsqueeze(1)).squeeze(1)
    return yhat

请注意,我们使用了 squeezeunsqueeze,因为 torch.nn.Linear 在向量批次上操作,而不是标量。

默认情况下,在模块上调用 parameters() 将返回其所有子模块的参数:

net = Net()
for p in net.parameters():
  print(p)

有一些预定义的模块充当其他模块的容器。最常用的容器模块是 torch.nn.Sequential。顾名思义,它用于将多个模块(或层)堆叠在一起。例如,要在两个 Linear 层之间堆叠一个 ReLU 非线性,可以这样做:

model = torch.nn.Sequential(
  torch.nn.Linear(64, 32),
  torch.nn.ReLU(),
  torch.nn.Linear(32, 10),
)

广播的优点与缺点

PyTorch 支持广播元素级操作。通常,当你想执行像加法和乘法这样的操作时,你需要确保操作数的形状匹配,例如你不能将形状为 [3, 2] 的张量加到形状为 [3, 4] 的张量上。但是有一个特殊情况,那就是当你有一个单一维度时。PyTorch 隐式地在单一维度上平铺张量以匹配另一个操作数的形状。所以将形状为 [3, 2] 的张量加到形状为 [3, 1] 的张量上是有效的。

import torch

a = torch.tensor([[1., 2.], [3., 4.]])
b = torch.tensor([[1.], [2.]])
# c = a + b.repeat([1, 2])
c = a + b

print(c)

广播允许我们进行隐式平铺,这使得代码更短,更节省内存,因为我们不需要存储平铺操作的结果。这种技术的一个很好的应用是在组合长度不同的特征时。为了连接长度不同的特征,我们通常平铺输入张量,连接结果并应用一些非线性。这是各种神经网络架构中的常见模式:

a = torch.rand([5, 3, 5])
b = torch.rand([5, 1, 6])

linear = torch.nn.Linear(11, 10)

# concat a and b and apply nonlinearity
tiled_b = b.repeat([1, 3, 1])
c = torch.cat([a, tiled_b], 2)
d = torch.nn.functional.relu(linear(c))

print(d.shape)  # torch.Size([5, 3, 10])

但这可以通过广播更有效地完成。我们利用 f(m(x + y)) 等于 f(mx + my) 的事实。所以我们可以分别进行线性操作,并使用广播进行隐式连接:

a = torch.rand([5, 3, 5])
b = torch.rand([5, 1, 6])

linear1 = torch.nn.Linear(5, 10)
linear2 = torch.nn.Linear(6, 10)

pa = linear1(a)
pb = linear2(b)
d = torch.nn.functional.relu(pa + pb)

print(d.shape)  # torch.Size([5, 3, 10])

事实上,这段代码非常通用,只要张量之间的广播是可能的,就可以应用于任意形状的张量:

class Merge(torch.nn.Module):
    def __init__(self, in_features1, in_features2, out_features, activation=None):
        super().__init__()
        self.linear1 = torch.nn.Linear(in_features1, out_features)
        self.linear2 = torch.nn.Linear(in_features2, out_features)
        self.activation = activation

    def forward(self, x1, x2):
        pa = self.linear1(x1)
        pb = self.linear2(x2)
        d = self.activation(pa + pb)
        return d

到目前为止,我们讨论了广播的优点。但你可能要问,它的缺点是什么?隐式假设几乎总是使调试变得更加困难。考虑以下例子:

a = torch.tensor([[1.], [2.]])
b = torch.tensor([1., 2.])
c = torch.sum(a + b)

print(c)

你认为 c 的值在评估后会是多少?如果你猜是 6,那是错误的。它会是 12。这是因为当两个张量的秩不匹配时,PyTorch 会自动扩展低秩张量的第一维,然后再进行元素级操作,所以加法的结果将是 [[2, 3], [3, 4]],并且对所有参数进行减少会给我们 12。

避免这个问题的方法是尽可能明确。如果我们指定了我们想要减少的维度,捕捉这个错误会容易得多:

a = torch.tensor([[1.], [2.]])
b = torch.tensor([1., 2.])
c = torch.sum(a + b, 0)

print(c)

这里 c 的值将是 [5, 7],我们立即会根据结果的形状猜测有些问题。一个通用的经验法则是在 reduction 操作中始终指定维度,并在使用 torch.squeeze 时这样做。

利用重载操作符

就像 NumPy 一样,PyTorch 重载了许多 Python 操作符,使 PyTorch 代码更短,更易读。

切片操作是重载的操作符之一,可以使索引张量变得非常容易:

z = x[begin:end]  # z = torch.narrow(0, begin, end-begin)

但使用这个操作时要小心。切片操作像任何其他操作一样有一些开销。因为它是一个常见的操作,看起来无辜,可能会被过度使用,这可能导致低效。为了理解这个操作可能有多低效,让我们来看一个例子。我们想手动对一个矩阵的行进行减少:

import torch
import time

x = torch.rand([500, 10])

z = torch.zeros([10])

start = time.time()
for i in range(500):
    z += x[i]
print("Took %f seconds." % (time.time() - start))

这运行相当慢,原因是我们调用了切片操作 500 次,增加了很多开销。一个更好的选择是使用 torch.unbind 操作一次将矩阵切片成一系列向量:

z = torch.zeros([10])
for x_i in torch.unbind(x):
    z += x_i

这显著地(在我的机器上大约快 30%)更快。

当然,做这个简单的减少的正确方法是使用 torch.sum 操作一次完成:

z = torch.sum(x, dim=0)

这非常快(在我的机器上大约快 100 倍)。

PyTorch 还重载了一系列算术和逻辑操作符:

z = -x  # z = torch.neg(x)
z = x + y  # z = torch.add(x, y)
z = x - y
z = x * y  # z = torch.mul(x, y)
z = x / y  # z = torch.div(x, y)
z = x // y
z = x % y
z = x ** y  # z = torch.pow(x, y)
z = x @ y  # z = torch.matmul(x, y)
z = x > y
z = x >= y
z = x < y
z = x <= y
z = abs(x)  # z = torch.abs(x)
z = x & y
z = x | y
z = x ^ y  # z = torch.logical_xor(x, y)
z = ~x  # z = torch.logical_not(x)
z = x == y  # z = torch.eq(x, y)
z = x != y  # z = torch.ne(x, y)

你也可以使用这些操作的增强版本。例如 x += yx **= 2 也是有效的。

注意 Python 不允许重载 and, or, 和 not 关键词。

用 TorchScript 优化 Runtime

PyTorch 优化了在大张量上执行操作。在小张量上执行许多操作在 PyTorch 中相当低效。因此,只要可能,你应该将你的计算重写为批处理形式,以减少开销并提高性能。如果你无法手动批处理你的操作,使用 TorchScript 可能会提高你的代码性能。TorchScript 只是 PyTorch 识别的 Python 函数的子集。PyTorch 可以使用其即时编译器 jit 自动优化你的 TorchScript 代码,并减少一些开销。

让我们来看一个例子。在 ML 应用中一个非常常见的操作是 批处理收集。这个操作可以简单地写成 output[i] = input[i, index[i]]。这可以在 PyTorch 中实现如下:

import torch

def batch_gather(tensor, indices):
    output = []
    for i in range(tensor.size(0)):
        output += [tensor[i][indices[i]]]
    return torch.stack(output)

要使用 TorchScript 实现相同的功能,只需使用 torch.jit.script 装饰器:

@torch.jit.script
def batch_gather_jit(tensor, indices):
    output = []
    for i in range(tensor.size(0)):
        output += [tensor[i][indices[i]]]
    return torch.stack(output)

在我的测试中,这大约快了 10%。

但没有什么比手动批处理你的操作更好的了。在我的测试中,向量化的实现快了 100 倍:

def batch_gather_vec(tensor, indices):
    shape = list(tensor.shape)
    flat_first = torch.reshape(
        tensor, [shape[0] * shape[1]] + shape[2:])
    offset = torch.reshape(
        torch.arange(shape[0]).cuda() * shape[1], [shape[0]] + [1] * (len(indices.shape) - 1))
    output = flat_first[indices + offset]
    return output

构建高效的自定义数据加载器

在上一堂课中,我们讨论了编写高效的 PyTorch 代码。但为了使你的代码以最大效率运行,你还需要有效地将你的数据加载到设备内存中。幸运的是,PyTorch 提供了一个工具来使数据加载变得容易。它被称为 DataLoaderDataLoader 使用多个工作线程同时从 Dataset 加载数据,并可选地使用 Sampler 来采样数据条目并形成批次。

如果你可以随机访问你的数据,使用 DataLoader 非常简单:你只需要实现一个 Dataset 类,该类实现 __getitem__(读取每个数据项)和 __len__(返回数据集中的项目数)方法。例如,以下是如何从给定目录加载图像的示例:

import glob
import os
import random
import cv2
import torch

class ImageDirectoryDataset(torch.utils.data.Dataset):
    def __init__(path, pattern):
        self.paths = list(glob.glob(os.path.join(path, pattern)))

    def __len__(self):
        return len(self.paths)

    def __item__(self):
        path = random.choice(paths)
        return cv2.imread(path, 1)

然后,你可以这样做来加载给定目录中的所有 jpeg 图像:

dataloader = torch.utils.data.DataLoader(ImageDirectoryDataset("/data/imagenet/*.jpg"), num_workers=8)
for data in dataloader:
    # do something with data

在这里,我们使用 8 个工作线程同时从磁盘读取我们的数据。你可以在你的机器上调整工作线程的数量以获得最佳结果。

如果你的数据项很大,使用具有随机访问的数据加载器可能没问题。但想象一下,如果你有一个网络文件系统连接速度很慢。以这种方式请求单个文件可能会非常慢,并且可能最终成为你训练管道的瓶颈。

一个更好的方法是将你的数据存储在连续的文件格式中,可以顺序读取。例如,如果你有大量图像,可以使用 tar 创建一个单一的存档,并在 python 中顺序提取存档中的文件。为此,你可以使用 PyTorch 的 IterableDataset。要创建一个 IterableDataset 类,你只需要实现一个 __iter__ 方法,该方法按顺序从数据集中读取并产生数据项。

一个简单的实现可能如下所示:

import tarfile
import torch

def tar_image_iterator(path):
    tar = tarfile.open(self.path, "r")
    for tar_info in tar:
        file = tar.extractfile(tar_info)
        content = file.read()
        yield cv2.imdecode(content, 1)
        file.close()
        tar.members = []
    tar.close()

class TarImageDataset(torch.utils.data.IterableDataset):
    def __init__(self, path):
        super().__init__()
        self.path = path

    def __iter__(self):
        yield from tar_image_iterator(self.path)

但这个实现有一个主要问题。如果你试图使用 DataLoader 以多于一个工作线程的方式从这个数据集中读取,你会观察到很多重复的图像:

dataloader = torch.utils.data.DataLoader(TarImageDataset("/data/imagenet.tar"), num_workers=8)
for data in dataloader:
    # data contains duplicated items

问题是每个工作线程创建了一个单独的数据集实例,每个实例都会从数据集的开头开始。避免这个问题的一种方法是,不要有一个 tar 文件,而是将你的数据分成 num_workers 个单独的 tar 文件,并让每个工作线程加载每个文件:

class TarImageDataset(torch.utils.data.IterableDataset):
    def __init__(self, paths):
        super().__init__()
        self.paths = paths

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        # For simplicity we assume num_workers is equal to number of tar files
        if worker_info is None or worker_info.num_workers != len(self.paths):
            raise ValueError("Number of workers doesn't match number of files.")
        yield from tar_image_iterator(self.paths[worker_info.worker_id])

这是我们的 dataset 类可以这样使用:

dataloader = torch.utils.data.DataLoader(
    TarImageDataset(["/data/imagenet_part1.tar", "/data/imagenet_part2.tar"]), num_workers=2)
for data in dataloader:
    # do something with data

我们讨论了一个简单的策略来避免重复条目问题。tfrecord 包使用稍微更复杂的策略来即时分片你的数据。

PyTorch 中的数值稳定性

当你使用任何数值计算库,如 NumPy 或 PyTorch 时,重要的是要注意,编写数学上正确的代码并不一定导致正确的结果。你还需要确保计算是稳定的。

让我们从一个简单的例子开始。数学上,很容易看出 x * y / y = x 对于任何非零的 x 值都是正确的。但让我们看看这在实践中是否总是正确的:

import numpy as np

x = np.float32(1)

y = np.float32(1e-50)  # y would be stored as zero
z = x * y / y

print(z)  # prints nan

不正确的原因是 y 对于 float32 类型来说太小了。当 y 太大时,也会出现类似的问题:

y = np.float32(1e39)  # y 将被存储为 inf
z = x * y / y

print(z)  # 打印 nan

float32 类型能表示的最小正值是 1.4013e-45,任何低于这个值的东西都会被存储为零。此外,任何超过 3.40282e+38 的数字,都会被存储为 inf。

print(np.nextafter(np.float32(0), np.float32(1)))  # prints 1.4013e-45
print(np.finfo(np.float32).max)  # print 3.40282e+38

为了确保你的计算是稳定的,你希望避免具有很小或非常大绝对值的值。这听起来非常明显,但这些问题可以变得非常难以调试,尤其是在 PyTorch 中进行梯度下降时。这是因为你不仅需要确保前向传递中的所有值都在你的数据类型的有效范围内,而且还需要确保反向传递(梯度计算期间)也是如此。

让我们来看一个真实的例子。我们想要计算一个 logits 向量的 softmax。一个简单的实现可能看起来像这样:

import torch

def unstable_softmax(logits):
    exp = torch.exp(logits)
    return exp / torch.sum(exp)

print(unstable_softmax(torch.tensor([1000., 0.])).numpy())  # 打印 [nan, 0.]

请注意,对于相对较小的数字计算 logits 的指数会导致巨大的结果,这些结果超出了 float32 范围。我们简单的 softmax 实现的最大有效 logit 是 ln(3.40282e+38) = 88.7`,任何超过这个值的东西都会导致 nan 结果。

但我们如何使其更稳定呢?解决方案相对简单。很容易看出 exp(x - c) Σ exp(x - c) = exp(x) / Σ exp(x)。因此,我们可以从 logits 中减去任何常数,结果将保持不变。我们选择这个常数为 logits 的最大值。这样,指数函数的域将被限制在 [-inf, 0],因此其范围将是 [0.0, 1.0],这是可取的:

import torch

def softmax(logits):
    exp = torch.exp(logits - torch.reduce_max(logits))
    return exp / torch.sum(exp)

print(softmax(torch.tensor([1000., 0.])).numpy())  # 打印 [1., 0.]

让我们来看一个更复杂的情况。考虑我们有一个分类问题。我们使用 softmax 函数从我们的 logits 生成概率。然后我们将我们的损失函数定义为预测和标签之间的交叉熵。回想一下,对于一个分类分布,交叉熵可以简单地定义为 xe(p, q) = -Σ p_i log(q_i)。因此,交叉熵的简单实现看起来像这样:

def unstable_softmax_cross_entropy(labels, logits):
    logits = torch.log(softmax(logits))
    return -torch.sum(labels * logits)

labels = torch.tensor([0.5, 0.5])
logits = torch.tensor([1000., 0.])

xe = unstable_softmax_cross_entropy(labels, logits)

print(xe.numpy())  # 打印 inf

请注意,在这个实现中,随着 softmax 输出接近零,对数的输出接近无穷大,这导致我们的计算不稳定。我们可以通过展开 softmax 并进行一些简化来重写这个实现:

def softmax_cross_entropy(labels, logits, dim=-1):
    scaled_logits = logits - torch.max(logits)
    normalized_logits = scaled_logits - torch.logsumexp(scaled_logits, dim)
    return -torch.sum(labels * normalized_logits)

labels = torch.tensor([0.5, 0.5])
logits = torch.tensor([1000., 0.])

xe = softmax_cross_entropy(labels, logits)

print(xe.numpy())  # 打印 500.0

我们还可以验证梯度也被正确计算:

logits.requires_grad_(True)
xe = softmax_cross_entropy(labels, logits)
g = torch.autograd.grad(xe, logits)[0]
print(g.numpy())  # 打印 [0.5, -0.5]

让我再次提醒,当你进行梯度下降时,必须格外小心,以确保每一层的函数范围以及梯度都在有效范围内。指数和对数函数在天真地使用时尤其成问题,因为它们可以将小数字映射到巨大的数字,反之亦然。

用混合精度加快训练

默认情况下,PyTorch 中的张量和模型参数以 32 位浮点精度存储。使用 32 位浮点数训练神经网络通常是稳定的,不会引起重大数值问题,然而神经网络在 16 位甚至更低精度下也表现出色。在现代 GPU 上,使用更低精度进行计算可以显著更快。它还有额外的好处是使用更少的内存,使得训练更大的模型和/或更大的批次大小成为可能,这可以进一步提高性能。但问题在于,16 位训练往往变得非常不稳定,因为精度通常不足以执行一些操作,如累积。

为了解决这个问题,PyTorch 支持混合精度训练。简而言之,混合精度训练是通过执行一些昂贵的操作(如卷积和矩阵乘法)在 16 位中进行,同时执行其他数值敏感的操作(如累积)在 32 位中进行。这样我们就能在没有 16 位计算的缺点的情况下获得 16 位计算的所有好处。接下来我们将讨论使用 AutocastGradScaler 进行自动混合精度训练。

Autocast

autocast 通过自动将数据转换为 16 位来提高 Runtime 性能,以进行某些计算。让我们通过一个例子来了解它的工作原理:

import torch

x = torch.rand([32, 32]).cuda()
y = torch.rand([32, 32]).cuda()

with torch.cuda.amp.autocast():
  a = x + y
  b = x @ y

print(a.dtype)  # 打印 torch.float32
print(b.dtype)  # 打印 torch.float16

请注意,x 和 y 都是 32 位张量,但 autocast 执行矩阵乘法时使用 16 位,同时保持加法操作在 32 位。如果其中一个操作数是 16 位的呢?

import torch

x = torch.rand([32, 32]).cuda()
y = torch.rand([32, 32]).cuda().half()

with torch.cuda.amp.autocast():
  a = x + y
  b = x @ y

print(a.dtype)  # 打印 torch.float32
print(b.dtype)  # 打印 torch.float16

再次,autocast 将 32 位操作数转换为 16 位以执行矩阵乘法,但它不会改变加法操作。默认情况下,PyTorch 中两个张量的相加操作会导致转换为更高的精度。

在实践中,你可以信任 autocast 进行正确的转换以提高运行效率。重要的是要将所有前向传递计算放在 autocast 上下文中:

model = ...
loss_fn = ...
with torch.cuda.amp.autocast():
  outputs = model(inputs)
  loss = loss_fn(outputs, targets)

这可能是你所需要的,如果你有一个相对稳定的优化问题,并且使用相对较低的学习率。添加这行额外的代码可以在现代硬件上将训练时间缩短一半。

GradScaler

正如我们在本节开头提到的,16 位精度可能并不总是足以进行某些计算。一个特别有趣的情况是表示梯度值,这些值中的大部分通常都是小值。用 16 位浮点数表示它们通常会导致缓冲区下溢(即它们将被表示为零)。这使得训练神经网络非常不稳定。GradScaler 旨在解决这个问题。它将损失值作为输入,并将其乘以一个大的标量,从而膨胀梯度值,因此它们可以在 16 位精度中表示。然后,在梯度更新期间将它们缩小,以确保参数被正确更新。这通常是 GradScaler 所做的。但在底层,GradScaler 比这更聪明。膨胀梯度可能会导致溢出,这同样糟糕。因此,GradScaler 实际上监控梯度值,如果检测到溢出,它会跳过更新,根据可配置的计划缩小标量因子。(默认计划通常有效,但你可能需要根据你的用例进行调整。)

在实践中,使用 GradScaler 非常简单:

scaler = torch.cuda.amp.GradScaler()

loss = ...
optimizer = ...  # torch.optim.Optimizer 的一个实例

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

请注意,我们首先创建了一个 GradScaler 实例。在训练循环中,我们调用 GradScaler.scale 来缩放损失,然后再调用 backward 产生膨胀的梯度,然后使用 GradScaler.step(可能会)更新模型参数。然后我们调用 GradScaler.update 进行标量更新(如果需要)。就这样!

以下是展示在合成问题上进行混合精度训练的示例代码,学习从图像坐标生成棋盘格。你可以将它粘贴到 Google Colab 上,将后端设置为 GPU,并比较单精度和混合精度性能。请注意,这是一个小玩具示例,在实践中,使用混合精度训练更大的网络可能会看到更大的性能提升。

生成一个棋盘格

import torch
import matplotlib.pyplot as plt
import time

def grid(width, height):
  hrange = torch.arange(width).unsqueeze(0).repeat([height, 1]).div(width)
  vrange = torch.arange(height).unsqueeze(1).repeat([1, width]).div(height)
  output = torch.stack([hrange, vrange], 0)
  return output


def checker(width, height, freq):
  hrange = torch.arange(width).reshape([1, width]).mul(freq / width / 2.0).fmod(1.0).gt(0.5)
  vrange = torch.arange(height).reshape([height, 1]).mul(freq / height / 2.0).fmod(1.0).gt(0.5)
  output = hrange.logical_xor(vrange).float()
  return output

# Note the inputs are grid coordinates and the target is a checkerboard
inputs = grid(512, 512).unsqueeze(0).cuda()
targets = checker(512, 512, 8).unsqueeze(0).unsqueeze(1).cuda()

定义一个卷积神经网络

class Net(torch.jit.ScriptModule):
  def __init__(self):
    super().__init__()
    self.net = torch.nn.Sequential(
      torch.nn.Conv2d(2, 256, 1),
      torch.nn.BatchNorm2d(256),
      torch.nn.ReLU(),
      torch.nn.Conv2d(256, 256, 1),
      torch.nn.BatchNorm2d(256),
      torch.nn.ReLU(),
      torch.nn.Conv2d(256, 256, 1),
      torch.nn.BatchNorm2d(256),
      torch.nn.ReLU(),
      torch.nn.Conv2d(256, 1, 1)
    )

  @torch.jit.script_method
  def forward(self, x):
    return self.net(x)

单精度训练

net = Net().cuda()
loss_fn = torch.nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), 0.001)

start_time = time.time()

for i in range(500):
  opt.zero_grad()
  outputs = net(inputs)
  loss = loss_fn(outputs, targets)
  loss.backward()
  opt.step()

print(loss)

print(time.time() - start_time)

plt.subplot(1,2,1); plt.imshow(outputs.squeeze().detach().cpu());
plt.subplot(1,2,2); plt.imshow(targets.squeeze().cpu()); plt.show()

混合精度训练

net = Net().cuda()
loss_fn = torch.nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), 0.001)

scaler = torch.cuda.amp.GradScaler()

start_time = time.time()

for i in range(500):
  opt.zero_grad()
  with torch.cuda.amp.autocast():
    outputs = net(inputs)
    loss = loss_fn(outputs, targets)
  scaler.scale(loss).backward()
  scaler.step(opt)
  scaler.update()

print(loss)

print(time.time() - start_time)

plt.subplot(1,2,1); plt.imshow(outputs.squeeze().detach().cpu().float());
plt.subplot(1,2,2); plt.imshow(targets.squeeze().cpu().float()); plt.show()