本文将通过 PrismaticVLM 多模态架构为例,全面解析如何使用 PyTorch 的 FSDP(Fully Sharded Data Parallel)对 Vision、Projector 与 LLM 模块进行包裹,同时配合冻结策略实现高效稳定的多阶段训练。


对于参数量大于等于 5–7B 及以上的 LLM 主干,DDP 在常规 GPU 硬件上往往会导致显存溢出(OOM)。为了解决这一问题,我们采用 FSDP,其在显存利用效率上具有明显优势。和DDP类似,我们可以这么包装:

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
self.vlm = FSDP(
            self.vlm,
            auto_wrap_policy=vlm_fsdp_wrapping_policy,
            mixed_precision=fsdp_precision_policy,
            sharding_strategy=self.fsdp_sharding_strategy,
            device_id=torch.cuda.current_device(),
            limit_all_gathers=True,
            use_orig_params=True,
        )

然而不同于 DDP,FSDP 需求更高的工程配置:自动 Wrapping 策略、Checkpoint 处理、精度控制都需要手动管理。

一、vision_backbone 使用 float32 的必要性

上一篇已经介绍了VLM分阶段训练的策略:

阶段 Vision Encoder (vision_backbone) LLM (llm_backbone) Projector
align 冻结 ❄️ 冻结 ❄️ 训练 🔥
finetune / vla-train 冻结 ❄️ 训练 🔥 训练 🔥
full-finetune / vla-full-train 训练 🔥(强制 float32) 训练 🔥 训练 🔥
last-layer-finetune 冻结 ❄️ 仅最后层训练 🔥 冻结 ❄️
vla-sandwich-train 训练 🔥(强制 float32) 仅最后层训练 🔥 训练 🔥

在full finetune的阶段,在 AMP(混合精度)模式下,forward 时默认会把它的计算转换为 float16,这可能会导致部分组件如 FrozenBatchNorm2d 或 residual-add 操作数值不稳定:

❗可能出现的问题:

  • LayerNorm 精度不够,输出异常波动

  • Residual Add 下溢,结果不收敛

  • FrozenBatchNorm2d 精度丢失

  • 梯度虽然不会更新,但 forward 会 silent 崩溃

✅ 正确做法:

with torch.no_grad():
    vision_backbone.to(dtype=torch.float32)

 这样可以保证forward 使用 float32 保留稳定性

而在finetune阶段,因为vision不参与训练:

  • 不涉及反向传播中的数值稳定性问题

  • 不存在梯度消失/爆炸的问题

  • 模型的推理部分在低精度下一般是稳定的(特别是 bfloat16

因此:✅ 可以将 vision_backbone 转为半精度,以节省显存、加速前向传播

if self.enable_mixed_precision_training and self.mixed_precision_dtype == torch.bfloat16:
     # MixedPrecision `param_dtype` specifies *compute* dtype (for forward/backward only)
     #   => Reference: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision
     reduce_buffer_dtype = torch.bfloat16 if not self.reduce_in_full_precision else torch.float32
     fsdp_precision_policy = MixedPrecision(
        param_dtype=torch.bfloat16, reduce_dtype=reduce_buffer_dtype, buffer_dtype=reduce_buffer_dtype
     )

      # When running FSDP with a frozen vision backbone --> move to half precision!
      if self.stage not in {"full-finetune", "vla-full-train", "vla-sandwich-train"}:
          overwatch.info("Casting Vision Backbone to *Half Precision* via `.to(dtype=...)`")
          self.vlm.vision_backbone.to(dtype=self.vlm.vision_backbone.half_precision_dtype)

  else:
      # If we're not using mixed precision, everything is in default full precision!
      fsdp_precision_policy = MixedPrecision(
      param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32
      )

无论vision部分用什么精度,其它部分依然可以用半精度混合训练。

autocast_dtype = self.llm_backbone.half_precision_dtype
with torch.autocast("cuda", dtype=autocast_dtype, enabled=self.enable_mixed_precision_training):
     generated_ids = super().generate(
                input_ids=input_ids,            # Shape: [1, seq]
                pixel_values=pixel_values,      # Shape: [1, 3, res, res] or Dict[str, Shape[1, 3, res, res]]
                **kwargs
     )

二、FSDP 包裹策略的三类模块

在配置FSDP的时候,合理使用 自动包裹策略(auto wrap policy) 是显著降低显存占用与优化分布式性能的关键。PrismaticVLM 模型中涉及三类模块:

1. LLM 模块:LlamaDecoderLayer

from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
llm_fsdp_wrapping_policy = partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={LlamaDecoderLayer}
)

