在今天的Wave Summit+2021深度学习开发者峰会上,Graphcore和PaddlePaddle正式宣布完成了IPU与PaddlePaddle框架的适配工作,PaddlePaddle框架已经能够支持Graphcore的IPU处理器,用户可以在PaddlePaddle框架选择IPU硬件进行AI训练或推理。Graphcore在2020年5月正式加入PaddlePaddle硬件生态圈,是PaddlePaddle在云端训练和推理方面重要的合作伙伴。
为什么要支持PaddlePaddle
PaddlePaddle是百度开源的产业级深度学习框架,在百度内部业务以及AI产业界都有广泛的应用。目前,目前累计开发者数量超过406万,服务企业数量15.7万家,创建了超过47.6万个模型。一直以来,PaddlePaddle助力开发者快速实现AI想法,快速上线AI业务,帮助越来越多的行业完成AI赋能,实现产业智能化升级。
设计架构
PaddlePaddle的架构设计非常优秀,其核心的AI编译器拥有定义良好的IR(Intermediate Representation)系统以及用于做图优化的IR Pass系统。同时,PaddlePaddle作为一个成熟的AI框架,拥有良好的扩展性,开发者可以通过扩展新的Device类型、新的Operator、新的Pass、新的Executor等来支持新的硬件类型。
Graphcore研发团队在开发时,以PaddlePaddle的IR层作为切入点来支持IPU,遵守的原则是尽量减少对PaddlePaddle框架原生代码的侵入式修改,而是通过扩展IR Pass、扩展Operator的方式来增量式开发,尽可能地减少对PaddlePaddle原有代码逻辑的影响。

当前进展
目前,Graphcore的IPU既支持通过PaddlePaddle进行大规模的模型训练任务,也支持通过Paddle Inference库执行高性能的推理任务。
BERT训练
Pretrain Phase1(sequence_length=128):





Paddle | SOTA | |
Phase1 Tput | 9200 | |
Phase2 Tput | 2700 | |
SQuAD EM | 80.48249 | 80.8 |
SQuAD F1 | 87.556685 | 88.5 |
如何开始
安装
1. 环境准备
Graphcore的Poplar SDK对硬件、操作系统以及软件环境均有一定的要求,具体参见Poplar SDK的需求文档。
2. 通过源码编译安装
# 下载源码
git clone -b paddle_bert_release https://github.com/graphcore/Paddle.git
# 构建 docker 镜像
docker build -t paddlepaddle/paddle:dev-ipu-2.3.0 \
-f tools/dockerfile/Dockerfile.ipu .
# 创建并运行 docker container
IPU依赖于ipu.conf配置文件进行分区,需要有可用的ipu.conf才能获取IPU设备。如果没有ipu.conf,可参考如下命令生成。
例:生成POD16(16个IPU)配置文件:
vipu create partition ipu --size 16
ipu.conf将会生成在以下路径:
ls ~/.ipuof.conf.d/
请将以下命令的${HOST_IPUOF_PATH}替换成host中ipu.conf的绝对路径。
docker run –ulimit memlock=-1:-1 –net=host –cap-add=IPC_LOCK \
–device=/dev/infiniband/ –ipc=host –name paddle-ipu-dev \
-v ${HOST_IPUOF_PATH}:/ipuof \
-e IPUOF_CONFIG_PATH=/ipuof/ipu.conf \
-it paddlepaddle/paddle:dev-ipu-2.3.0 bash
注意:之后的操作均在container内执行。
# 验证IPU设备
通过以下命令可以查看当前全部IPU设备,以及正在使用的IPU设备ID
gc-monitor
出现以下图片说明正常获取IPU设备。如果没办法获取IPU设备,请检查是否提供正确的ipu.conf。

