PyTorch 2.0:更快、更Pythonic、更动态的下一代深度学习框架

PyTorch 是一个开源的机器学习框架,它可以让研究人员和开发者快速地从原型设计到生产部署。PyTorch 的最大优势之一是它提供了一个简单、灵活、命令式的编程风格,让用户可以像写 Python 代码一样写深度学习代码,并且能够实时地调试和优化模型。

近日,PyTorch 团队宣布了 PyTorch 2.0 的发布,这是 PyTorch 的下一代版本,它在保持原有的 eager 模式和用户体验的基础上,从根本上改变了 PyTorch 在编译器级别的运行方式。PyTorch 2.0 能够为动态形状(Dynamic Shapes)和分布式运行提供更快的性能和更好的支持。此外,PyTorch 2.0 还包括了一些新的功能和改进,例如 Accelerated Transformers(加速版变换器)、torch.compile(编译 API)、torch.func(函数式 API)等等。

在本文中,我们将简要介绍 PyTorch 2.0 的主要特点和意义,并给出一些示例代码来展示如何使用 PyTorch 2.0 来构建和优化深度学习模型。

动态形状(Dynamic Shapes)

动态形状是指在运行时可以改变张量(Tensor)的大小或维度的能力。这对于处理不同长度或格式的数据非常有用,例如文本、语音、图像等。动态形状也可以让模型更加灵活和通用,例如可以根据输入数据的特征来调整网络结构或参数。

PyTorch 原本就支持动态形状,在 eager 模式下,用户可以随时修改张量或模块(Module)的属性,并且立即看到效果。然而,在某些情况下,为了提高性能或兼容性,用户可能需要将模型转换为静态图(Static Graph)或导出为其他格式(例如 ONNX)。这时候就需要对模型进行跟踪(Tracing)或脚本化(Scripting),将 Python 代码转换为 TorchScript 代码,并生成一个可序列化(Serializable)和可优化(Optimizable)的计算图。

跟踪或脚本化过程中会损失动态形状的信息,因为 TorchScript 需要知道每个操作所涉及的张量大小或维度。如果输入数据发生变化,则需要重新跟踪或脚本化模型,并且可能导致性能下降或错误发生。

为了解决这个问题,PyTorch 2.0 引入了 torch.compile 这个新功能。torch.compile 可以将任何 PyTorch 模型包装成一个编译后的模型,并返回一个与原始模型具有相同接口和行为的对象。编译后的模型可以自动适应不同大小或维度的输入张量,而无需重新编译或跟踪。编译后的模型可以利用 TorchDynamo、AOTAutograd、PrimTorch 和 TorchInductor 这些新技术来进行优化和加速。

torch.compile 的使用非常简单,只需要将模型或函数作为参数传入即可。例如,我们可以定义一个简单的函数 foo,它接受两个张量 x 和 y 作为输入,并返回它们的正弦和余弦之和。然后我们可以用 torch.compile 来包装这个函数,并得到一个编译后的函数 opt_foo。我们可以像调用原始函数一样调用编译后的函数,并且得到相同的结果,但是速度会更快。

import torch

def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(x)
    return a + b

opt_foo = torch.compile(foo)

print(opt_foo(torch.randn(10, 10), torch.randn(10, 10)))

torch.compile 还支持一些其他的参数,例如 fullgraph、dynamic、backend、mode、options 等,可以用来控制编译过程中的细节和选项3。例如,我们可以指定 backend 为 ‘inductor’ 来使用 TorchInductor 编译器,它可以生成针对不同加速器和后端的高效代码。我们还可以指定 options 为 {‘matmul-padding’: True} 来开启矩阵乘法(MatMul)的填充优化。

@torch.compile(backend='inductor', options={'matmul-padding': True})
def opt_foo(x, y):
    a = torch.sin(x)
    b = torch.cos(x)
    return a + b

print(opt_foo(torch.randn(10, 10), torch.randn(10, 10)))

加速版变换器(Accelerated Transformers)

变换器(Transformer)是一种流行的深度学习模型,它主要用于自然语言处理(NLP)等领域,具有强大的表示能力和并行性。变换器的核心是自注意力(Self-Attention)机制,它可以让模型捕捉输入序列中的长距离依赖关系,并且可以高效地利用 GPU 或其他加速器。

然而,自注意力机制也有一些缺点,例如计算复杂度和内存消耗随着序列长度的增加而呈平方级增长,这限制了模型可以处理的最大序列长度。为了解决这个问题,PyTorch 2.0 包含了一个稳定版本的加速版变换器(Accelerated Transformers),它使用了一些优化技术来提高变换器的性能和效率。

加速版变换器的前身称作 Better Transformer,它使用了自定义的内核架构,在训练和推理中使用缩放点积注意力(Scaled Dot Product Attention,SPDA)方法,以获得更好的性能。SPDA 是一种基于矩阵乘法和 softmax 的自注意力方法,它可以充分利用 GPU 的并行计算能力,并且可以通过 torch.compile 来进一步优化。

