多模态学习系列(二):FSDP 与混合训练的技巧
本文将通过 VLM 多模态架构为例,全面解析如何使用 PyTorch 的 FSDP(Fully Sharded Data Parallel)对 Vision、Projector 与 LLM 模块进行包裹,同时配合冻结策略实现高效稳定的多阶段训练。
本文将通过 PrismaticVLM 多模态架构为例,全面解析如何使用 PyTorch 的 FSDP(Fully Sharded Data Parallel)对 Vision、Projector 与 LLM 模块进行包裹,同时配合冻结策略实现高效稳定的多阶段训练。
一、FSDP 包裹策略的三类模块
在大模型训练中,合理使用 FSDP 自动包裹策略(auto wrap policy) 是显著降低显存占用与优化分布式性能的关键。PrismaticVLM 模型中涉及三类模块:
1. LLM 模块:LlamaDecoderLayer
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
llm_fsdp_wrapping_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer}
)
该策略递归扫描模型中所有子模块,并在每个 LlamaDecoderLayer
处自动打包。
2. Vision 模块:VisionTransformer
此策略直接包裹整个 VisionTransformer 模块本体。
若要精细到每个 Block,也可以使用:
vit_wrap_policy = partial(
_module_wrap_policy, # a custom wrap policy helper
module_classes={VisionTransformer}
)
transformer_block_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={Block}
)
vision_fsdp_wrapping_policy = partial(
_or_policy,
policies=[
vit_wrap_policy,
transformer_block_policy
],
)
3. Projector 模块:Linear / MLP
一般作为桥接图像与文本向量的 MLP,也可以通过 module_classes
来单独指定:
projector_fsdp_wrapping_policy = partial(
_module_wrap_policy,
module_classes={LinearProjector, MLPProjector, FusedMLPProjector},
)
✅ 三者如何协同?
这些策略是互不冲突的,可以通过 _or_policy
组合使用:
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy
policies = partial(
_or_policy,
policies=[
vision_fsdp_wrapping_policy,
llm_fsdp_wrapping_policy,
projector_fsdp_wrapping_policy,
],
)
最终结果是:每个核心模块都在适当粒度上被独立包装,避免冗余且利于 shard 化。
二、VLM 模型结构图(FSDP 包裹标记)
PrismaticVLM
├── vision_backbone (VisionTransformer)
│ ├── Block #1 ← wrap
│ ├── Block #2 ← wrap
│ └── ...
├── projector (Linear / MLP) ← wrap
└── llm_backbone (LLaMA)
├── LlamaDecoderLayer #1 ← wrap
├── LlamaDecoderLayer #2 ← wrap
└── ...
每一层的 wrap 都独立处理,在多卡训练时可最大化 shard 和通信效率。
对于vision的两个包裹策略,两者不会冲突,而是协同工作。最终结果是:
外层是
VisionTransformer
的 FSDP 包装。内层每个
Block
也被 FSDP 包装,实现更细粒度的分片。
三、vision_backbone 使用 float32 的必要性
上一篇已经介绍了VLM分阶段训练的策略:
阶段 | Vision Encoder (vision_backbone ) |
LLM (llm_backbone ) |
Projector |
---|---|---|---|
align |
冻结 ❄️ | 冻结 ❄️ | 训练 🔥 |
finetune / vla-train |
冻结 ❄️ | 训练 🔥 | 训练 🔥 |
full-finetune / vla-full-train |
训练 🔥(强制 float32) | 训练 🔥 | 训练 🔥 |
last-layer-finetune |
冻结 ❄️ | 仅最后层训练 🔥 | 冻结 ❄️ |
vla-sandwich-train |
训练 🔥(强制 float32) | 仅最后层训练 🔥 | 训练 🔥 |
但是在finetune的阶段,即使一个模块被冻结,它依然会经历前向传播。而在 AMP(混合精度)模式下,forward 时默认会把它的计算转换为 float16,这可能会导致部分组件如 FrozenBatchNorm2d
或 residual-add 操作数值不稳定:
❗可能出现的问题:
-
LayerNorm
精度不够,输出异常波动 -
Residual Add
下溢,结果不收敛 -
FrozenBatchNorm2d
精度丢失 -
梯度虽然不会更新,但 forward 会 silent 崩溃
✅ 正确做法:
with torch.no_grad():
vision_backbone.to(dtype=torch.float32)
autocast_dtype = self.llm_backbone.half_precision_dtype
with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
generated_ids = super().generate(
input_ids=input_ids, # Shape: [1, seq]
pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, Shape[1, 3, res, res]]
**kwargs
)
这样可以保证即使 frozen,forward 使用 float32 保留稳定性。
四、结论与实践建议
模块 | 包裹建议 | 精度建议 |
---|---|---|
Vision | 每个 Block 包裹 | 冻结时使用 float32 |
Projector | 整体包裹 | 可使用 AMP |
LLM | 每个 Layer 包裹 | 可使用 AMP,部分层可训练 |
-
✅
FSDP + AMP
是多模态训练中的标配组合 -
✅ 模块粒度包裹有利于显存优化与调试追踪
-
✅ 模块冻结也需要 dtype 管理,非“无脑冻结”
更多推荐
所有评论(0)