# 编译 PaddlePaddle
git clone -b paddle_bert_release https://github.com/graphcore/Paddle.git
cd Paddle
cmake -DPYTHON_EXECUTABLE=/usr/bin/python \
-DWITH_PYTHON=ON -DWITH_IPU=ON -DPOPLAR_DIR=/opt/poplar \
-DPOPART_DIR=/opt/popart -G “Unix Makefiles” -H`pwd` -B`pwd`/build
cmake –build \`pwd`/build –config Release –target paddle_python -j$(nproc)
# 安装wheel包
pip install -U build/python/dist/paddlepaddle-0.0.0-cp37-cp37m-linux_x86_64.whl
# 验证安装
python -c "import paddle; print(paddle.fluid.is_compiled_with_ipu())"
预期得到以下结果:
> True
例子
Bert-Base训练
Source: https://github.com/graphcore/portfolio-examples/tree/master/paddlepaddle/bert-base
Bert-Base Training 包含以下任务:
- phase1: sequence_length=128预训练
- phase2: sequence_length=384预训练
- SQuAD fine-tune
- SQuAD validation
1. 数据准备:
1. Pre-train dataset (通过NVIDIA提供的脚本生成数据):
git clone https://github.com/NVIDIA/DeepLearningExamples.git
cd DeepLearningExamples/TensorFlow/LanguageModeling/BERT
bash scripts/docker/build.sh
cd data/
vim create_datasets_from_start.sh
修改line 40 –max_seq_length 512 成–max_seq_length 384修改line 41 — max_predictions_per_seq 80 成–max_predictions_per_se 56
cd ../
bash scripts/data_download.sh wiki_only
将生成sequence_length=128 和 384的tfrecord输入数据。
2. SQuAD dataset
# Fine-tune dataset
curl --create-dirs -L https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json -o data/squad/train-v1.1.json
# Validation dataset
curl --create-dirs -L https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json -o data/squad/dev-v1.1.json
2. PaddleNLP:
本例除了依赖Paddlepaddle(已在上文中安装),还依赖PaddleNLP进行模型构建和数据处理。请按照如下操作安装PaddleNLP:
# 安装依赖:
pip3.7 install jieba h5py colorlog colorama seqeval multiprocess numpy==1.19.2 paddlefsl==1.0.0 six==1.13.0 wandb
pip3.7 install torch==1.7.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip3.7 install torch-xla@https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp37-cp37m-linux_x86_64.whl
# 安装PaddleNLP
pip3.7 install git+https://github.com/graphcore/PaddleNLP.git@paddle_bert_release
3. Run:
修改run_stage.sh中的–input_dir为对应的输入数据路径:
- Phase1: tfrecord(sequence_length=128)文件的存放路径
- Phase2: tfrecord(sequence_length=384)文件的存放路径
- Fine-tune: train-v1.1.json的存放路径
- Validation: dev-v1.1.json的存放路径
run_stage.sh有4个参数:
- device:ipu或cpu
- stage:phase1, phase2, SQuAD或validation
- input_pdparams:导入的pdparams的路径+前缀
- output_pdparams:导出的pdparams的路径+前缀
注意:程序中有使用wandb记录运行数据,运行时会弹出如下提示,需要选择对应的模式运行。如果没有wandb账号请输入3。
#Run Phase1:
#phase1不需要导入params,随机初始化权重
./run_stage.sh ipu phase1 _ pretrained_128_model
#Run Phase2:
#phase2需要导入phase1训练好的params
./run_stage.sh ipu phase2 pretrained_128_model pretrained_384_model
#Run SQuAD fine-tune:
#fine-tune需要导入phase2训练好的params
./run_stage.sh ipu SQuAD pretrained_384_model finetune_model
#Run validation:
./run_stage.sh ipu validation finetune_model _
Paddle inference demo
#生成Paddle inference库
基于之前的Paddlepaddle编译命令增加-DON_INFER=ON,编译完成后将在build目录下生成 paddle_inference_install_dir目录。该目录为Paddle inference库目录。
cmake -DPYTHON_EXECUTABLE=/usr/bin/python \
-DWITH_PYTHON=ON –DON_INFER=ON -DWITH_IPU=ON -DPOPLAR_DIR=/opt/poplar \
-DPOPART_DIR=/opt/popart -G “Unix Makefiles” -H`pwd` -B`pwd`/build
cmake --build \`pwd`/build --config Release --target paddle_python -j$(nproc)