PyTorch 2.0 提供了一个新的 scaled_dot_product_attention 函数作为 torch.nn.functional 的一部分。这个函数包含了几种不同的实现方式,可以根据输入数据和硬件设备来选择适合的方法。例如,FlashAttention 是一种节省内存的算法,它可以在不影响精度的情况下减少中间结果所占用的空间。Memory-Efficient Attention 是另一种节省内存的算法,它使用了块状矩阵乘法和重复利用梯度等技巧来降低内存消耗。

用户可以通过调用 scaled_dot_product_attention 函数来直接使用这些优化方法,也可以通过 torch.compile 来自动选择最佳方法。例如,我们可以定义一个简单的变换器层 TransformerLayer,并使用 scaled_dot_product_attention 函数来实现自注意力机制。

import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerLayer(nn.Module):
    def __init__(self, d_model, nhead):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead)
        self.linear1 = nn.Linear(d_model, d_model)
        self.linear2 = nn.Linear(d_model, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        # x: (seq_len, batch_size, d_model)
        attn_output, _ = F.scaled_dot_product_attention(
            self.self_attn.in_proj_q(x),
            self.self_attn.in_proj_k(x),
            self.self_attn.in_proj_v(x),
            key_padding_mask=None,
            need_weights=False,
            attn_mask=None,
            dropout_p=self.self_attn.dropout,
        )
        # attn_output: (seq_len, batch_size, d_model)
        x = x + self.self_attn.dropout(attn_output)
        x = self.norm1(x)

        y = F.relu(self.linear1(x))
        y = self.linear2(y)
        x = x + y
        x = self.norm2(x)

        return x

函数式 API(torch.func)

torch.func 是 PyTorch 2.0 中的一个新模块,它提供了一些类似于 JAX 的可组合函数变换(Composable Function Transforms)。函数变换是一种高阶函数,它接受一个数值函数作为输入,并返回一个计算不同量的新函数。torch.func 包含了一些自动微分变换(例如 grad、vjp、jvp 等),一个向量化/批处理变换(vmap),以及其他一些有用的变换。

这些函数变换可以任意地组合在一起,从而实现一些在 PyTorch 中难以做到的用例,例如计算每个样本的梯度(Per-Sample-Gradients)、运行模型集成、批量处理 MAML 的内循环、高效地计算雅可比矩阵和海森矩阵等等。

torch.func 的使用也非常简单,只需要将目标函数作为参数传入相应的函数变换即可。例如,我们可以定义一个简单的损失函数 loss,并使用 grad 函数来得到一个计算损失函数梯度的新函数 grad_loss。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.func import grad

model = nn.Linear(3, 3)
x = torch.randn(4, 3)
t = torch.randn(4, 3)

def loss(params):
    y = F.linear(x, params['weight'], params['bias'])
    return F.mse_loss(y, t)

grad_loss = grad(loss)

params = dict(model.named_parameters())
grads = grad_loss(params)
print(grads['weight'])
print(grads['bias'])

torch.func 还支持对模块进行函数式调用,即不改变模块本身的参数和缓冲区,而是使用给定的参数和缓冲区来执行模块。这可以通过 functional_call 函数来实现。例如,我们可以定义一个简单的线性层 LinearLayer,并使用 functional_call 来用不同的权重和偏置来调用它。

import torch
import torch.nn as nn
from torch.func import functional_call

class LinearLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        self.bias = nn.Parameter(torch.randn(out_features))

    def forward(self, x):
        return F.linear(x, self.weight, self.bias)

layer = LinearLayer(3, 3)
x = torch.randn(4, 3)

params1 = dict(layer.named_parameters())
y1 = layer(x) # equivalent to functional_call(layer, params1, x)

params2 = {'weight': torch.ones(3, 3), 'bias': torch.zeros(3)}
y2 = functional_call(layer, params2, x) # use different parameters

print(y1)
print(y2)

结论

PyTorch 2.0 是 PyTorch 的下一代版本,它在保持原有的 eager 模式和用户体验的基础上,从根本上改变了 PyTorch 在编译器级别的运行方式。PyTorch 2.0 能够为动态形状和分布式运行提供更快的性能和更好的支持。此外,PyTorch 2.0 还包括了一些新的功能和改进,例如加速版变换器、torch.compile、torch.func 等等。

这些新特性可以让 PyTorch 用户更方便地构建和优化深度学习模型,并且可以应对各种复杂的场景和需求。PyTorch 2.0 是一个值得期待的版本,我们希望它能够为 PyTorch 社区带来更多的创新和价值。

本文由作者 王大神 原创发布于 大神网的AI博客。

转载请注明作者:王大神

原文出处:PyTorch 2.0:更快、更Pythonic、更动态的下一代深度学习框架

(0)
打赏 微信扫一扫 微信扫一扫
上一篇 2023年3月19日
下一篇 2023年3月19日