首先叠甲:
台大李宏毅讲解Transformer
这个看不了可以在bilibili上搜 讲的非常非常好 全网讲解transformer最好的 可以听完这个课再来看这篇文章
如果我的理解有问题的话 也欢迎大家在评论区指出更正!

我对Transformer的理解

首先要理解 transformer 在大模型开发中的作用,transformer是用于处理类似文本,语音,代码这样的序列数据的架构,在transformer中,有一个关键名词:注意力机制self-attention,那transformer其实就是用注意力机制代替原有的循环神经网络(RNN/LSTM),让大模型能够并行处理序列,学习长距离依赖关系。

原有神经网络,RNN/LSTM的痛点是什么?

1. 无法并行,只能从左到右读

因为我是java开发,所以这里我用java视角来举例:
RNN本质上是一个方法``step()``,必须一个词一个词顺序来执行。
public class RNN {
    RNNState hidden;
    public void process(List<String> words) {
        for (String w : words) {
            hidden = step(w, hidden);   // 必须按顺序执行
        }
    }
    
    public RNNState step(String word, RNNState prev) {
        // 词 + 上一个状态 = 新状态
        return tanh(Wx * word + Wh * prev);
    }
}

通俗例子就是:
接收到一段话:我 / 今天 / 去 / 天职师大 / 开会
RNN/LSTM的执行顺序必须严格为:
step("我") -> step("今天") -> step("去") -> step("天职师大") -> step("开会")
因为在例子中

“今天”的计算必须依赖“我”的结果,不能跳。

无法并行! 那这样的话,会导致长序列训练会非常非常慢,同时对于大模型训练完全不可行

2. 长距离依赖能力弱

RNN的信息传递方式如下:
h1 → h2 → h3 → h4 → h5 → …
词之间相隔越远,信息越容易被稀释、丢失
通俗一点的例子就是:
句子:

“尽管下雨,他还是去了学校。”

我们想判断:
“他” 是否与前面“下雨”的语境有关?

在 RNN 里:

“尽管” 信息传递给 h1
“下雨” 信息传递给 h2

“他” 直到 h5

路径大概是这样:

State h1 = step("尽管", init);
State h2 = step("下雨", h1);
State h3 = step(",", h2);
State h4 = step("他", h3); // 想理解“他”与“下雨”的关系,但信息已被稀释

“他”要找“下雨”的信息,需要倒推 h3 h4,中间信息已经被压缩很多次,导致:

  • 代词指代困难
  • 上下文关联弱
  • 长句理解能力很差

如果上面的例子还是有点难懂的话,我再举例:
就像前期的对话式大模型,有时候你跟它聊天,告诉了你的名字,但是聊的多了,你再问它,我叫什么名字,它却答不出来
这就是 长距离依赖被破坏的现象
JAVA伪代码:

State h = init();
h = step("我", h);       // h1
h = step("叫", h);       // h2
h = step("小明", h);     // h3   ← 这里记录了“用户叫小明”

// ... 过了很久 ...
for (int i = 0; i < 200; i++) {
    h = step("其他对话内容", h);  // h200
}

// 想要问:我叫什么?
String answer = model.predict("我叫什么名字?", h);

此时 h 已经经过 200 次压缩变换,“小明”相关的信息已经被稀释或者完全丢掉。

3. 梯度传播困难(梯度消失 / 梯度爆炸)

首先来说一下什么是梯度传播:
在训练神经网络模型的时候,就是在不断的更新参数或者说是权重
那么 更新的方式来自:

梯度 = 损失对参数的偏导数

简单来说:就是梯度告诉你 参数应该往哪个方向改,改多少?
在深度网络中,梯度需要从输出层 一路反向到最前面的层 这就是反向传播
那么当网络非常深,序列很长时,梯度就需要经过很多层或者说时间步
例子就是:一句话要一个人传一个人一样 会出现:

越传越小 → 信息消失 → 梯度消失
越传越大 → 信息夸张 → 梯度爆炸

那么就会导致模型训练不稳定非常严重的问题!

如果上面的概念讲解不是很明白的话 我再举一个通俗易懂的例子:
假如人物A 要告诉人物F 今晚来他家吃饭
关系链很长:
A → B → C → D → E → F
人物A要先告诉B,然后B告诉C,这样一直说下去;
那么可能会出现两种问题:

  1. 梯度消失问题→ 信息传输的越来越弱

传到 B 的时候变成:晚上吃饭。
传到 D 的时候变成:吃饭。
传到 F 的时候: “…饭?”

最终 信息越来越弱,最后基本消失。
RNN/LSTM 在反向传播时,梯度经过大量乘法 → 越乘越小 → 趋近 0 → 模型无法学习到最早的数据。
2. 梯度爆炸→ 信息越来越夸张
如果信息每传一次都会「被放大」。

A说:
“今晚 7 点吃饭。”
传到 D :
“今晚必须来吃大餐!”
传到 F:
“今晚一定要来参加超豪华宴会,不来会生气!”

信息逐层被放大,最终 梯度变得特别大,模型会:

  • 训练极度不稳定
  • loss 震荡
  • 参数发散
    这是因为梯度在链式反向传播时一直被乘以大于 1 的系数,越乘越大。

那这样来看,想要模型发展,就要解决这些痛点,那么transformer就是为解决这些问题而生的
持续更新中

Logo

助力广东及东莞地区开发者,代码托管、在线学习与竞赛、技术交流与分享、资源共享、职业发展,成为松山湖开发者首选的工作与学习平台

更多推荐