精度对齐#

1. 总览#

1.1 背景#

模型精度对齐是开展后续工作的关键,确保了相同模型在相同环境和相同参数配置下输出结果的稳定性和一致性,为后续的数据分析、决策制定以及系统优化提供了坚实的基础。

1.2 前序工作#

基于精度对齐验收标准,建议准备以下内容:

  • 准备好训练/验证数据集,用于模型训练与评估。

  • 准备好PyTorch模型结构,作为模型精度baseline。

  • 准备好验证设备,如模型参数为fp16,可准备V100、A100等计算卡,如模型参数为bf16,需准备A100等计算卡。

2. 整体流程#

整体流程包含模型结构对齐、准备小数据集、前向初次对齐、损失函数对齐、优化器对齐、学习率对齐、正则化策略对齐、反向初次对齐、训练集数据对齐和训练对齐。针对采用并行策略的大模型而言,分别增加了并行模型结构对齐、并行前向初次对齐和并行反向初次对齐。

2.1 流程概览#

验证模型精度的整体流程如下图所示:

align_workflow

3. 模型对齐流程#

3.1 模型结构对齐#

对齐模型结构时,一般有3个主要步骤:

  • 网络结构代码转换

  • 权重转换

  • 模型组网正确性验证

3.1.1 网络结构代码转换#

【基本流程】

PyTorch的API和PaddlePaddle的API基本相似, 可以参考 PyTorch最新release与Paddle develop API映射表 , 部分组网代码也可手动转换。

【代码自动转换工具】

代码自动转换工具PaConvert 能自动将其它深度学习框架训练或推理的代码,转换为 PaddlePaddle 的代码,方便快速自动地 模型代码迁移。

目前仅支持自动转换 Pytorch 代码,其它深度学习框架的支持后续新增中, 转换时会尽量保持原代码的风格与结构,将其它深度学习框架的 API 接口 转换为 PaddlePaddle 的 API 接口。

【大模型网络结构示例】

3.1.2 权重转换#

【基本流程】

