CS336 Assignment 1: Transformers Language Model Architecture

Feb. 03, 2026 · Qiyao Wang

Transformer Basics

语言模型的输入是经过 tokenizer 分词后的整数 token IDs 序列 (batch化),形状为 (batch_size, sequence_length),输出为正则化的词典中 token 的概率分布,形状为(batch_size, sequence_length, vocab_size)。训练语言模型时,计算输出概率分布和实际标签分布之间的交叉熵损失(cross-entropy loss)。在推理阶段,则使用当前时间步的下一个 token 的分布来采样解码得到预测 token,之后该 token 又会加入输入序列中继续预测下一时间步的 token。

基于Transformer架构的语言模型主要包括以下几个组件和步骤(见图1):给定 token IDs 序列,经过 Embedding Layer 将整数的 token ID 映射成稠密的向量,之后传入 num_layers 层的 Transformer blocks,取最后一层 Transformer block 的输出,通过一个 FFN(一般称为 LM head)来产生词典级别的输出 token 概率分布,之后进行解码采样得到预测的 token。

Transformer
图1:(Figure 1) Transformer 架构图; (Figure 2) Pre-norm Transformer Block.

Token Embeddings: (batch_size, sequence_length) $\to$ (batch_size, sequence_length, d_model).

发生在每个Transformer Block 中的 Attention 计算之前的 Normalization,Pre-norm Transformer Block: (batch_size, sequence_length, d_model) $\to$ (batch_size, sequence_length, d_model).

Output Normalization 会在最后一个 Transformer Block 后进行 Layer Normalization。

Efficient Computation

Transformer 架构中一般进行 batch-level 的并行计算。在具体分析的时候主要关注下面三个元素:

  • Elements of a batch: 每一次 forward computation 在一个 batch_size 大小的数据中进行,forward依次作用到 batch 中的每一个数据中。
  • Sequence length: 对于 RMSNorm 和 Feed-forward 等 point-wise 的位置无关操作,在序列中每个位置的处理都一样。
  • Attention heads: 注意力操作可以并行,也是 Transformer 可以并行的基础之一,并且可以基于 batch 来处理多头注意力操作。
举例:data tensor D 的形状为 (batch_size, sequence_length, d_model),在其上进行 batched vector-matrix multiply,矩阵为 A,形状是 (d_model, d_model),则 D @ A 会进行 batched matrix multiply,其中 (batch_size, sequence_length) 是批处理的。

使用 einsum notation 进行更方便的批处理。PyTorch 中是 torch.einsum,与框架无关的库是 einops (本次使用)和 einx

主要有两个关键的操作:

  • einsum: 可以对 tensor 做任意的操作,见下面的表格的总结。
  • 操作类型 einsum 表达式示例 张量形状示例 等价常见操作 说明
    转置 / 维度重排 einsum("ij->ji", A) A: (i, j) A.T, permute 只改变维度顺序,不做数值计算
    求和(Reduce) einsum("ij->i", A) A: (i, j) A.sum(dim=1) 消除未出现在输出中的维度
    逐元素乘法 einsum("ij,ij->ij", A, B) A,B: (i, j) A * B 对应元素相乘
    内积 / 点积 einsum("i,i->", x, y) x,y: (i) torch.dot 输出标量
    外积 einsum("i,j->ij", x, y) x:(i), y:(j) outer 生成高阶张量
    矩阵乘法 einsum("ij,jk->ik", A, B) A:(i,j), B:(j,k) matmul 经典线性代数操作
    批量矩阵乘 einsum("bij,bjk->bik", A, B) A,B:(b,i,j) bmm 每个 batch 独立计算
    Trace(迹) einsum("ii->", A) A:(i,i) trace 对角线求和
    Attention 核心计算 einsum("bhtd,bhsd->bhts", Q, K) Q,K:(b,h,t,d) Q @ Kᵀ Transformer 中的相似度计算
    线性层 / Token Mixing einsum("btc,cd->btd", X, W) X:(b,t,c), W:(c,d) Linear / MatMul 特征维度线性变换
    多操作融合 einsum("bct,td->bc", X, W) X:(b,c,t) permute + matmul + sum 一行完成多步计算
  • rearrange: 可以重新排列、连接和拆分张量的维度。
  • 操作类型 rearrange 表达式示例 输入形状示例 输出形状 等价常见操作 说明
    转置 / 维度重排 rearrange(x, "b c h w -> b h w c") (b, c, h, w) (b, h, w, c) permute 仅改变维度顺序,不改数值
    合并维度(merge) rearrange(x, "b t d -> b (t d)") (b, t, d) (b, t·d) reshape 将多个维度合并成一个
    拆分维度(split) rearrange(x, "b (h d) -> b h d", h=8) (b, h·d) (b, h, d) view 需要显式指定拆分尺寸
    Flatten rearrange(x, "b c h w -> b (c h w)") (b, c, h, w) (b, c·h·w) flatten 常用于 CNN → Transformer
    Unflatten rearrange(x, "b (h w) c -> b c h w", h=32) (b, h·w, c) (b, c, h, w) view 恢复空间结构
    Patch 划分(ViT) rearrange(x, "b c (h p1) (w p2) -> b (h w) (c p1 p2)", p1=16, p2=16) (b, c, H, W) (b, num_patches, patch_dim) 手写切块 Vision Transformer 核心操作
    多头拆分(Heads) rearrange(x, "b t (h d) -> b h t d", h=8) (b, t, h·d) (b, h, t, d) reshape + permute Attention 预处理
    多头合并 rearrange(x, "b h t d -> b t (h d)") (b, h, t, d) (b, t, h·d) reshape Attention 输出恢复
    Batch / Token 混合 rearrange(x, "b t d -> (b t) d") (b, t, d) (b·t, d) view 常用于共享线性层
    保持不变(Identity) rearrange(x, "b t d -> b t d") (b, t, d) (b, t, d) 用于代码可读性 / 占位

import torch
from einops import einsum, rearrange

def einstein_example1():
    """
    Batched matrix multiplication with einops.einsum
    """
    D = torch.rand(5, 3, 4)
    A = torch.rand(3, 4)

    # 普通矩阵乘法,不知道 D 和 A 具体的形状以及输出的形状
    Y = D @ A.T

    print(Y)

    # 使用 einsum 进行矩阵乘法,可读性更好
    Y = einsum(D, A, "batch sequence d_in, d_out d_in -> batch sequence d_out")

    # D 可以有任意的前置维度,A 受限
    Y = einsum(D, A, "... d_in, d_out d_in -> ... d_out")

    print(Y)

有一批图像,对于每个图像希望基于一些缩放因子生成 10 种暗淡版本

def einstein_example2():
    """
    Broadcasted operations with einops.rearrange
    """
    images = torch.randn(64, 128, 128, 3) # (batch, height, width, channel)
    dim_by = torch.linspace(start=0.0, end=1.0, steps=10)

    # Reshape and multiply
    dim_value = rearrange(dim_by,    "dim_value              -> 1 dim_value 1 1 1")
    images_rearr = rearrange(images, "b height width channel -> b 1 height width channel")
    dimmed_images = images_rearr * dim_value

    # Use einsum to reshape and multiply in one step
    dimmed_images = einsum(
        images, dim_by,
        "batch height width channel, dim_value -> batch dim_value height width channel"
    )

Reference

[1] CS336 Assignment 1 (basics): Building a Transformer LM. Version 1.0.5

Contact

There may be some errors present. If you find any, please feel free to contact me at wangqiyao25@mails.ucas.ac.cn. I would appreciate it!