该策略递归扫描模型中所有子模块,并在每个 LlamaDecoderLayer 处自动打包。


2. Vision 模块:VisionTransformer

此策略直接包裹整个 VisionTransformer 模块本体。

若要精细到每个 Block,也可以使用:

vit_wrap_policy = partial(
    _module_wrap_policy,  # a custom wrap policy helper
    module_classes={VisionTransformer}
)
transformer_block_policy = partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={Block}
)

vision_fsdp_wrapping_policy =  partial(
            _or_policy,
            policies=[ 
                vit_wrap_policy, 
                transformer_block_policy
                ],
            )

3. Projector 模块:Linear / MLP

一般作为桥接图像与文本向量的 MLP,也可以通过 module_classes 来单独指定:

projector_fsdp_wrapping_policy = partial(
            _module_wrap_policy,
            module_classes={LinearProjector, MLPProjector, FusedMLPProjector},
)

✅ 三者如何协同?

这些策略是互不冲突的,可以通过 _or_policy 组合使用:

from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy
 
vlm_fsdp_wrapping_policy =  partial(
            _or_policy,
            policies=[
                vision_fsdp_wrapping_policy,
                llm_fsdp_wrapping_policy,
                projector_fsdp_wrapping_policy,
                ],
            )

最终结果是:每个核心模块都在适当粒度上被独立包装,避免冗余且利于 shard 化。


三、VLM 模型结构图(FSDP 包裹标记)

PrismaticVLM
├── vision_backbone (VisionTransformer)
│   ├── Block #1     ← wrap
│   ├── Block #2     ← wrap
│   └── ...
├── projector (Linear / MLP) ← wrap
└── llm_backbone (LLaMA)
    ├── LlamaDecoderLayer #1 ← wrap
    ├── LlamaDecoderLayer #2 ← wrap
    └── ... 

每一层的 wrap 都独立处理,在多卡训练时可最大化 shard 和通信效率。

对于vision的两个包裹策略,两者不会冲突,而是协同工作。最终结果是:

  • 外层是 VisionTransformer 的 FSDP 包装。

  • 内层每个 Block 也被 FSDP 包装,实现更细粒度的分片。

除此之外,我们还需要修改和FSDP配套的checkpointing配置:

# Gradient Checkpoint Setup
if self.enable_gradient_checkpointing:
    # For Gradient Checkpointing under FSDP --> we make the same assumption as in the DDP/other strategies; the
    #   bulk of activation memory is taken up by the LLM activations. However, unlike other strategies, we
    #   cannot rely on the HF Transformers default `gradient_checkpointing_enable()` --> FSDP breaks semantics!
    # Instead, we need to write our own *NO-REENTRANT* wrapper, and apply it to the LLM's Transformer Layer.
    non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT)

    def check_fn(submodule: nn.Module) -> bool:
        return isinstance(submodule, self.llm_transformer_layer_cls)

    # Note that the terms "activation checkpointing" and "gradient checkpointing" are synonymous!
    apply_activation_checkpointing(self.vlm, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn)
def clip_grad_norm(self) -> None:
    # Note =>> FSDP uses a custom `clip_grad_norm_` function; requires *uniform grad dtype*
    # `clip_grad_norm_` is added by FSDP wrapper, not a method of the orignal VLM class.
    self.vlm.clip_grad_norm_(max_norm=self.max_grad_norm)

四、结论与实践建议

模块 包裹建议 精度建议
Vision 每个 Block 包裹 需要训练时使用 float32
Projector 整体包裹 可使用 AMP
LLM 每个 Layer 包裹 可使用 AMP,部分层可训练
  • FSDP + AMP 是多模态训练中的标配组合

  • ✅ 模块粒度包裹有利于显存优化与调试追踪

  • ✅ 相比于DDP,FSDP 需求更高的工程配置

Logo

助力广东及东莞地区开发者,代码托管、在线学习与竞赛、技术交流与分享、资源共享、职业发展,成为松山湖开发者首选的工作与学习平台

更多推荐