组网代码转换完成之后,需要对模型权重进行转换。

  1import json
  2import os
  3import shutil
  4import copy
  5import paddle
  6import torch
  7from safetensors.torch import load_file
  8from safetensors.numpy import save_file
  9from paddlenlp.utils.log import logger
 10from paddlenlp.transformers import Qwen2MoeForCausalLM, AutoConfig
 11
 12
 13def execute_cmd(cmd, file_path):
 14    cmd = cmd + " " + file_path
 15    os.system(cmd)
 16
 17
 18def convert_from_torch_to_paddle(torch_path=None, paddle_path=None, torch_prefix_key="model.", paddle_class=Qwen2MoeForCausalLM, delete_after_convert=False):
 19    assert torch_path is not None
 20    if paddle_path is None:
 21        paddle_path = torch_path + "-paddle"
 22    if not os.path.exists(paddle_path):
 23        os.mkdir(paddle_path)
 24
 25    config = AutoConfig.from_pretrained(torch_path)
 26    name_mappings = paddle_class._get_name_mappings(config=config)
 27
 28    torch_prefix_key = torch_prefix_key
 29    paddle_prefix_key = paddle_class.base_model_prefix + "."
 30
 31    if os.path.exists(os.path.join(torch_path, "model.safetensors.index.json")):
 32        index = json.load(open(os.path.join(torch_path, "model.safetensors.index.json")))
 33        dst_index = copy.deepcopy(index)
 34
 35        for key in list(dst_index["weight_map"].keys()):
 36            paddle_key = key.replace(torch_prefix_key, paddle_prefix_key)
 37            dst_index["weight_map"][paddle_key] = dst_index["weight_map"].pop(key)
 38
 39        files = set(index["weight_map"].values())
 40        logger.info(files)
 41
 42        for file_name in sorted(os.listdir(torch_path)):
 43            # skip hidden files
 44            if file_name.startswith("."):
 45                continue
 46
 47            logger.info(file_name)
 48            if file_name in files:
 49                # convert safetensors to safetensors(paddle)
 50                convert_safetensors_from_torch_to_paddle(file_name,
 51                                                        torch_path,
 52                                                        paddle_path,
 53                                                        torch_prefix_key,
 54                                                        paddle_prefix_key,
 55                                                        name_mappings,
 56                                                        delete_after_convert=False)
 57            else:
 58                # copy config.json and other files
 59                shutil.copy(os.path.join(torch_path, file_name), os.path.join(paddle_path, file_name))
 60
 61        json.dump(dst_index, open(os.path.join(paddle_path, "model.safetensors.index.json"), "w"), indent=2)
 62    else:
 63        for file_name in sorted(os.listdir(torch_path)):
 64            # skip hidden files
 65            if file_name.startswith("."):
 66                continue
 67
 68            logger.info(file_name)
 69            if file_name == "model.safetensors":
 70                convert_safetensors_from_torch_to_paddle(file_name,
 71                                                        torch_path,
 72                                                        paddle_path,
 73                                                        torch_prefix_key,
 74                                                        paddle_prefix_key,
 75                                                        name_mappings,
 76                                                        delete_after_convert=False)
 77            else:
 78                # copy config.json and other files
 79                shutil.copy(os.path.join(torch_path, file_name), os.path.join(paddle_path, file_name))
 80
 81    execute_cmd(cmd="sed -i -e  's/torch_dtype/dtype/g' ",
 82                file_path=os.path.join(paddle_path, "config.json"))
 83
 84def convert_safetensors_from_torch_to_paddle(file_name, torch_path, paddle_path, torch_prefix_key, paddle_prefix_key, name_mappings, delete_after_convert=False):
 85    tensors = load_file(os.path.join(torch_path, file_name))
 86
 87    transpose_state_dict = {}
 88    for name_mapping in name_mappings:
 89        if name_mapping.action == "transpose":
 90            transpose_state_dict[name_mapping.target_name] = True
 91        else:
 92            transpose_state_dict[name_mapping.target_name] = False
 93
 94    for key in list(tensors.keys()):
 95        paddle_key = key.replace(torch_prefix_key, paddle_prefix_key)
 96        logger.info("{} {}".format(key, tensors[key].shape))
 97        if transpose_state_dict[paddle_key]:
 98            t = tensors.pop(key).cuda().t().contiguous()
 99            capsule = torch.utils.dlpack.to_dlpack(t)
100            t = paddle.utils.dlpack.from_dlpack(capsule)
101            tensors[paddle_key] = t.numpy()
102        else:
103            t = tensors.pop(key).cuda()
104            capsule = torch.utils.dlpack.to_dlpack(t)
105            t = paddle.utils.dlpack.from_dlpack(capsule)
106            tensors[paddle_key] = t.numpy()
107
108            # tensors[dst_key] = paddle.to_tensor(tensors.pop(key).cuda().float().cpu().numpy(), dtype="bfloat16").numpy()
109        logger.info("{} {}".format(paddle_key, tensors[paddle_key].shape))
110
111    save_file(tensors, os.path.join(paddle_path, file_name), metadata={"format": "np"})
112    if delete_after_convert:
113        os.remove(os.path.join(torch_path, file_name))
114
115
116convert_from_paddle_to_torch(paddle_path="/root/code/PaddleNLP/ckpt/Qwen/Qwen2-0.5B" paddle_class=Qwen2MoeForCausalLM)

其中,模型结构中需实现_get_name_mapping方法,在这个方法中会将线性层参数标识需要转置的参数,进而适配Paddle nn.Linear的参数。参考如Qwen模型结构:

PaddlePaddle/PaddleNLP

 1class Qwen2PretrainedModel(PretrainedModel):
 2    @classmethod
 3    def _get_name_mappings(cls, config: Qwen2Config) -> list[StateDictNameMapping]:
 4        mappings: list[StateDictNameMapping] = []
 5        model_mappings = [
 6            ["embed_tokens.weight"],
 7            ["norm.weight"],
 8        ]
 9        for layer_index in range(config.num_hidden_layers):
