本文将通过 PrismaticVLM 多模态架构为例,全面解析如何使用 PyTorch 的 FSDP(Fully Sharded Data Parallel)对 Vision、Projector 与 LLM 模块进行包裹,同时配合冻结策略实现高效稳定的多阶段训练。


一、FSDP 包裹策略的三类模块

在大模型训练中,合理使用 FSDP 自动包裹策略(auto wrap policy) 是显著降低显存占用与优化分布式性能的关键。PrismaticVLM 模型中涉及三类模块:

1. LLM 模块:LlamaDecoderLayer

from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
llm_fsdp_wrapping_policy = partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={LlamaDecoderLayer}
)

该策略递归扫描模型中所有子模块,并在每个 LlamaDecoderLayer 处自动打包。


2. Vision 模块:VisionTransformer

此策略直接包裹整个 VisionTransformer 模块本体。

若要精细到每个 Block,也可以使用:

vit_wrap_policy = partial(
    _module_wrap_policy,  # a custom wrap policy helper
    module_classes={VisionTransformer}
)
transformer_block_policy = partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={Block}
)

vision_fsdp_wrapping_policy =  partial(
            _or_policy,
            policies=[ 
                vit_wrap_policy, 
                transformer_block_policy
                ],
            )

3. Projector 模块:Linear / MLP

一般作为桥接图像与文本向量的 MLP,也可以通过 module_classes 来单独指定:

projector_fsdp_wrapping_policy = partial(
            _module_wrap_policy,
            module_classes={LinearProjector, MLPProjector, FusedMLPProjector},
)

✅ 三者如何协同?

这些策略是互不冲突的,可以通过 _or_policy 组合使用:

from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy
 
policies =  partial(
            _or_policy,
            policies=[
                vision_fsdp_wrapping_policy,
                llm_fsdp_wrapping_policy,
                projector_fsdp_wrapping_policy,
                ],
            )

最终结果是:每个核心模块都在适当粒度上被独立包装,避免冗余且利于 shard 化。


二、VLM 模型结构图(FSDP 包裹标记)

PrismaticVLM
├── vision_backbone (VisionTransformer)
│   ├── Block #1     ← wrap
│   ├── Block #2     ← wrap
│   └── ...
├── projector (Linear / MLP) ← wrap
└── llm_backbone (LLaMA)
    ├── LlamaDecoderLayer #1 ← wrap
    ├── LlamaDecoderLayer #2 ← wrap
    └── ... 

每一层的 wrap 都独立处理,在多卡训练时可最大化 shard 和通信效率。

对于vision的两个包裹策略,两者不会冲突,而是协同工作。最终结果是:

  • 外层是 VisionTransformer 的 FSDP 包装。

  • 内层每个 Block 也被 FSDP 包装,实现更细粒度的分片。


三、vision_backbone 使用 float32 的必要性

上一篇已经介绍了VLM分阶段训练的策略:

阶段 Vision Encoder (vision_backbone) LLM (llm_backbone) Projector
align 冻结 ❄️ 冻结 ❄️ 训练 🔥
finetune / vla-train 冻结 ❄️ 训练 🔥 训练 🔥
full-finetune / vla-full-train 训练 🔥(强制 float32) 训练 🔥 训练 🔥
last-layer-finetune 冻结 ❄️ 仅最后层训练 🔥 冻结 ❄️
vla-sandwich-train 训练 🔥(强制 float32) 仅最后层训练 🔥 训练 🔥

但是在finetune的阶段,即使一个模块被冻结,它依然会经历前向传播。而在 AMP(混合精度)模式下,forward 时默认会把它的计算转换为 float16,这可能会导致部分组件如 FrozenBatchNorm2d 或 residual-add 操作数值不稳定:

❗可能出现的问题:

  • LayerNorm 精度不够,输出异常波动

  • Residual Add 下溢,结果不收敛

  • FrozenBatchNorm2d 精度丢失

  • 梯度虽然不会更新,但 forward 会 silent 崩溃

✅ 正确做法:

with torch.no_grad():
    vision_backbone.to(dtype=torch.float32)
autocast_dtype = self.llm_backbone.half_precision_dtype
with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
     generated_ids = super().generate(
                input_ids=input_ids,            # Shape: [1, seq]
                pixel_values=pixel_values,      # Shape: [1, 3, res, res] or Dict[str, Shape[1, 3, res, res]]
                **kwargs
     )

 这样可以保证即使 frozen,forward 使用 float32 保留稳定性


四、结论与实践建议

模块 包裹建议 精度建议
Vision 每个 Block 包裹 冻结时使用 float32
Projector 整体包裹 可使用 AMP
LLM 每个 Layer 包裹 可使用 AMP,部分层可训练
  • FSDP + AMP 是多模态训练中的标配组合

  • ✅ 模块粒度包裹有利于显存优化与调试追踪

  • ✅ 模块冻结也需要 dtype 管理,非“无脑冻结”

Logo

助力广东及东莞地区开发者,代码托管、在线学习与竞赛、技术交流与分享、资源共享、职业发展,成为松山湖开发者首选的工作与学习平台

更多推荐