Skip to content

rwkv7 训练踩坑

2026-04-06 · 982字 · 4分钟 · 浏览量

最近放假,正好尝试一下训练一个小语言模型。23 年的时候就耳闻 RWKV4,但囿于其生态、源码混乱以及自己码力不足等等问题一直未尝试,今天来尝试从头预训练一个 RWKV7。

RWKV7 在 github 上有两个仓库提供预训练代码:一个是 Bo 的原仓库 BlinkDL/RWKV-LM,另一个是社区仓库 RWKV-Vibe/RWKV-LM-V7,起先我拉取的是后者,因为看似后者的代码更新一些。

结果不出所料跑不起来,我只能换成 Bo 的代码,以便直接在 discord 上请教之。仍然跑不起来。下面是 debug 的记录:

  1. 直接采用原代码默认的 deepspeed_stage_2 策略,训练框架在通过 deepspeed 初始化 TorchBackend 时就出现 SegFault,考虑在训练代码前添加
py
import faulthandler
faulthandler.enable()

打印 SegFault 前的最后的调用栈。发现是 torch.broadcast 的问题。于是考虑运行下面的测试脚本:

py
import torch
import torch.distributed as dist
import os

os.environ['NCCL_IB_DISABLE'] = '1'
    # 模拟 Lightning 的 DDP/DeepSpeed 初始化环境
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

def test_broadcast():
    # 这里的 backend 换成 "nccl" (GPU) 或 "gloo" (CPU)
    # 如果 nccl 崩了,试试 gloo
    dist.init_process_group(backend="nccl", rank=0, world_size=1)
    
    # 模拟 Lightning 的 log_dir 字符串广播
    test_path = "/home/ids/smao-22/phd/chongmo/out/test_dir"
    object_list = [test_path]
    
    print(f"Before broadcast: {object_list}")
    try:
        # Lightning 内部实际调用的就是这个
        dist.broadcast_object_list(object_list, src=0)
        print(f"After broadcast: {object_list}")
    except Exception as e:
        print(f"Broadcast failed with error: {e}")
    
    dist.destroy_process_group()

def test_tensor_broadcast():
    dist.init_process_group(backend="nccl", rank=0, world_size=1)
    
    # 创建一个简单的 Tensor,不涉及 Python 对象序列化
    x = torch.ones(10).cuda()
    print("Tensor before broadcast:", x)
    
    try:
        # 直接广播原始内存数据
        dist.broadcast(x, src=0)
        print("Tensor broadcast success!")
    except Exception as e:
        print(f"Tensor broadcast failed: {e}")
    
    dist.destroy_process_group()

if __name__ == "__main__":
    test_tensor_broadcast()
    test_broadcast()

报错,确认是 broadcast 问题。尝试禁用os.environ['NCCL_IB_DISABLE'] = '1'以后,可以运行。IB 通信对多机分布式训练有影响,但对单机分布式依靠 NVLink,所以关闭以后问题不大。

  1. 在 debug 上述问题时,我首先怀疑是分布式训练问题,于是把ligthning的策略从deepspeed_stage_2改成了auto,结果报AssertionError,发现原来 RWKV7 的前后向过程都默认采用 bf16 精度,而在其自定义的 cuda 算子中还开启了精度检查。最终确定是lightningauto策略会在precision='bf16'时自动开启混合精度训练,即自动在前向过程中开启torch.autocast。为此我在其training_step添加了关闭混合精度训练的装饰器:
py
@torch.autocast('cuda',enabled=False)
def training_step(self, batch, batch_idx):
    ...

但在开启 deepspeed 策略以后可以注释掉装饰器,因为 deepspeed 将接管精度控制。

此后便可正常训练了,以下是一些数据:

  • [x] 单机单卡训练,单机单卡下原训练脚本有个 bug,即把 strategy 改成 auto 以后会自动在 forward pass autocast 到 float32,导致RWKV7_CLAMPW_CUDA_OP内部 assertion error,需在 training_step 外加装饰器`
    • 设备:A40 46GiB:
      • bs=8,ctx_len=4096,GRAD_CP=0,显存占用约 32GiB,1h/epoch,需要 294epochs 才能跑完一遍数据集。
      • bs=64,ctx_len=4096,GRAD_CP=1,显存占用约 32GiB,1h/epoch,需要 294epochs 才能跑完一遍数据集。
  • [x] 单机单卡 deepspeed 训练:目前有 torch broadcast 的 bug 待修复。服务器的 nccl ib 通信有问题,设置os.environ['NCCL_IB_DISABLE'] = '1'将其关闭即可进行deepspeed_stage_2训练
    • 设备:A100 80GiB:
      • bs=64,ctx_len=4096,GRAD_CP=1,显存占用约 34GiB,30min/epoch,需要 294epochs 才能跑完一遍数据集。
      • bs=128,ctx_len=4096,GRAD_CP=1,显存占用约 66GiB,30min/epoch,需要 294epochs 才能跑完一遍数据集。
  • [x] 单机多卡 deepspeed 训练:目前有 torch broadcast 的 bug 待修复。服务器的 nccl ib 通信有问题,设置os.environ['NCCL_IB_DISABLE'] = '1'将其关闭即可进行deepspeed_stage_2训练
    • 设备:8*A100 40GiB:
      • bs=256(32*8),ctx_len=4096,GRAD_CP=1,显存占用约 18GiB,5min/epoch,需要 294epochs 才能跑完一遍数据集,一天足矣。

注意到开启梯度检查器后,显存占用大幅下降,而训练速度几乎不变。

返回

人同此心,心同此理;如风沐面,若水润心