Unit Scaling是一种模型设计方法,能够让FP16和FP8等低精度数字格式更加易用。我们很高兴地宣布,我们发布了一个PyTorch库来促进Unit Scaling的使用。
7月,Graphcore(拟未)在ICML上发表了论文《Unit Scaling:开箱即用的低精度训练》。现在,我们发布软件工具,让更多人可以使用这种方法。
支持FP8的硬件的发展大幅提高了用户效率,例如Graphcore® C600 IPU处理器PCIe卡。但是,简单地将较高精度的值转换为FP8值往往会导致性能降低。Unit Scaling解决了这一问题,为充分利用FP8硬件进行训练提供了一条简单的路径。
查看库文档
阅读我们的ICML论文
库操作演示
为了向用户展示如何在他们自己的模型中应用Unit Scaling,我们还发布了一个和库配套的notebook。它展示了在FP8下,使用和不使用Unit Scaling的nanoGPT模型训练。
只需一行代码——model = unit_scale(model)——用户就可以将他们的PyTorch模块转化为单位缩放模型。我们在notebook中通过训练以下模型说明了这一点:
from nanoGPT.model import GPT
from notebook_utils import config, train
from unit_scaling.transforms import simulate_fp8, unit_scale
gpt = GPT(config) # model unchanged from original nanoGPT
fp8_gpt = simulate_fp8(gpt)
unit_scaled_fp8_gpt = unit_scale(fp8_gpt)
models = [gpt, fp8_gpt, unit_scaled_fp8_gpt]
for model in models:
train(model)

在FP8中直接训练基础模型会导致明显的性能下降。不过,使用Unit Scaling可以恢复全精度。
这种单行转换可应用于任意PyTorch模型,与torch.compile一起使用时,开销可忽略不计。
实施Unit Scaling
单行自动unit_scale()转换是一项实验性功能。我们建议大多数用户通过以下方式手动实施Unit Scaling。
您可以考虑这种导入PyTorch模块/函数的常见方法:
from torch import nn
from torch.nn import functional as F
在此设置中,您可以这样实现Unit Scaling,首先添加:
import unit_scaling as uu
from unit_scaling import functional as U
然后将字母nn替换为uu,将F替换为U,让这样这些类和函数变成单位缩放的。例如:
class UnitScaledMLP(nn.Module):
def _init_(self, d:int) -> None:
super()._init_()
self.linear_1 = uu.Linear(d, d*4) # Changed 'nn' to 'uu'
self.linear_2 = uu.Linear(d*4, d) # Changed 'nn' to 'uu'
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = U.gelu(x) # Changed 'F' to 'U'
return self.linear_2(x)
我们的用户指南中还介绍了运用Unit Scaling所需的一些其他注意事项。用户应特别注意正确缩放跳转/残差增量和损失函数。
使用库
Unit Scaling可通过以下方式安装:
pip install git+https://github.com/grapchore-research/unit-scaling.git
尽管我们付出了诸多努力,但unit_scaling毕竟是一个新的库,我们不能保证它完全没有bug或功能已经完全齐备。不过我们将非常愿意为这个库的使用者提供帮助。