10            layer_mappings = [
11                [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"],
12                [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"],
13                [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"],
14                [f"layers.{layer_index}.self_attn.q_proj.bias", None],
15                [f"layers.{layer_index}.self_attn.k_proj.bias", None],
16                [f"layers.{layer_index}.self_attn.v_proj.bias", None],
17                [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"],
18                [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"],
19                [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"],
20                [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"],
21                [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"],
22                [f"layers.{layer_index}.input_layernorm.weight"],
23                [f"layers.{layer_index}.post_attention_layernorm.weight"],
24            ]
25            model_mappings.extend(layer_mappings)
26
27        init_name_mappings(mappings=model_mappings)
28        # base-model prefix "Qwen2MoEModel"
29        if "Qwen2Model" not in config.architectures:
30            for mapping in model_mappings:
31                mapping[0] = "model." + mapping[0]
32                mapping[1] = "qwen2." + mapping[1]
33            if not config.tie_word_embeddings:
34                model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"])
35
36        mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)]
37        return mappings

3.1.3 模型组网正确性验证#

【基本流程】

  1. 定义PyTorch模型,加载权重,固定seed,基于numpy生成随机数,转换为PyTorch可以处理的tensor,送入网络,获取输出。

  2. 定义PaddlePaddle模型,加载权重,固定seed,基于numpy生成随机数,转换为PaddlePaddle可以处理的tensor,送入网络,获取输出。

  3. 排查diff,小于阈值,即可完成自测。

【示例代码】

 1import numpy as np
 2import paddle
 3import torch
 4from transformers import Qwen2Config as Qwen2Config_hf
 5from transformers import Qwen2ForCausalLM as Qwen2ForCausalLM_hf
 6
 7from paddlenlp.transformers import Qwen2Config, Qwen2ForCausalLM
 8
 9def eval_model_convert():
10    paddle_input_ids = paddle.to_tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
11    torch_input_ids = torch.LongTensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
12
13    # paddle model
14    paddle_ckpt_path = "Qwen/Qwen2-0.5B"
15    config_paddle = Qwen2Config.from_pretrained(paddle_ckpt_path)
16    model_paddle = Qwen2ForCausalLM.from_pretrained(paddle_ckpt_path, config=config_paddle, dtype="float32")
17
18    # torch model
19    torch_ckpt_path = "/root/.cache/modelscope/hub/Qwen/Qwen2-0___5B"
20    config_torch = Qwen2Config_hf.from_pretrained(torch_ckpt_path, trust_remote_code=True)
21    config_torch.dtype = "float32"
22    model_torch = Qwen2ForCausalLM_hf.from_pretrained(torch_ckpt_path, config=config_torch, trust_remote_code=True)
23
24    model_paddle.eval()
25    model_torch.eval()
26
27    out_paddle = model_paddle(paddle_input_ids)[0]
28    out_torch = model_torch(torch_input_ids, return_dict=False)[0]
29
30    assert np.allclose(out_paddle.numpy(), out_torch.detach().numpy(), rtol=1e-5, atol=1e-3)
31
32eval_model_convert()

【注意事项】

  • 模型在前向对齐验证时,需要调用model.eval()方法,保证组网中的随机量被关闭,比如BatchNorm、Dropout等。

  • 给定相同的输入数据,为保证可复现性,如果有随机数生成,固定相关的随机种子。

  • 输出diff可以使用np.max(np.abs(o1 - o2))进行计算,一般小于1e-5的话,可以认为前向没有问题。如果最终输出结果diff较大,可以使用二分的方法进行排查,比如说BERT,包含1个embdding层、12个transformer-block以及最后的MLM head层,那么完成模型组网和权重转换之后,如果模型输出没有对齐,可以尝试输出中间某一个transformer-block的tensor进行对比,如果相同,则向后进行排查;如果不同,则继续向前进行排查,以此类推,直到找到导致没有对齐的操作。

  • 在验证精度时需设置环境变量,避免算子的随机性,环境变量如下:

 1# 通用环境变量,避免随机性
 2export NVIDIA_TF32_OVERRIDE=0
 3export FLAGS_embedding_deterministic=1
 4export FLAGS_cudnn_deterministic=1
 5
 6# 并行计算环境变量,避免随机性
 7export Flags_mp_aysnc_allreduce=1
 8export Flags_skip_mp_c_identity=1
 9export FLAGS_shard_norm_align_dp=0
10export FLAGS_shard_use_reduce=1
11export FLAGS_sync_before_allreduce=1

3.1.4 分布式组网对齐#

【基本流程】

基本流程同 3.1.3 模型组网正确性验证。此外,在模型初始化时,需创建分布式并行环境,并使用paddle.distributed.launch进行启动运行,示例命令如下:

1python -m paddle.distributed.launch --devices 0,1 compare_torch_with_paddle.py

【示例代码】

 1import numpy as np
 2import paddle
 3import torch
 4from padiff import auto_diff
 5from transformers import Qwen2Config as Qwen2Config_hf
 6from transformers import Qwen2ForCausalLM as Qwen2ForCausalLM_hf
 7from paddle.distributed import fleet
 8from paddlenlp.transformers import Qwen2Config, Qwen2ForCausalLM
 9
10def eval_model_convert_parallel(mp_degree=1):
11    paddle_input_ids = paddle.to_tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
12    torch_input_ids = torch.LongTensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
13
14    strategy = fleet.DistributedStrategy()
15    strategy.hybrid_configs = {
16        "dp_degree": 1,
17        "mp_degree": mp_degree,
18        "pp_degree": 1,
19        "sharding_degree": 1,
20    }
21    fleet.init(is_collective=True, strategy=strategy)
22    hcg = fleet.get_hybrid_communicate_group()
23
24    # paddle model
25    paddle_ckpt_path = "Qwen/Qwen2-0.5B"
26    config_paddle = Qwen2Config.from_pretrained(paddle_ckpt_path)
27    config_paddle.tensor_parallel_degree = hcg.get_model_parallel_world_size()
28    config_paddle.tensor_parallel_rank = hcg.get_model_parallel_rank()
29    config_paddle.tensor_parallel_output = False
30    model_paddle = Qwen2ForCausalLM.from_pretrained(paddle_ckpt_path, config=config_paddle, dtype="float32")
31
32    # torch model
33    torch_ckpt_path = "/root/.cache/modelscope/hub/Qwen/Qwen2-0___5B"
34    config_torch = Qwen2Config_hf.from_pretrained(torch_ckpt_path, trust_remote_code=True)
35    config_torch.dtype = "float32"
36    model_torch = Qwen2ForCausalLM_hf.from_pretrained(torch_ckpt_path, config=config_torch, trust_remote_code=True)
37
38    model_paddle.eval()
39    model_torch.eval()
40
41    # 手动验证
42    out_paddle = model_paddle(paddle_input_ids)[0]
43    out_torch = model_torch(torch_input_ids, return_dict=False)[0]
44    assert np.allclose(out_paddle.numpy(), out_torch.detach().numpy(), rtol=1e-5, atol=1e-4)
45
46eval_model_convert_parallel(mp_degree=2)

【注意事项】

  • 在验证精度时需设置环境变量,避免算子的随机性,环境变量如下:

 1# 通用环境变量,避免随机性
 2export NVIDIA_TF32_OVERRIDE=0
 3export FLAGS_embedding_deterministic=1
 4export FLAGS_cudnn_deterministic=1
 5
 6# 并行计算环境变量,避免随机性
 7export Flags_mp_aysnc_allreduce=1
 8export Flags_skip_mp_c_identity=1
 9export FLAGS_shard_norm_align_dp=0
10export FLAGS_shard_use_reduce=1
11export FLAGS_sync_before_allreduce=1

3.2 前向对齐&反向对齐-对齐工具验证#

【基本流程】

上述手动验证方式对开发者而言较为繁琐,可采用自动验证PaDiff进行验证。PaDiff 是基于 PaddlePaddle 与 PyTorch 的模型精度对齐工具。传入 Paddle 或 Torch 模型,对齐训练中间结果以及训练后的模型权重,并提示精度 diff 第一次出现的位置。

PaDiff: PaddlePaddle/PaDiff

【使用方式】

 1import numpy as np
 2import paddle
 3import torch
 4from padiff import auto_diff
 5from transformers import Qwen2Config as Qwen2Config_hf
 6from transformers import Qwen2ForCausalLM as Qwen2ForCausalLM_hf
 7
 8from paddlenlp.transformers import Qwen2Config, Qwen2ForCausalLM
 9
10
11def eval_model_convert():
12    paddle_input_ids = paddle.to_tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
13    torch_input_ids = torch.LongTensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]])
14
15    # paddle model
16    paddle_ckpt_path = "Qwen/Qwen2-0.5B"
17    config_paddle = Qwen2Config.from_pretrained(paddle_ckpt_path)
18    model_paddle = Qwen2ForCausalLM.from_pretrained(paddle_ckpt_path, config=config_paddle, dtype="float32")
19
20    # torch model
21    torch_ckpt_path = "/root/.cache/modelscope/hub/Qwen/Qwen2-0___5B"
22    config_torch = Qwen2Config_hf.from_pretrained(torch_ckpt_path, trust_remote_code=True)
23    config_torch.dtype = "float32"
24    model_torch = Qwen2ForCausalLM_hf.from_pretrained(torch_ckpt_path, config=config_torch, trust_remote_code=True)
25
26    model_paddle.eval()
27    model_torch.eval()
28
29    # 手动验证
30    out_paddle = model_paddle(paddle_input_ids)[0]
31    out_torch = model_torch(torch_input_ids, return_dict=False)[0]
32    assert np.allclose(out_paddle.numpy(), out_torch.detach().numpy(), rtol=1e-5, atol=1e-4)
33
34    # 使用padiff验证
35    inp = ({"input_ids": torch_input_ids,
36            "use_cache": False,
37            "output_attentions": False,
38            "output_hidden_states": False,
39            "return_dict": False},
40        {"input_ids": paddle_input_ids})
41    # diff_phase 可以设置为forward,backword和both
42    auto_diff(model_torch, model_paddle, inp, atol=1e-4, rtol=1e3, auto_init=False, diff_phase="both", compare_mode="strict")
43
44eval_model_convert()

精度对齐情况参考,可作为验证标准

model

size

logits diff (float32)

loss diff (float32)

each tensor in all layers (float32)

Qwen/Qwen2-0.5B

0.5B

1e-4

1e-5

1e-4

Qwen/Qwen2-1.5B

1.5B

1e-3

1e-5

1e-3

Qwen/Qwen2-7B

7B

1e-3

1e-5

1e-3

Qwen/Qwen1.5-14B

14B

1e-4

1e-5

1e-4

3.3 模型训练对齐#

【基本流程】

完成前面的步骤之后,就可以开始全量数据的训练对齐任务了。按照下面的步骤进行训练对齐。

  1. 准备train/eval data, loader, model

  2. 模型初始化

  3. 加载配置,开始训练,迭代得到最终模型与评估指标。

【注意事项】

  1. 【强烈】建议先做完反向对齐之后再进行模型训练对齐,二者之间的不确定量包括:数据集、PaddlePaddle与参考代码在模型training mode下的区别,初始化参数。

  2. 在训练对齐过程中,受到较多随机量的影响,精度有少量diff是正常的,以SST-2数据集的分类为例,diff在0.15%以内可以认为是正常的,这里可以根据不同的任务,适当调整对齐检查的阈值(ReprodDiffHelper.report函数中的diff_threshold参数)。

  3. 训练过程中的波动是正常的,如果最终收敛结果不一致,可以从以下方面进行排查:

  • 仔细排查Dropout、BatchNorm以及其他组网模块及超参是否无误。

  • 基于参考代码随机生成一份预训练模型,转化为PaddlePaddle的模型,并使用PaddlePaddle加载训练,对比二者的收敛曲线与最终结果,排查初始化影响。

  • 使用参考代码的Dataloader生成的数据,进行模型训练,排查train dataloader的影响。

参考文档:

  1. PaddlePaddle/PaDiff

  2. PaddlePaddle/models