多模态学习系列(三):混合训练与FSDP 的技巧
本文将通过 PrismaticVLM 多模态架构为例,全面解析如何使用 PyTorch 的 FSDP(Fully Sharded Data Parallel)对 Vision、Projector 与 LLM 模块进行包裹,同时配合冻结策略实现高效稳定的多阶段训练。
对于参数量大于等于 5–7B 及以上的 LLM 主干,DDP 在常规 GPU 硬件上往往会导致显存溢出(OOM)。为了解决这一问题,我们采用 FSDP,其在显存利用效率上具有明显优势。和DDP类似,我们可以这么包装:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
self.vlm = FSDP(
self.vlm,
auto_wrap_policy=vlm_fsdp_wrapping_policy,
mixed_precision=fsdp_precision_policy,
sharding_strategy=self.fsdp_sharding_strategy,
device_id=torch.cuda.current_device(),
limit_all_gathers=True,
use_orig_params=True,
)
然而不同于 DDP,FSDP 需求更高的工程配置:自动 Wrapping 策略、Checkpoint 处理、精度控制都需要手动管理。
一、vision_backbone 使用 float32 的必要性
上一篇已经介绍了VLM分阶段训练的策略:
| 阶段 | Vision Encoder (vision_backbone) |
LLM (llm_backbone) |
Projector |
|---|---|---|---|
align |
冻结 ❄️ | 冻结 ❄️ | 训练 🔥 |
finetune / vla-train |
冻结 ❄️ | 训练 🔥 | 训练 🔥 |
full-finetune / vla-full-train |
训练 🔥(强制 float32) | 训练 🔥 | 训练 🔥 |
last-layer-finetune |
冻结 ❄️ | 仅最后层训练 🔥 | 冻结 ❄️ |
vla-sandwich-train |
训练 🔥(强制 float32) | 仅最后层训练 🔥 | 训练 🔥 |
在full finetune的阶段,在 AMP(混合精度)模式下,forward 时默认会把它的计算转换为 float16,这可能会导致部分组件如 FrozenBatchNorm2d 或 residual-add 操作数值不稳定:
❗可能出现的问题:
-
LayerNorm精度不够,输出异常波动 -
Residual Add下溢,结果不收敛 -
FrozenBatchNorm2d精度丢失 -
梯度虽然不会更新,但 forward 会 silent 崩溃
✅ 正确做法:
with torch.no_grad():
vision_backbone.to(dtype=torch.float32)
这样可以保证forward 使用 float32 保留稳定性。
而在finetune阶段,因为vision不参与训练:
-
不涉及反向传播中的数值稳定性问题
-
不存在梯度消失/爆炸的问题
-
模型的推理部分在低精度下一般是稳定的(特别是
bfloat16)
因此:✅ 可以将 vision_backbone 转为半精度,以节省显存、加速前向传播。
if self.enable_mixed_precision_training and self.mixed_precision_dtype == torch.bfloat16:
# MixedPrecision `param_dtype` specifies *compute* dtype (for forward/backward only)
# => Reference: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision
reduce_buffer_dtype = torch.bfloat16 if not self.reduce_in_full_precision else torch.float32
fsdp_precision_policy = MixedPrecision(
param_dtype=torch.bfloat16, reduce_dtype=reduce_buffer_dtype, buffer_dtype=reduce_buffer_dtype
)
# When running FSDP with a frozen vision backbone --> move to half precision!
if self.stage not in {"full-finetune", "vla-full-train", "vla-sandwich-train"}:
overwatch.info("Casting Vision Backbone to *Half Precision* via `.to(dtype=...)`")
self.vlm.vision_backbone.to(dtype=self.vlm.vision_backbone.half_precision_dtype)
else:
# If we're not using mixed precision, everything is in default full precision!
fsdp_precision_policy = MixedPrecision(
param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32
)
无论vision部分用什么精度,其它部分依然可以用半精度混合训练。
autocast_dtype = self.llm_backbone.half_precision_dtype
with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
generated_ids = super().generate(
input_ids=input_ids, # Shape: [1, seq]
pixel_values=pixel_values, # Shape: [1, 3, res, res] or Dict[str, Shape[1, 3, res, res]]
**kwargs
)
二、FSDP 包裹策略的三类模块
在配置FSDP的时候,合理使用 自动包裹策略(auto wrap policy) 是显著降低显存占用与优化分布式性能的关键。PrismaticVLM 模型中涉及三类模块:
1. LLM 模块:LlamaDecoderLayer
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
llm_fsdp_wrapping_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer}
)
该策略递归扫描模型中所有子模块,并在每个 LlamaDecoderLayer 处自动打包。
2. Vision 模块:VisionTransformer
此策略直接包裹整个 VisionTransformer 模块本体。
若要精细到每个 Block,也可以使用:
vit_wrap_policy = partial(
_module_wrap_policy, # a custom wrap policy helper
module_classes={VisionTransformer}
)
transformer_block_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={Block}
)
vision_fsdp_wrapping_policy = partial(
_or_policy,
policies=[
vit_wrap_policy,
transformer_block_policy
],
)
3. Projector 模块:Linear / MLP
一般作为桥接图像与文本向量的 MLP,也可以通过 module_classes 来单独指定:
projector_fsdp_wrapping_policy = partial(
_module_wrap_policy,
module_classes={LinearProjector, MLPProjector, FusedMLPProjector},
)
✅ 三者如何协同?
这些策略是互不冲突的,可以通过 _or_policy 组合使用:
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy
vlm_fsdp_wrapping_policy = partial(
_or_policy,
policies=[
vision_fsdp_wrapping_policy,
llm_fsdp_wrapping_policy,
projector_fsdp_wrapping_policy,
],
)
最终结果是:每个核心模块都在适当粒度上被独立包装,避免冗余且利于 shard 化。
三、VLM 模型结构图(FSDP 包裹标记)
PrismaticVLM
├── vision_backbone (VisionTransformer)
│ ├── Block #1 ← wrap
│ ├── Block #2 ← wrap
│ └── ...
├── projector (Linear / MLP) ← wrap
└── llm_backbone (LLaMA)
├── LlamaDecoderLayer #1 ← wrap
├── LlamaDecoderLayer #2 ← wrap
└── ...
每一层的 wrap 都独立处理,在多卡训练时可最大化 shard 和通信效率。
对于vision的两个包裹策略,两者不会冲突,而是协同工作。最终结果是:
外层是
VisionTransformer的 FSDP 包装。内层每个
Block也被 FSDP 包装,实现更细粒度的分片。
除此之外,我们还需要修改和FSDP配套的checkpointing配置:
# Gradient Checkpoint Setup
if self.enable_gradient_checkpointing:
# For Gradient Checkpointing under FSDP --> we make the same assumption as in the DDP/other strategies; the
# bulk of activation memory is taken up by the LLM activations. However, unlike other strategies, we
# cannot rely on the HF Transformers default `gradient_checkpointing_enable()` --> FSDP breaks semantics!
# Instead, we need to write our own *NO-REENTRANT* wrapper, and apply it to the LLM's Transformer Layer.
non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT)
def check_fn(submodule: nn.Module) -> bool:
return isinstance(submodule, self.llm_transformer_layer_cls)
# Note that the terms "activation checkpointing" and "gradient checkpointing" are synonymous!
apply_activation_checkpointing(self.vlm, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn)
def clip_grad_norm(self) -> None:
# Note =>> FSDP uses a custom `clip_grad_norm_` function; requires *uniform grad dtype*
# `clip_grad_norm_` is added by FSDP wrapper, not a method of the orignal VLM class.
self.vlm.clip_grad_norm_(max_norm=self.max_grad_norm)
四、结论与实践建议
| 模块 | 包裹建议 | 精度建议 |
|---|---|---|
| Vision | 每个 Block 包裹 | 需要训练时使用 float32 |
| Projector | 整体包裹 | 可使用 AMP |
| LLM | 每个 Layer 包裹 | 可使用 AMP,部分层可训练 |
-
✅
FSDP + AMP是多模态训练中的标配组合 -
✅ 模块粒度包裹有利于显存优化与调试追踪
-
✅ 相比于DDP,FSDP 需求更高的工程配置
更多推荐



所有评论(0)