#可参考以下ipu_word2vec_sample.cc使用Paddle inference库进行推理:
下载模型:
wget -q http://paddle-inference-dist.bj.bcebos.com/word2vec.inference.model.tar.gz
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the “License”);
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an “AS IS” BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file contains a simple demo for how to take a model for inference with IPUs.
*/
#include <iostream>
#include <vector>
#include <numeric>
#include <string>
#include “paddle/fluid/inference/api/paddle_inference_api.h”
#include “gflags/gflags.h”
#include “glog/logging.h”
DEFINE_string(infer_model, “”, “Directory of the inference model.”);
using paddle_infer::Config;
using paddle_infer::Predictor;
using paddle_infer::CreatePredictor;
void inference(std::string model_path, bool use_ipu, std::vector<float> *out_data) {
//# 1. Create Predictor with a config.
Config config;
config.SetModel(FLAGS_infer_model);
if (use_ipu) {
// ipu_device_num, ipu_micro_batch_size
config.EnableIpu(1, 4);
}
auto predictor = CreatePredictor(config);
//# 2. Prepare input/output tensor.
auto input_names = predictor->GetInputNames();
std::vector<int64_t> data{1, 2, 3, 4};
// For simplicity, we set all the slots with the same data.
for (auto input_name : input_names) {
auto input_tensor = predictor->GetInputHandle(input_name);
input_tensor->Reshape({4, 1});
input_tensor->CopyFromCpu(data.data());
}
//# 3. Run
predictor->Run();
//# 4. Get output.
auto output_names = predictor->GetOutputNames();
auto output_tensor = predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_tensor->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
out_data->resize(out_num);
output_tensor->CopyToCpu(out_data->data());
}
int main(int argc, char *argv[]) {
::GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true);
std::vector<float> ipu_result;
std::vector<float> cpu_result;
inference(FLAGS_infer_model, true, &ipu_result);
inference(FLAGS_infer_model, false, &cpu_result);
for (size_t i = 0; i < ipu_result.size(); i++) {
CHECK_NEAR(ipu_result[i], cpu_result[i], 1e-6);
}
LOG(INFO) << “Finished”;
}
#编译方法:
CMakeList.txt:
cmake_minimum_required(VERSION 3.0)
project(cpp_inference_demo CXX C)
include_directories(“${PADDLE_LIB}/”)
set(PADDLE_LIB_THIRD_PARTY_PATH “${PADDLE_LIB}/third_party/install/”)
include_directories(“${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/include”)
include_directories(“${PADDLE_LIB_THIRD_PARTY_PATH}glog/include”)
include_directories(“${PADDLE_LIB_THIRD_PARTY_PATH}gflags/include”)
include_directories(“${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/include”)
include_directories(“${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/include”)
include_directories(“${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn/include”)
include_directories(“${PADDLE_LIB_THIRD_PARTY_PATH}mklml/include”)
link_directories(“${PADDLE_LIB_THIRD_PARTY_PATH}protobuf/lib”)
link_directories(“${PADDLE_LIB_THIRD_PARTY_PATH}glog/lib”)
link_directories(“${PADDLE_LIB_THIRD_PARTY_PATH}gflags/lib”)
link_directories(“${PADDLE_LIB_THIRD_PARTY_PATH}xxhash/lib”)
link_directories(“${PADDLE_LIB_THIRD_PARTY_PATH}cryptopp/lib”)
link_directories(“${PADDLE_LIB_THIRD_PARTY_PATH}ipu”)
link_directories(“/opt/poplar/lib”)
link_directories(“/opt/popart/lib”)
link_directories(“${PADDLE_LIB}/paddle/lib”)
set(EXTERNAL_LIB “-lrt -ldl -lpthread”)
set(DEPS ${DEPS}
${PADDLE_LIB_THIRD_PARTY_PATH}mkldnn/lib/libdnnl.so.2
${PADDLE_LIB_THIRD_PARTY_PATH}mklml/lib/libiomp5.so
paddle_inference paddle_ipu flags
glog gflags protobuf xxhash cryptopp
${EXTERNAL_LIB})
set(CMAKE_CXX_FLAGS “-std=c++11”)
add_executable(${DEMO_NAME} ${DEMO_NAME}.cc)
target_link_libraries(${DEMO_NAME} ${DEPS})
#编译脚本compile.sh:
注意:请将${PADDLE_INFERENCE_INSTALL_DIR}替换为对应的paddle_inference库路径
#!/bin/bash
mkdir -p build
cd build
rm -rf *
DEMO_NAME=ipu_word2vec_sample
LIB_DIR=${PADDLE_INFERENCE_INSTALL_DIR}
cmake .. -DPADDLE_LIB=${LIB_DIR} -DDEMO_NAME=${DEMO_NAME}
make –j
#编译:
./compile.sh
#Run:
本测例将会分别完成IPU和CPU的推理,并对比两者结果
./ipu_word2vec_sample –-infer_model=word2vec.inference.model
相关资料