Transformer和GPT原理详解

基于 "Attention is All You Need" 论文的交互式学习工具

1. 为什么需要Transformer?

1.1 传统模型的局限性

模型类型 核心机制 优点 缺点 计算复杂度
RNN/LSTM 循环连接 • 天然处理序列
• 参数共享
• 无法并行计算
• 长距离依赖困难
• 梯度消失/爆炸
$O(n \cdot d^2)$
CNN 卷积操作 • 可并行计算
• 局部特征提取
• 感受野受限
• 需要多层堆叠
• 位置信息丢失
$O(k \cdot n \cdot d^2)$
Transformer 自注意力 • 完全并行化
• 全局依赖建模
• 动态权重
• 内存占用大
• 需要位置编码
$O(n^2 \cdot d)$

1.2 三种架构的信息流对比

2. Self-Attention 机制详解

2.1 核心概念

Query (Q) - 查询向量
表示当前位置需要关注的信息,用于计算与其他位置的相关性。
Key (K) - 键向量
表示每个位置的特征表示,用于被Query查询。
Value (V) - 值向量
表示每个位置的实际信息内容,根据注意力权重进行加权求和。

2.2 数学公式

Self-Attention计算公式:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

其中:

  • • $Q \in \mathbb{R}^{n \times d_k}$:查询矩阵
  • • $K \in \mathbb{R}^{n \times d_k}$:键矩阵
  • • $V \in \mathbb{R}^{n \times d_v}$:值矩阵
  • • $n$:序列长度
  • • $d_k$:键的维度
  • • $\sqrt{d_k}$:缩放因子,防止点积过大导致梯度消失

2.3 计算步骤演示

1 2 3 4 5

3. 多头注意力机制(Multi-Head Attention)

多头注意力公式:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

其中每个注意力头计算为:

$$\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)$$

参数维度:

  • • $W_i^Q \in \mathbb{R}^{d_{model} \times d_k}$
  • • $W_i^K \in \mathbb{R}^{d_{model} \times d_k}$
  • • $W_i^V \in \mathbb{R}^{d_{model} \times d_v}$
  • • $W^O \in \mathbb{R}^{hd_v \times d_{model}}$
  • • 通常设置:$d_k = d_v = d_{model} / h$
4

4. 位置编码(Positional Encoding)

正弦位置编码公式:

$$PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$ $$PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right)$$

其中:

  • • $pos$:位置索引(0, 1, 2, ...)
  • • $i$:维度索引
  • • $d_{model}$:模型维度
序列长度: 20
嵌入维度: 64

5. Transformer完整架构

Encoder Decoder 完整架构

编码器(Encoder)

编码器由N个相同的层堆叠而成,每层包含两个子层:

  1. 多头自注意力机制(Multi-Head Self-Attention)
  2. 位置全连接前馈网络(Position-wise Feed-Forward Network)

前馈网络:

$$\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2$$

其中:$W_1 \in \mathbb{R}^{d_{model} \times d_{ff}}$,$W_2 \in \mathbb{R}^{d_{ff} \times d_{model}}$

通常设置:$d_{ff} = 4 \times d_{model}$

残差连接与层归一化:

每个子层都使用残差连接和层归一化:

$$\text{LayerNorm}(x + \text{Sublayer}(x))$$

解码器(Decoder)

解码器也由N个相同的层堆叠而成,每层包含三个子层:

  1. Masked多头自注意力(防止看到未来信息)
  2. 编码器-解码器注意力(Cross-Attention)
  3. 位置全连接前馈网络
Mask机制:

解码器的自注意力使用下三角掩码矩阵,确保位置$i$只能关注位置$\leq i$的信息。

完整架构流程

输入嵌入
+ 位置编码
Encoder × N
Decoder × N
线性层 + Softmax
输出概率

6. GPT(Generative Pre-trained Transformer)

6.1 GPT vs 标准Transformer

特性 标准Transformer GPT
架构 编码器-解码器 仅解码器
注意力模式 编码器双向,解码器单向 单向(自回归)
预训练任务 监督学习(如翻译) 无监督语言建模
应用 序列到序列任务 文本生成、理解任务

6.2 自回归语言建模

训练目标:

$$L(\theta) = -\sum_{i=1}^{n} \log P(x_i | x_1, x_2, ..., x_{i-1}; \theta)$$

模型学习预测序列中的下一个词,给定所有之前的词。

6.3 GPT生成过程演示

输入: The cat

已生成:

当前预测:

1.0

7. 完整计算示例

7.1 简化示例:处理 "我爱你"

Step 1: 词嵌入(假设维度=4)

0.5
0.8
-0.3
0.2
0.3
-0.5
0.9
0.1
-0.2
0.7
0.4
0.6

8. PyTorch实现

8.1 Self-Attention实现

import torch import torch.nn as nn import torch.nn.functional as F import math class SelfAttention(nn.Module): def __init__(self, d_model, n_heads=8): super().__init__() self.d_model = d_model self.n_heads = n_heads self.d_k = d_model // n_heads self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def forward(self, x, mask=None): batch_size, seq_len, _ = x.shape # 1. 计算Q, K, V Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2) # 2. 计算注意力分数 scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k) # 3. 应用mask(如果有) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) # 4. Softmax attention_weights = F.softmax(scores, dim=-1) # 5. 加权求和 context = torch.matmul(attention_weights, V) # 6. 合并多头 context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model) # 7. 输出投影 output = self.W_o(context) return output, attention_weights

8.2 位置编码实现

class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1).float() div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1)]

9. 模型参数计算器

层数 (N): 12
模型维度 (d_model): 768
注意力头数 (h): 12
词表大小 (V): 50000

参数统计

嵌入层:38.4M

注意力层:28.3M

前馈网络:56.6M

层归一化:0.02M

总参数量:123.3M