[{"content":"","date":"2026年6月24日","externalUrl":null,"permalink":"/tags/","section":"标签","summary":"","title":"标签","type":"tags"},{"content":"","date":"2026年6月24日","externalUrl":null,"permalink":"/categories/","section":"分类","summary":"","title":"分类","type":"categories"},{"content":"这里会介绍我和这个博客。\n","date":"2026年6月24日","externalUrl":null,"permalink":"/about/","section":"首页","summary":"","title":"关于","type":"page"},{"content":"","date":"2026年6月24日","externalUrl":null,"permalink":"/","section":"首页","summary":"","title":"首页","type":"page"},{"content":"","date":"2026年6月24日","externalUrl":null,"permalink":"/posts/","section":"文章","summary":"","title":"文章","type":"posts"},{"content":"这里会整理值得访问的朋友和站点。\n","date":"2026年6月24日","externalUrl":null,"permalink":"/friends/","section":"首页","summary":"","title":"友链","type":"page"},{"content":" Go Context 的使用 # 1. Context 本质说明 # Context 源码定义 # context.Context 本质上是一个接口，用来在多个函数、goroutine 或请求处理链路之间传递“上下文状态”。它最常见的用途不是传业务参数，而是传递取消信号、超时时间、截止时间以及少量跨调用链共享的数据。\n可以把它理解为一个轻量的控制器：上游创建 Context，下游持续监听它。一旦上游决定取消任务，或任务超过指定时间，下游就可以及时停止工作，避免 goroutine 泄漏或无意义的资源消耗。\nContext在 go 语言源码中定义如下：\ntype Context interface { Deadline() (deadline time.Time, ok bool) Done() \u0026lt;-chan struct{} // 只读 channel Err() error Value(key any) any } 这个接口只有四个方法：\n方法 作用 Deadline() 返回当前任务的截止时间，如果没有截止时间则 ok = false Done() 返回一个只读 channel，当任务被取消或超时时会关闭 Err() 返回取消原因，例如 context.Canceled 或 context.DeadlineExceeded Value(key) 根据 key 读取上下文中携带的值 因此，学习 Context 的核心就是理解两件事：如何把取消/超时信号传下去，以及下游如何监听这个信号并及时退出。\nemptyContext # emptyCtx 是标准库里最基础的空 Context 实现。它实现了 Context 接口要求的四个方法，但这些方法都只返回“空语义”：没有截止时间、没有取消信号、没有错误、也没有携带任何值。后面的 backgroundCtx 和 todoCtx 都是通过嵌入 emptyCtx，直接获得这四个方法，从而成为可用的 Context。源代码如下所示：\ntype emptyCtx struct{} func (emptyCtx) Deadline() (deadline time.Time, ok bool) { return } func (emptyCtx) Done() \u0026lt;-chan struct{} { return nil } func (emptyCtx) Err() error { return nil } func (emptyCtx) Value(key any) any { return nil } Context 的两个祖先：Background 与 TODO # Background() 和 TODO() 是标准库提供的两个最基础的根 Context。它们本身不会主动取消、没有截止时间，也不携带任何值。\n从源码上看，backgroundCtx 和 todoCtx 并没有各自重新实现 Deadline、Done、Err、Value 四个方法，而是都嵌入了前面提到的 emptyCtx。因此它们天然拥有了 emptyCtx 的空实现，只是在语义上分别表示“正式根节点”和“临时占位节点”。\ntype backgroundCtx struct{ emptyCtx } func (backgroundCtx) String() string { return \u0026#34;context.Background\u0026#34; } type todoCtx struct{ emptyCtx } func (todoCtx) String() string { return \u0026#34;context.TODO\u0026#34; } // Background returns a non-nil, empty [Context]. It is never canceled, has no // values, and has no deadline. It is typically used by the main function, // initialization, and tests, and as the top-level Context for incoming // requests. func Background() Context { return backgroundCtx{} } // TODO returns a non-nil, empty [Context]. Code should use context.TODO when // it\u0026#39;s unclear which Context to use or it is not yet available (because the // surrounding function has not yet been extended to accept a Context // parameter). func TODO() Context { return todoCtx{} } struct{ emptyCtx }：只有类型名，没有字段名。Go 会把它当成匿名字段，字段名默认就是类型名，逻辑上可以近似理解为：\ntype backgroundCtx struct { emptyCtx emptyCtx // 逻辑上可以这样理解，但源码写法更简洁 } 但需要注意，源码里的 struct{ emptyCtx } 是匿名字段写法，会带来方法提升；外层的 backgroundCtx 可以直接使用 emptyCtx 的 Deadline、Done、Err、Value 方法。\n这段源码可以拆成三层理解：\nemptyCtx 提供最基础的空 Context 行为。 backgroundCtx 嵌入 emptyCtx，并通过 String() 标记自己是 context.Background。 todoCtx 嵌入 emptyCtx，并通过 String() 标记自己是 context.TODO。 Background() 通常用于没有上游 Context 的地方，例如 main 函数、初始化逻辑、测试代码、后台任务入口，或者作为一个请求处理链路的最顶层 Context。\nTODO() 则表示“这里暂时还不知道该传哪个 Context”。它通常只作为过渡写法使用，比如函数还没来得及改造成接收 ctx context.Context 参数时，可以先用 context.TODO() 占位，后续再替换成真正的父级 Context。\n需要注意的是，Context 是一条向下传递的链路。没有上游 Context 时，可以从 context.Background() 开始派生：\nroot := context.Background() ctx, cancel := context.WithTimeout(root, 3*time.Second) defer cancel() 但如果函数已经收到了上游传入的 ctx，就应该从这个 ctx 继续派生：\nchildCtx, cancel := context.WithCancel(ctx) defer cancel() 这样上游的取消、超时和截止时间才能继续向下传递。如果在业务链路中随意重新使用 context.Background()，就会切断原有链路，下游可能无法感知请求已经结束。\n2. 四大派生函数 # 2.1 WithCancel：手动取消 # context.WithCancel(parent) 用来创建一个可以被手动取消的子 Context。它会返回两个值：\nctx, cancel := context.WithCancel(parent) ctx：派生出来的子 Context cancel：取消函数，调用后会立即关闭 ctx.Done() 返回的只读 channel 当下游任务正在监听 ctx.Done() 时，一旦上游调用 cancel()，下游就可以立刻感知取消信号，并提前结束当前任务。\n下面这个例子模拟“获取 IP 需要 3 秒，但主流程在 2 秒后主动取消任务”：\npackage main import ( \u0026#34;context\u0026#34; \u0026#34;fmt\u0026#34; \u0026#34;sync\u0026#34; \u0026#34;time\u0026#34; ) func main() { var wait sync.WaitGroup wait.Add(1) ctx, cancel := context.WithCancel(context.Background()) start := time.Now() go func() { defer wait.Done() ip, err := getIP(ctx) fmt.Println(\u0026#34;ip:\u0026#34;, ip, \u0026#34;err:\u0026#34;, err) }() go func() { time.Sleep(2 * time.Second) cancel() }() wait.Wait() fmt.Println(\u0026#34;执行完成:\u0026#34;, time.Since(start)) } func getIP(ctx context.Context) (string, error) { select { case \u0026lt;-time.After(3 * time.Second): return \u0026#34;192.168.200.1\u0026#34;, nil case \u0026lt;-ctx.Done(): return \u0026#34;\u0026#34;, ctx.Err() } } 这段代码中，getIP 并不是单纯 time.Sleep(3 * time.Second)，而是通过 select 同时等待两个事件：\ntime.After(3 * time.Second)：模拟耗时任务正常完成 ctx.Done()：监听上游是否取消任务 因为 2 秒后调用了 cancel()，所以 ctx.Done() 会先被关闭，getIP 会提前返回 context canceled，程序总耗时约 2 秒。\n需要注意，context 不会强制杀死 goroutine。它只是关闭 Done() channel，把取消信号传递出去；真正停止执行，需要下游函数自己监听 ctx.Done() 并主动返回。\n下面是一种容易写错的方式：\nfunc GetIp(ctx context.Context) (ip string, err error) { go func() { select { case \u0026lt;-ctx.Done(): fmt.Println(\u0026#34;取消\u0026#34;, ctx.Err().Error()) err = ctx.Err() wait.Done() return } }() time.Sleep(3 * time.Second) ip = \u0026#34;192.168.200.1\u0026#34; wait.Done() return } 这段代码的问题在于：内部新开的 goroutine 只是监听到了取消信号，但它不能让外层的 GetIp 立即停止。外层函数仍然会继续执行 time.Sleep(3 * time.Second)，然后继续返回 IP。\n同时，内部 goroutine 直接修改返回变量 err，还在里面调用 wait.Done()，这会让职责变得混乱：监听到取消并不等于整个 GetIp 任务已经完成。如果后面 GetIp 自己又调用一次 wait.Done()，还可能导致 WaitGroup 计数器被多减。\n更清晰的写法是：WaitGroup 由启动 goroutine 的外层负责，业务函数只负责业务逻辑和响应 ctx.Done()。\nWaitGroup 的工作方式可以理解为一个计数器：\nAdd(n)：登记有 n 个任务需要等待 Done()：表示有一个任务完成，计数器减一 Wait()：阻塞当前 goroutine，直到计数器归零 Add 和 Done 不会自动绑定到某一个具体 goroutine。机制上它只关心计数器是否归零；但工程实践中，应该让每个 Done() 对应一个真实完成的任务，避免在“监听到某个事件”时就提前把任务标记为完成。\n2.2 WithDeadline：截止时间自动取消 # context.WithDeadline(parent, deadline) 用来创建一个带“绝对截止时间”的子 Context。当系统时间到达 deadline 后，这个 Context 会自动取消，ctx.Done() 返回的 channel 会被关闭，ctx.Err() 会返回 context.DeadlineExceeded。\n它和 WithCancel 的区别在于：\nWithCancel：需要手动调用 cancel() 才会取消 WithDeadline：到达指定时间点后自动取消 仍然沿用前面“获取 IP 需要 3 秒”的例子。如果我们希望这个操作最晚只能执行到 2 秒后，就可以这样写：\npackage main import ( \u0026#34;context\u0026#34; \u0026#34;fmt\u0026#34; \u0026#34;sync\u0026#34; \u0026#34;time\u0026#34; ) func main() { var wait sync.WaitGroup wait.Add(1) deadline := time.Now().Add(2 * time.Second) ctx, cancel := context.WithDeadline(context.Background(), deadline) defer cancel() start := time.Now() go func() { defer wait.Done() ip, err := getIP(ctx) fmt.Println(\u0026#34;ip:\u0026#34;, ip, \u0026#34;err:\u0026#34;, err) }() wait.Wait() fmt.Println(\u0026#34;执行完成:\u0026#34;, time.Since(start)) } func getIP(ctx context.Context) (string, error) { select { case \u0026lt;-time.After(3 * time.Second): return \u0026#34;192.168.200.1\u0026#34;, nil case \u0026lt;-ctx.Done(): return \u0026#34;\u0026#34;, ctx.Err() } } 这段代码中：\ndeadline := time.Now().Add(2 * time.Second) 表示截止时间是当前时间 2 秒后 getIP 原本需要 3 秒才能返回 IP 但 2 秒后 ctx.Done() 会自动关闭 getIP 会提前返回 context deadline exceeded 输出结果大致如下：\nip: err: context deadline exceeded 执行完成: 2.00... 这里的关键仍然是 getIP 内部的 select：\nselect { case \u0026lt;-time.After(3 * time.Second): return \u0026#34;192.168.200.1\u0026#34;, nil case \u0026lt;-ctx.Done(): return \u0026#34;\u0026#34;, ctx.Err() } time.After(3 * time.Second) 表示获取 IP 需要 3 秒后才会完成，而 ctx.Done() 会在 2 秒后因为到达截止时间而关闭。由于 ctx.Done() 先发生，所以 select 会进入取消分支，函数不会继续等待 3 秒，而是直接返回 ctx.Err()。\n所以 WithDeadline 的核心作用是：不是等待某个固定时长，而是指定一个明确的截止时间点。只要到达这个时间点，不管任务有没有完成，都会触发取消信号。\n另外，WithDeadline 虽然会在到达截止时间后自动取消，但它依然会返回一个 cancel 函数：\nctx, cancel := context.WithDeadline(context.Background(), deadline) defer cancel() 这里的 defer cancel() 不是为了手动提前取消，而是为了在函数结束时释放和这个 Context 相关的资源。实际开发中，只要调用了 WithCancel、WithTimeout、WithDeadline 这类函数，通常都应该在合适的位置调用返回的 cancel。\n实际开发中，如果你关心的是“最多执行多久”，通常用 WithTimeout 更直接；如果你关心的是“最晚不能超过某个具体时间点”，就更适合用 WithDeadline。\n2.3 WithTimeout：固定时长自动取消 # context.WithTimeout(parent, timeout) 用来创建一个带“固定超时时长”的子 Context。它和 WithDeadline 非常类似，最终效果都是：到达指定时间后自动取消，ctx.Done() 关闭，ctx.Err() 返回 context.DeadlineExceeded。\n不同点在于参数表达方式：\nWithDeadline 接收的是一个明确的时间点，例如 time.Now().Add(2 * time.Second) WithTimeout 接收的是一个时间长度，例如 2 * time.Second 从 Go 源码可以看出，WithTimeout 本质上就是对 WithDeadline 的一层封装：\nfunc WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { return WithDeadline(parent, time.Now().Add(timeout)) } 也就是说，下面这两种写法在语义上是接近的：\nctx, cancel := context.WithDeadline(parent, time.Now().Add(2*time.Second)) defer cancel() ctx, cancel := context.WithTimeout(parent, 2*time.Second) defer cancel() 在实际使用中，WithTimeout 更常见，因为很多场景关心的是“这个操作最多执行多久”，而不是“这个操作必须在哪个具体时间点之前结束”。\n2.4 WithValue：携带键值对 # WithValue 用来在 Context 链路中携带少量键值对数据。它的特点是：函数参数表面上只传了一个 ctx context.Context，但下游可以通过 ctx.Value(key) 读取上游放进去的值。\n因此，WithValue 必须和 Value() 方法搭配使用：\n上游使用 context.WithValue(parent, key, value) 写入值 下游使用 ctx.Value(key) 读取值 示例：\npackage main import ( \u0026#34;context\u0026#34; \u0026#34;fmt\u0026#34; ) type contextKey string const userIDKey contextKey = \u0026#34;userID\u0026#34; func main() { ctx := context.Background() ctx = context.WithValue(ctx, userIDKey, \u0026#34;user-001\u0026#34;) handleRequest(ctx) } func handleRequest(ctx context.Context) { printUserID(ctx) } func printUserID(ctx context.Context) { userID, ok := ctx.Value(userIDKey).(string) if !ok { fmt.Println(\u0026#34;未找到 userID\u0026#34;) return } fmt.Println(\u0026#34;当前用户:\u0026#34;, userID) } 从 main 到 printUserID，函数之间看起来只传递了一个 ctx，但 userID 已经被挂在这条 Context 链路上，因此下游可以在需要时取出来。\n需要注意的是，WithValue 不适合传递主要业务参数。它更适合放请求级别的元信息，例如用户 ID、trace ID、请求 ID 等。为了避免 key 冲突，通常会自定义 key 类型，而不是直接使用普通字符串作为 key。\n3. Context 的 Timeout 继承链 # Context 是一条从父级向子级派生的链路。这里说的“继承”不是面向对象里的继承，而是指：子 Context 会受到父 Context 的影响。\n最重要的规则是：\n父 Context 被取消后，所有由它派生出来的子 Context 都会被取消 子 Context 被取消后，不会反向取消父 Context 如果父子 Context 都设置了截止时间，最终会以更早到期的那个为准 也就是说，子 Context 不能突破父 Context 的生命周期。如果父级最多只能执行 2 秒，那么子级即使设置了 5 秒，也不会真的等到 5 秒后才取消，因为父级在 2 秒时已经取消了。\n例如：\nparentCtx, parentCancel := context.WithTimeout(context.Background(), 2*time.Second) defer parentCancel() childCtx, childCancel := context.WithTimeout(parentCtx, 5*time.Second) defer childCancel() 这段代码中，parentCtx 的超时时间是 2 秒，childCtx 的超时时间是 5 秒。虽然子 Context 设置得更长，但它是从 parentCtx 派生出来的，所以它最多也只能存活 2 秒。2 秒后父 Context 超时，子 Context 的 Done() 也会一起关闭。\n如果反过来：\nparentCtx, parentCancel := context.WithTimeout(context.Background(), 5*time.Second) defer parentCancel() childCtx, childCancel := context.WithTimeout(parentCtx, 2*time.Second) defer childCancel() 这时父 Context 可以存活 5 秒，但子 Context 自己只设置了 2 秒，所以 childCtx 会在 2 秒后先被取消，而 parentCtx 不会因为子级取消而取消。\n因此，对于存在父子关系的 Context 来说，实际生效的截止时间可以理解为：\n实际截止时间 = 父 Context 截止时间 和 子 Context 截止时间 中更早的那个 这个规则可以保证上游的控制能力不会被下游绕开。上游一旦设置了整体超时时间，下游即使继续派生新的 Context，也只能在这个整体时间范围内工作。\n参考 # 20分钟搞懂go语言中的context——枫枫知道 go语言的context——枫枫知道的博客 golang context该怎么玩儿\n","date":"2026年6月9日","externalUrl":null,"permalink":"/posts/go-basic/go-context/","section":"文章","summary":"","title":"Go Context 的使用和说明","type":"posts"},{"content":"","date":"2026年6月9日","externalUrl":null,"permalink":"/categories/go-%E8%AF%AD%E8%A8%80%E5%9F%BA%E7%A1%80/","section":"分类","summary":"","title":"Go 语言基础","type":"categories"},{"content":"","date":"2026年6月9日","externalUrl":null,"permalink":"/tags/golang/","section":"标签","summary":"","title":"GoLang","type":"tags"},{"content":" 记录 Go Modules 依赖管理、SDK 版本管理、基础语法、并发模式与类型断言等 Go 语言基础知识点。\n1. Go Modules 依赖管理 # 1.1 核心理念：与 Python 的本质区别 # 维度 Python（pip/conda） Go（Go Modules） 隔离方式 需要创建虚拟环境（venv / conda env），否则全局污染 天然隔离，无需创建虚拟环境 依赖存储 每个虚拟环境各存一份副本，磁盘占用大 全局统一缓存池，按精确版本号分别存储，多项目复用同一份文件 版本冲突 同一环境内不允许同一个包的两个版本共存 全局缓存中可同时存在 v1.8.0 和 v1.9.0，编译时按 go.mod 清单精确引用 激活操作 每次需要 conda activate xxx 无需任何激活步骤，进入项目目录即自动生效 1.2 go.mod 文件 # 每个 Go 项目根目录下都有一个 go.mod 文件，它声明了：\n当前模块的名称 所需的 Go 最低版本 依赖的第三方包及其精确版本 初始化一个新模块：\ngo mod init \u0026lt;模块名\u0026gt; # 例如：go mod init hello 示例 go.mod 内容：\nmodule hello go 1.21 1.3 依赖缓存位置 # 所有下载的第三方依赖统一存放在：$GOPATH/pkg/mod\n通过以下命令安装的全局工具存放在：$GOPATH/bin\n2. Go SDK 版本管理 # 2.1 适用场景 # 当不同项目需要不同的 Go 语言版本时（如旧系统需要 Go 1.16，新项目使用 Go 1.22 的泛型特性），需要进行 SDK 版本管理。\n2.2 方法 A：官方原生多版本（适合偶尔切换） # # 安装指定版本的 Go go install golang.org/dl/go1.16.15@latest go1.16.15 download # 使用指定版本编译项目 go1.16.15 build main.go go1.16.15 run main.go 不影响当前主力版本，两者并存。\n2.3 方法 B：第三方版本管理器（适合频繁切换） # 工具名 平台 说明 gvm (Go Version Manager) Linux / macOS 类似 Node 的 nvm，通过 gvm use go1.21.0 切换版本 g Windows / 跨平台 轻量级版本管理器，Windows 下推荐使用 3. Hello World 快速上手 # 3.1 最小项目结构 # hello/ ├── go.mod ← 模块声明文件 └── main.go ← 程序入口 3.2 代码示例 # package main // 声明属于 main 包，Go 中可执行程序必须在 main 包内 import \u0026#34;fmt\u0026#34; // 导入格式化输出标准库 func main() { // main 函数是程序执行入口 fmt.Println(\u0026#34;Hello, World!\u0026#34;) } 3.3 运行方式 # # 方式一：直接运行（开发阶段常用） go run main.go # 方式二：编译为可执行文件 go build .\\hello.exe 附：常用 go env 命令速查 # go env # 查看所有环境变量 go env GOPATH # 查看 GOPATH go env GOPROXY # 查看代理配置 go env GOROOT # 查看 Go SDK 安装目录 go env -w KEY=VALUE # 永久写入环境变量配置 4. 单元测试基础（testing 包） # 4.1 测试函数规范 # Go 内置了完整的测试框架，无需安装第三方依赖。测试文件以 _test.go 结尾，测试函数必须满足以下约定：\n函数名以大写 Test 开头，后接被测功能名称（如 TestHello） 接收唯一参数 t *testing.T，用于控制测试流程和报告结果 4.2 经典测试三部曲 # func TestHello(t *testing.T) { got := say_hello() // 第一步：执行目标函数，获取实际结果 want := \u0026#34;Hello World!\u0026#34; // 第二步：定义期望的正确结果 if got != want { // 第三步：比对两者，不一致则报错 t.Errorf(\u0026#34;got \u0026#39;%q\u0026#39; want \u0026#39;%q\u0026#39;\u0026#34;, got, want) } } t.Errorf：将测试标记为失败，并打印格式化错误信息 %q：以带双引号的安全形式打印字符串，方便发现空格、特殊字符等细微差异 4.3 运行测试 # go test # 运行当前包的所有测试 go test -v # 显示详细输出（包括每个测试函数的名称和结果） go test -run Xxx # 仅运行函数名匹配 Xxx 的测试 5. 变量声明与赋值 # 5.1 三种声明方式对比 # 方式 语法 适用场景 完整声明 var name string = \u0026quot;张三\u0026quot; 全局变量、需要显式指定类型时 类型推断声明 var name = \u0026quot;张三\u0026quot; 全局变量、右侧类型明确时可省略类型 简短声明 name := \u0026quot;张三\u0026quot; 仅限函数内部，最常用的写法 5.2 :=（简短变量声明） # := 同时完成 声明新变量 + 赋值 两个动作，并自动推断右侧值的类型：\ngot := say_hello() // 自动推断 got 的类型为 say_hello() 的返回类型 count := 42 // 自动推断为 int pi := 3.14 // 自动推断为 float64 限制：\n只能在函数内部使用，全局变量必须用 var 左侧必须至少有一个新变量，不能对已存在的变量重复使用 := 5.3 =（赋值运算符） # = 仅用于修改已声明变量的值，左侧变量必须已经存在：\ndrink := \u0026#34;可乐\u0026#34; // 创建新变量（:=） drink = \u0026#34;雪碧\u0026#34; // 修改已有变量的值（=） // drink := \u0026#34;雪碧\u0026#34; // ❌ 错误：不能对已存在的变量再次使用 := 5.4 速记口诀 # := → \u0026ldquo;无中生有\u0026rdquo;（声明 + 赋值） = → \u0026ldquo;喜新厌旧\u0026rdquo;（仅修改旧变量的值） 6. 整数类型与类型系统 # 6.1 int 的默认大小 # int 的实际位宽由编译目标平台的架构决定：\n平台架构 int 大小 等效于 32 位系统（GOARCH=386） 32 位 / 4 字节 int32 64 位系统（GOARCH=amd64） 64 位 / 8 字节 int64 当今绝大多数开发机和服务器均为 64 位，因此日常使用中 int 实际等同于 int64。\n6.2 强类型注意事项 # 即使 int 在 64 位系统上与 int64 位宽相同，Go 编译器依然视它们为不同类型，不能直接互相赋值：\nvar a int = 100 var b int64 // b = a // ❌ 编译错误：cannot use a (type int) as type int64 b = int64(a) // ✅ 必须显式类型转换 6.3 选型建议 # 一般场景：直接用 int，性能最优（贴合硬件字长），且与标准库 API 兼容（切片长度、循环变量等均为 int） 精确控制：网络协议、二进制文件读写等对字节数有严格要求的场景，使用 int8 / int16 / int32 / int64 无符号整数：对应有 uint、uint8（即 byte）、uint16、uint32、uint64 7. sort 包排序机制详解 # 7.1 核心组件分工 # sort 包采用**\u0026ldquo;接口解耦\u0026rdquo;**设计：规则定义与排序执行分离。\n组件 类型 作用 sort.IntSlice(x) 包装类型 将 []int 包装为实现排序接口的结构体，提供升序比较规则，本身不执行排序 sort.Reverse(...) 规则修饰器 将传入的比较规则反转（大小颠倒），本身不执行排序 sort.Sort(...) 执行函数 唯一的执行者，根据传入的规则对数据进行排序 sort.Ints(x) 快捷函数 sort.Sort(sort.IntSlice(x)) 的简写，一步完成升序排序 7.2 代码拆解示例 # // 降序排列：从大到小 sort.Sort(sort.Reverse(sort.IntSlice(ints))) 执行流程（从内到外）：\nsort.IntSlice(ints) → 套上\u0026quot;升序规则\u0026quot;外壳 sort.Reverse(...) → 将规则反转为\u0026quot;降序规则\u0026quot;（此时数据未被移动） sort.Sort(...) → 按照降序规则执行真正的排序算法 7.3 常见易错点 # sort.Int 不存在，会编译报错；正确写法是 sort.IntSlice（包装类型）或 sort.Ints（快捷函数） sort.Reverse 不是物理反转数组，它只是反转比较规则 7.4 推荐的现代写法 # // 方式一：自定义排序规则（Go 1.8+） sort.Slice(ints, func(i, j int) bool { return ints[i] \u0026gt; ints[j] // 降序 }) // 方式二：使用 slices 标准库（Go 1.21+） import \u0026#34;slices\u0026#34; slices.Sort(ints) // 升序 slices.Reverse(ints) // 物理反转（这里的 Reverse 才是真正的反转动作） 8. 条件判断惯用写法 # 8.1 Go 支持 else if # Go 语法完整支持 if / else if / else 结构，唯一的格式限制是 else if 和 else 必须紧跟在前一个 } 的同一行：\nif score \u0026gt;= 90 { fmt.Println(\u0026#34;优秀\u0026#34;) } else if score \u0026gt;= 60 { // ✅ 必须和 } 同行 fmt.Println(\u0026#34;及格\u0026#34;) } else { fmt.Println(\u0026#34;不及格\u0026#34;) } ⚠️ 如果将 else if 换行放到新的一行，编译器会直接报错。这是 Go 对大括号位置的强制要求。\n8.2 卫语句（Guard Clause）：异常前置拦截 # 当某些条件不满足时函数应尽早退出，避免深层嵌套：\nfunc SaveUser(user *User) error { if user == nil { return errors.New(\u0026#34;用户不存在\u0026#34;) // 提前退出 } if user.Age \u0026lt; 0 { return errors.New(\u0026#34;年龄非法\u0026#34;) // 提前退出 } // 核心业务：所有校验通过后才执行，无嵌套负担 db.Save(user) return nil } 8.3 选择策略 # 场景 推荐写法 原因 同级互斥的业务分流（各分支后续还有公共逻辑） if / else if / else 各分支平等，无法提前 return 异常校验、前置拦截（不满足条件直接退出） 卫语句（提前 return） 减少嵌套，保持主干扁平 平行条件分支较多（≥3 个） 无条件 switch 比长链 else if 更简洁 8.4 无条件 switch：多分支的优雅替代 # switch { case score \u0026gt;= 90: fmt.Println(\u0026#34;优秀\u0026#34;) case score \u0026gt;= 60: fmt.Println(\u0026#34;及格\u0026#34;) default: fmt.Println(\u0026#34;不及格\u0026#34;) } 9. 值传递与指针 # 9.1 核心铁律：一切皆值传递 # Go 中所有的参数传递都是值拷贝（Pass by Value），不存在 C++ 中的引用传递（Pass by Reference）。想要在函数内修改外部变量，唯一的方式是传递指针：\nfunc change(age *int) { // 接收指针 *age = 20 // 通过指针修改原始值 } func main() { a := 18 change(\u0026amp;a) // 传递 a 的内存地址 fmt.Println(a) // 输出 20 } 9.2 \u0026ldquo;引用类型\u0026quot;的真相 # map、slice、channel 在日常中被称为\u0026quot;引用类型\u0026rdquo;，因为将它们传入函数后，函数内部的修改会影响到外部。但这并不是真正的引用传递，而是因为它们的底层结构体内嵌了指针。\n以切片为例，[]int 的底层结构：\ntype slice struct { array unsafe.Pointer // 指向底层数组的指针 len int // 切片长度 cap int // 切片容量 } 传入函数时，Go 依然执行值拷贝——复制了这三个字段。但由于 array 指针的副本和原件指向同一块底层数组，因此：\n✅ 修改元素 slice[0] = 99：通过指针改的是公共底层数组，外部可见 ❌ 扩容 append(slice, 100)：可能导致底层数组重新分配，新指针仅存在于函数内部副本中，外部不受影响 9.3 总结 # 概念 Go 中是否存在 说明 值传递 ✅ 唯一的传递方式 所有参数传递（包括指针本身）都是值拷贝 引用传递 ❌ 不存在 没有 C++ 那样的 \u0026amp; 引用语法 指针 ✅ 存在 通过 * 和 \u0026amp; 操作符使用，是修改外部变量的唯一手段 \u0026ldquo;引用类型\u0026rdquo; ⚠️ 仅为惯用叫法 map / slice / channel 底层内嵌指针，表现出类似引用的行为，但本质仍是值传递 10. 协程（Goroutine）与并发等待 # 10.1 协程基础 # Go 使用 go 关键字启动协程（goroutine），它是 Go 运行时管理的轻量级线程，创建成本极低（初始栈仅约 2KB）：\ngo pay(\u0026#34;张三\u0026#34;) // 启动一个新协程执行 pay 函数 go pay(\u0026#34;李四\u0026#34;) // 再启动一个，与上一个并行执行 go pay(\u0026#34;王五\u0026#34;) // 三个协程几乎同时运行 关键问题：主协程（main 函数）不会等待子协程完成。一旦 main 执行到末尾，整个程序立即退出，无论子协程是否结束。因此必须有一种机制来\u0026quot;卡住\u0026quot;主协程，等所有子协程跑完再退出。\n10.2 方案一：sync.WaitGroup（推荐） # sync.WaitGroup 是标准库提供的并发安全计数器，专门用于等待一组协程完成：\nvar wg sync.WaitGroup func pay(name string) { defer wg.Done() // 函数结束时计数器 -1 fmt.Printf(\u0026#34;%s 在付钱\\n\u0026#34;, name) time.Sleep(1 * time.Second) } func main() { startTime := time.Now() wg.Add(3) // 告知计数器：共有 3 个协程需要等待 go pay(\u0026#34;张三\u0026#34;) go pay(\u0026#34;李四\u0026#34;) go pay(\u0026#34;王五\u0026#34;) wg.Wait() // 阻塞主协程，直到计数器归 0 fmt.Println(\u0026#34;总耗时:\u0026#34;, time.Since(startTime)) } 三步走：\n方法 作用 wg.Add(n) 设置需要等待的协程数量（计数器 +n） wg.Done() 协程结束时调用（计数器 -1），通常配合 defer wg.Wait() 阻塞当前协程，直到计数器减至 0 defer 的意义：即使函数中途发生 panic，defer wg.Done() 也保证会被执行，避免计数器永远无法归 0 导致程序死锁。\n10.3 方案二：手动全局计数器 + 轮询（学习原理用） # 如果不借助 WaitGroup，可以用全局变量 + 死循环轮询模拟相同效果。但需要解决两个关键问题：\n问题一：并发安全——num-- 不是原子操作 # 多个协程同时修改全局变量，可能导致数据竞争（Data Race）。解决方案是使用 sync/atomic 包提供的原子操作：\nvar num int32 = 3 func pay(name string) { fmt.Printf(\u0026#34;%s 在付钱\\n\u0026#34;, name) time.Sleep(1 * time.Second) atomic.AddInt32(\u0026amp;num, -1) // 原子减 1，硬件级别保证安全 } 普通操作 原子操作 区别 num-- atomic.AddInt32(\u0026amp;num, -1) 前者多协程同时写入会数据错乱 if num == 0 atomic.LoadInt32(\u0026amp;num) == 0 前者可能读到\u0026quot;写了一半\u0026quot;的脏数据 问题二：CPU 空转——死循环吃满性能 # 主协程用 for {} 循环检查计数器，如果不做任何让步，会导致 CPU 使用率飙升到 100%：\nfor { if atomic.LoadInt32(\u0026amp;num) == 0 { break } runtime.Gosched() // 主动让出 CPU 时间片，让子协程有机会执行 // 或使用 time.Sleep(1 * time.Millisecond) 短暂休眠 } 让步方式 效果 runtime.Gosched() 立即让出当前时间片，调度器安排其他协程运行 time.Sleep(1 * time.Millisecond) 休眠 1ms 后再检查，CPU 基本无负载 不做任何让步 ❌ CPU 单核跑满 100%，严重浪费资源 10.4 两种方案对比 # 维度 sync.WaitGroup 手动计数器 + 轮询 并发安全 内部已处理，开箱即用 需要手动使用 atomic 包保证 等待机制 真正的阻塞（不消耗 CPU） 忙等轮询（需 Gosched / Sleep 降耗） 代码复杂度 低（3 行核心代码） 高（需处理原子操作 + CPU 让步） 适用场景 生产环境首选 仅用于学习理解底层原理 性能 最优 有额外轮询开销 结论：生产代码中应始终使用 sync.WaitGroup，它本质上就是对\u0026quot;原子计数器 + 高效阻塞唤醒\u0026quot;的封装。手动轮询方式作为学习理解的手段非常有价值，但不建议在实际项目中使用。\n11. sync 包对象禁止值拷贝 # 11.1 问题现象 # 将 sync.WaitGroup 作为值类型参数传递给函数时，编译器会抛出警告：\ncall of pay copies lock value: sync.WaitGroup contains sync.noCopy 11.2 出错的代码 # func pay(name string, wait sync.WaitGroup) { // ❌ 值传递，产生副本 fmt.Printf(\u0026#34;%s 在付钱\\n\u0026#34;, name) time.Sleep(1 * time.Second) wait.Done() // 操作的是副本的计数器，对原始对象无效 } func main() { var wait sync.WaitGroup wait.Add(3) go pay(\u0026#34;张三\u0026#34;, wait) // ❌ 把整个 WaitGroup 拷贝了一份传进去 go pay(\u0026#34;李四\u0026#34;, wait) go pay(\u0026#34;王五\u0026#34;, wait) wait.Wait() // 永远等不到计数器归零 → 死锁 } 11.3 原因分析 # sync.WaitGroup 内部嵌入了 sync.noCopy 标记，编译器静态分析工具（go vet）会检测到拷贝行为并发出警告 值传递时，Go 会创建 WaitGroup 的完全独立副本 协程内调用 wait.Done() 只会减少副本的计数器，main 中原始的计数器纹丝不动 主协程的 wait.Wait() 永远等不到计数器归 0，最终触发死锁（deadlock） 11.4 正确做法：传指针 # func pay(name string, wait *sync.WaitGroup) { // ✅ 指针传递 defer wait.Done() fmt.Printf(\u0026#34;%s 在付钱\\n\u0026#34;, name) time.Sleep(1 * time.Second) } func main() { var wait sync.WaitGroup wait.Add(3) go pay(\u0026#34;张三\u0026#34;, \u0026amp;wait) // ✅ 传递地址，所有协程共享同一个计数器 go pay(\u0026#34;李四\u0026#34;, \u0026amp;wait) go pay(\u0026#34;王五\u0026#34;, \u0026amp;wait) wait.Wait() } 11.5 适用范围 # 该规则适用于 sync 包下的所有同步原语，它们都内嵌了 noCopy 标记：\n类型 说明 必须指针传递 sync.WaitGroup 等待计数器 ✅ sync.Mutex 互斥锁 ✅ sync.RWMutex 读写锁 ✅ sync.Cond 条件变量 ✅ sync.Once 单次执行 ✅ sync.Map 并发安全字典 ✅ 口诀：sync 包里的家伙，一个都不能复印，只能给地址。\n12. Channel 与 WaitGroup 协作模式 # 12.1 场景 # 当协程不仅需要同步等待，还需要回传结果数据时，需要将 Channel（信道）与 WaitGroup 配合使用。\n12.2 完整示例 # var moneyChan = make(chan int) // 无缓冲信道 func pay(name string, money int, wait *sync.WaitGroup) { defer wait.Done() fmt.Printf(\u0026#34;%s 在付钱\\n\u0026#34;, name) time.Sleep(1 * time.Second) moneyChan \u0026lt;- money // 向信道发送数据 } func main() { var wait sync.WaitGroup wait.Add(3) startTime := time.Now() go pay(\u0026#34;张三\u0026#34;, 2, \u0026amp;wait) go pay(\u0026#34;李四\u0026#34;, 3, \u0026amp;wait) go pay(\u0026#34;王五\u0026#34;, 5, \u0026amp;wait) // 关键：用独立协程等待完成并关闭信道 go func() { defer close(moneyChan) wait.Wait() }() var moneyList []int for val := range moneyChan { moneyList = append(moneyList, val) } fmt.Println(\u0026#34;总耗时:\u0026#34;, time.Since(startTime)) fmt.Println(moneyList) } 12.3 为什么关闭信道的逻辑必须用 go 启动新协程？ # 这是本模式中最容易出错的地方。如果不加 go，直接在主协程中同步执行：\n// ❌ 错误写法：会死锁 func() { defer close(moneyChan) wait.Wait() // 主协程被卡在这里 }() // 下面这行永远执行不到 for val := range moneyChan { ... } 死锁的时序分析：\n┌─────────────────────────────────────────────────────────────────┐ │ ❌ 不加 go 的情况 │ ├──────────────────┬──────────────────────────────────────────────┤ │ 主协程 │ pay 协程（×3） │ ├──────────────────┼──────────────────────────────────────────────┤ │ wait.Wait() 阻塞 │ 执行到 moneyChan \u0026lt;- money │ │ 等 Done() 被调用 │ 无缓冲信道，等待有人接收…… │ │ │ 没人接收 → 无法继续 → Done() 永远不会被调用 │ │ 永远等不到 Done │ │ │ 💀 双方互等 = 死锁 │ └─────────────────────────────────────────────────────────────────┘ 正确流程（加 go 后）：\n┌─────────────────────────────────────────────────────────────────┐ │ ✅ 加 go 的情况 │ ├──────────────────┬───────────────┬──────────────────────────────┤ │ 主协程 │ 监听协程 │ pay 协程（×3） │ ├──────────────────┼───────────────┼──────────────────────────────┤ │ │ wait.Wait() │ │ │ │ 后台等待中…… │ │ │ for range 接收数据│ │ moneyChan \u0026lt;- money 发送成功 │ │ 收到数据 ✓ │ │ defer Done() 计数器 -1 │ │ ……继续接收…… │ │ ……3个协程依次完成…… │ │ │ 计数器归 0 ✓ │ │ │ │ close(chan) ✓ │ │ │ range 感知关闭 │ │ │ │ 退出循环 ✓ │ │ │ └──────────────────┴───────────────┴──────────────────────────────┘ 核心要点：无缓冲信道的发送操作（chan \u0026lt;- data）必须有对应的接收方同时就绪，否则发送方会被永久阻塞。用 go 把等待逻辑放到后台，就能让主协程腾出手来当接收方。\n13. 匿名函数与立即调用（IIFE） # 13.1 语法结构 # Go 支持匿名函数（Anonymous Function），即没有名字的函数。它可以被赋值给变量、作为参数传递，或者定义后立即调用：\n// 定义并立即调用（IIFE：Immediately Invoked Function Expression） func() { fmt.Println(\u0026#34;我被立即执行了\u0026#34;) }() // ← 这对括号 = 立刻调用 13.2 拆解理解 # 将 func() { ... }() 拆成两部分：\n部分 含义 func() { ... } 定义一个匿名函数（只是造出来了） 末尾的 () 调用这个函数（立刻执行它） 如果不加末尾的 ()，就只是一个函数值（可以赋给变量），但不会被执行。\n13.3 带参数的 IIFE # 末尾的括号和普通函数调用一样，可以传入参数：\ngo func(msg string, times int) { for i := 0; i \u0026lt; times; i++ { fmt.Println(msg) } }(\u0026#34;你好\u0026#34;, 3) // ← 传入参数并立即执行 13.4 常见用途 # 用途 示例 配合 go 启动携带逻辑的协程 go func() { ... }() 配合 defer 延迟执行一段逻辑 defer func() { fmt.Println(\u0026quot;收尾\u0026quot;) }() 闭包捕获外部变量 go func(id int) { process(id) }(i) — 避免循环变量陷阱 注意：go 关键字后面必须跟一个函数调用（而非函数定义），所以 go func() { ... } 不加 () 会编译报错。defer 同理。\n14. 无缓冲信道多通道死锁分析 # 14.1 问题场景 # 当一个协程需要向多个无缓冲信道依次发送数据，而主协程用顺序 for range 依次读取时，极易发生死锁：\nvar moneyChan = make(chan int) var nameChan = make(chan string) func pay(name string, money int, wait *sync.WaitGroup) { defer wait.Done() moneyChan \u0026lt;- money // 第一步：发送 money nameChan \u0026lt;- name // 第二步：发送 name（必须等第一步完成） } func main() { // ...启动 3 个 pay 协程... go func() { defer close(moneyChan) defer close(nameChan) wait.Wait() }() // ❌ 顺序读取：先读完 moneyChan，再读 nameChan for money := range moneyChan { moneyList = append(moneyList, money) } for name := range nameChan { nameList = append(nameList, name) } } 14.2 死锁时序分析 # ┌──────────────────────────────────────────────────────────────────────┐ │ 死锁全过程 │ ├──────────────────┬───────────────────────────────────────────────────┤ │ 主协程 │ 3 个 pay 协程 │ ├──────────────────┼───────────────────────────────────────────────────┤ │ for range │ moneyChan \u0026lt;- money ✅ 发送成功（主协程在接收） │ │ 收到 3 个 money │ │ │ 等待更多数据…… │ nameChan \u0026lt;- name ❌ 阻塞！无人接收 │ │ │ （卡死在此，defer wait.Done() 永远不会执行） │ │ 因信道未关闭 │ │ │ 无法退出循环 │ │ │ 💀 所有 Goroutine 互等 = 死锁 │ └──────────────────┴───────────────────────────────────────────────────┘ 关键链条：pay 卡在 nameChan 发送 → Done() 不执行 → Wait() 不通过 → close(moneyChan) 不执行 → 主协程的第一个 for range 永远退不出来。\n14.3 解决方案 # 方案 做法 优缺点 带缓冲信道 make(chan int, 3) 最简单，但需预知数据量 select 多路复用 用 select 同时监听多个信道 推荐，见第 19 节 合并为结构体信道 定义 PayRecord 结构体，使用单一信道 工程最佳实践，数据映射不会错乱 15. defer 执行顺序与信道关闭广播机制 # 15.1 defer 的栈结构（后进先出） # 同一函数内的多个 defer 语句使用 栈（Stack） 存储，遵循 LIFO（Last In, First Out） 原则——越靠后注册的 defer，越先被执行：\ngo func() { defer close(moneyChan) // ① 第一个入栈 → 最后执行 defer close(nameChan) // ② 第二个入栈 → 倒数第二执行 defer close(doneChan) // ③ 最后入栈 → 最先执行 wait.Wait() }() 弹栈执行顺序：close(doneChan) → close(nameChan) → close(moneyChan)\n15.2 从已关闭的 Channel 读取数据 # 这是 Go 并发编程中的一个核心特性：\n信道状态 读取行为 返回值 有数据未关闭 正常接收，无数据时阻塞等待 数据值，ok = true 已关闭且仍有缓冲数据 正常接收，不阻塞 缓冲数据值，ok = true 已关闭且无数据 立即返回，不阻塞 对应类型的零值，ok = false 未关闭且无数据 阻塞等待 — 零值示例：int 返回 0，string 返回 \u0026quot;\u0026quot;，bool 返回 false。\n15.3 关闭信道作为广播信号（Broadcast Pattern） # 利用上述特性，close(channel) 可以充当一个一对多的广播信号——所有在 \u0026lt;-channel 上等待的协程都会同时被唤醒：\nvar quit = make(chan struct{}) // 空结构体不占内存，专用于信号传递 // 100 个工作协程都在监听同一个退出信号 for i := 0; i \u0026lt; 100; i++ { go func() { for { select { case \u0026lt;-quit: fmt.Println(\u0026#34;收到退出信号，安全退出\u0026#34;) return default: // 继续执行正常工作…… } } }() } // 主协程只需关闭一次，100 个协程全部同时收到信号 close(quit) 对比：向信道发送值（quit \u0026lt;- struct{}{}）只能唤醒一个接收方；而 close(quit) 能同时唤醒所有接收方。这就是\u0026quot;广播\u0026quot;与\u0026quot;单播\u0026quot;的区别。\n15.4 defer close 顺序的实际影响 # 当多个信道的关闭存在先后依赖时，defer 的注册顺序至关重要：\n// ✅ 正确：doneChan 最先关闭，触发 select 退出，避免读到脏数据 defer close(moneyChan) // 第三个关闭 defer close(nameChan) // 第二个关闭 defer close(doneChan) // 第一个关闭 → 主协程 select 立即退出 // ❌ 危险：moneyChan/nameChan 先关闭，doneChan 最后才关闭 defer close(doneChan) // 第三个关闭（太迟了！） defer close(moneyChan) // 第二个关闭 defer close(nameChan) // 第一个关闭 // 在 doneChan 关闭之前的真空期，select 会疯狂从已关闭的信道中收到零值 16. select 多路复用与并发同步模式对比 # 16.1 select 语句：Channel 的 switch # select 是 Go 专为 Channel 设计的多路复用语句，能同时监听多个信道，哪个先就绪就执行哪个分支：\nfor { select { case money, ok := \u0026lt;-moneyChan: if ok { moneyList = append(moneyList, money) } case name, ok := \u0026lt;-nameChan: if ok { nameList = append(nameList, name) } case \u0026lt;-doneChan: // 收到退出信号，结束循环 return } } 16.2 ok 检测：防御关闭信道的零值污染 # 从 Channel 接收数据时，Go 支持双返回值形式，第二个布尔值 ok 用于判断数据是否有效：\nvalue, ok := \u0026lt;-channel ok 的值 含义 应对策略 true 数据来自正常发送（ch \u0026lt;- data） 正常处理 false 信道已关闭，value 是该类型的零值 丢弃，不处理 在多信道 + select 的场景中，必须使用 ok 检测。因为不同信道的关闭存在微小时间差，在退出信号到达之前，已关闭的信道可能会吐出大量零值污染数据。\n16.3 两种并发同步模式对比 # Go 中协调协程\u0026quot;生死同步\u0026quot;的两大核心机制：\n维度 sync.WaitGroup Channel + close 核心语义 批量等待一组任务完成（Join） 非阻塞式并发信号广播（Broadcast） 等待行为 Wait() 死等阻塞，期间什么都不能做 select 监听，等待期间可同时处理其他事务 信号方向 子 → 主（子协程报到完毕） 双向均可（主→子 或 子→主） 通知范围 无广播能力 close 一次可唤醒所有监听者 数据传递 仅提供同步，不能传递数据 天然支持数据传递 典型场景 等待 N 个爬虫全部完成后汇总结果 主协程通知所有工作协程安全退出 组合使用 常与 Channel 配合 常与 WaitGroup 配合 16.4 完整实战模式：WaitGroup + Channel + select # 以下是本次学习中使用的完整并发最佳实践模式：\nvar moneyChan = make(chan int) var nameChan = make(chan string) var doneChan = make(chan struct{}) // 空结构体信道，专用于退出信号 func pay(name string, money int, wait *sync.WaitGroup) { defer wait.Done() fmt.Printf(\u0026#34;%s 在付钱\\n\u0026#34;, name) time.Sleep(1 * time.Second) moneyChan \u0026lt;- money nameChan \u0026lt;- name } func main() { var wait sync.WaitGroup wait.Add(3) startTime := time.Now() go pay(\u0026#34;张三\u0026#34;, 2, \u0026amp;wait) go pay(\u0026#34;李四\u0026#34;, 3, \u0026amp;wait) go pay(\u0026#34;王五\u0026#34;, 5, \u0026amp;wait) // 后台协程：等待所有任务完成后，按正确顺序关闭信道 go func() { defer close(moneyChan) // 第三个关闭 defer close(nameChan) // 第二个关闭 defer close(doneChan) // 第一个关闭（触发退出信号） wait.Wait() }() var moneyList []int var nameList []string // select 多路复用：同时监听数据信道和退出信号 for { select { case money, ok := \u0026lt;-moneyChan: if ok { moneyList = append(moneyList, money) } case name, ok := \u0026lt;-nameChan: if ok { nameList = append(nameList, name) } case \u0026lt;-doneChan: fmt.Println(\u0026#34;总耗时:\u0026#34;, time.Since(startTime)) fmt.Println(moneyList) fmt.Println(nameList) return } } } Go 并发哲学：Do not communicate by sharing memory; instead, share memory by communicating.（不要通过共享内存来通信，而应该通过通信来共享内存。）以上模式正是这一哲学的典型体现——WaitGroup 保障任务全部完成，Channel 传递数据，close 广播退出信号，select 实现非阻塞多路监听。\n17. 函数闭包（Closure） # 17.1 定义：函数 + 它捕获的外部变量 # Go 中的闭包（Closure）本质上是：一个函数在创建时，顺便捕获了它所依赖的外部变量，因此以后单独调用这个函数时，仍然可以继续使用这些变量。\n换句话说：\n闭包不只是\u0026quot;一段可执行逻辑\u0026quot; 它还是\u0026quot;一个自带上下文或状态的函数实例\u0026quot; 典型形式：外层函数接收前置条件或创建局部变量，内层匿名函数引用这些变量并被返回出去。\nfunc counter() func() int { n := 0 return func() int { n++ return n } } 上例中，返回的匿名函数就是闭包，因为它捕获了外层函数中的 n。\n17.2 为什么函数调用结束后，闭包还能记住变量？ # 很多初学者容易困惑：\u0026ldquo;函数不是调用结束后，局部变量就应该消失吗？\u0026rdquo;\n关键点在于：只要外部仍然有东西继续使用这个变量，它就不能被清掉。\n以 counter 为例：\n调用 counter() 时，创建局部变量 n 返回的匿名函数引用了 n 因此外部虽然拿到的是\u0026quot;函数\u0026quot;，但这个函数内部仍然持有 n 只要这个闭包还活着，n 就必须继续存在 所以真正\u0026quot;记住状态\u0026quot;的，不是外层函数本身，而是返回出去的闭包绑定并持有了对应的外部变量。\n对比例子：\nfunc add() int { n := 0 n++ return n } add() 每次调用都会重新创建一个新的 n，因此结果总是从头开始；而闭包会复用同一个被捕获的 n。\n17.3 闭包与循环的本质区别 # 很多场景表面上看像是\u0026quot;计数\u0026quot;，似乎用 for 循环也能完成；但循环和闭包解决的并不是同一类问题。\n机制 核心作用 更适合的场景 for 循环 立即重复执行一段逻辑 已知次数的连续处理 闭包 让函数携带状态或上下文，并在未来继续使用 事件驱动、延迟调用、生成多个独立实例 例如：\ncount := 0 for i := 0; i \u0026lt; 3; i++ { count++ fmt.Println(count) } 这段代码当然能完成\u0026quot;计数\u0026quot;，但它只适合在当前这段流程里连续执行三次。一旦流程结束，计数逻辑也随之结束。\n而闭包生成的是一个\u0026quot;以后想调用就能继续调用\u0026quot;的函数实例：\nc1 := counter() c2 := counter() fmt.Println(c1()) // 1 fmt.Println(c1()) // 2 fmt.Println(c2()) // 1 这里：\nc1 和 c2 是两个独立实例 它们各自记住自己的状态 互不影响 所以可以记住一句话：\nfor 循环解决的是：重复做 闭包解决的是：带着状态做 17.4 闭包最常见的工程价值 # 闭包在 Go 中通常不是为了炫技，而是为了将\u0026quot;逻辑\u0026quot;与\u0026quot;外部上下文\u0026quot;绑定在一起，形成一个可反复调用的轻量实例。\n常见用途包括：\n计数器 / 命中次数统计\n例如记录某个函数被调用了多少次 适合封装小型状态，而不必专门定义结构体 参数预绑定 / 带前缀日志函数\n提前固定 prefix、tag 等参数 之后只需要传入真正变化的部分 func prefixLogger(prefix string) func(string) { return func(msg string) { fmt.Println(prefix, msg) } } 中间件配置注入\nWeb 开发中很常见 外层函数接收 role、timeout、path 等配置 返回真正处理请求的函数 回调、协程、延迟执行\n将当前上下文带入 goroutine、定时器、事件回调中 测试辅助代码\n生成带固定输入的断言函数、mock 行为、测试桩逻辑 17.5 闭包保存的不一定是\u0026quot;变化的状态\u0026quot; # 对闭包的理解不要只停留在\u0026quot;它会记住递增的数字\u0026quot;这一层。\n闭包捕获的外部变量，通常分为两类：\n类型 例子 特点 固定配置 prefix、role、timeout 创建后基本不变，属于上下文信息 可变状态 count、retryTimes、命中次数 会随着多次调用不断变化 因此，更准确地说：\n闭包的作用是把函数执行所需的外部上下文封装起来。这个上下文既可能是固定配置，也可能是可变状态。\n17.6 对闭包的一个实用理解模板 # 可以把闭包理解成：\n创建了一个函数实例，这个实例不仅有逻辑，还自带一份独立的上下文或状态。\n这也是为什么下面这些模式本质上都很像：\n计数器 带前缀日志函数 记住配置的中间件 带固定参数的回调函数 它们的共同点是：\n外层函数先接收前置条件或创建局部变量 返回一个内部函数 内部函数捕获这些变量 以后无论在什么地方调用它，都能继续使用当初那份上下文 17.7 常见陷阱：for 循环中的变量捕获 # 闭包捕获的是变量本身，不一定是你直觉里理解的\u0026quot;当时那个值\u0026quot;。因此在 for 循环中配合 goroutine 使用时，容易踩坑。\n推荐写法：\nfor i := 0; i \u0026lt; 3; i++ { go func(v int) { fmt.Println(v) }(i) } 这里通过参数 v 显式传值，避免多个闭包共享同一个循环变量带来的混乱。\n17.8 一句话总结 # 闭包 = 一个会记住外部上下文的函数。\n当你需要：\n预先绑定参数 封装小型状态 创建多个彼此独立的函数实例 让函数在未来调用时继续使用当前上下文 闭包通常就是一种非常自然、非常轻量的实现方式。\n18. 类型断言与 Comma-ok 惯用法 # 18.1 什么是类型断言（Type Assertion） # 在 Go 中，空接口 interface{}（Go 1.18+ 可写作 any）可以装任何类型的值。但一旦放进去，编译器就丢失了它的具体类型信息，无法直接对它做加减乘除或调用具体方法。\n类型断言就是从空接口中把值\u0026quot;拆\u0026quot;出来、还原为原本具体类型的操作。语法固定为：\n具体值 := 接口变量.(目标类型) 一眼识别法：只要看到代码里有 变量.(类型) 这种带点号和括号的写法，就是类型断言。\n18.2 两种断言方式 # 方式一：直接断言（不安全，猜错就崩溃） # var box interface{} = \u0026#34;hello\u0026#34; str := box.(string) // ✅ 猜对了，str = \u0026#34;hello\u0026#34; num := box.(int) // ❌ 猜错了，程序直接 Panic 崩溃！ 适用于上下文已经 100% 确保类型正确的场景（例如外层有 switch 判断兜底）。\n方式二：Comma-ok 断言（安全，推荐日常使用） # var box interface{} = 666 if str, ok := box.(string); ok { fmt.Println(\u0026#34;猜对了！值是:\u0026#34;, str) } else { fmt.Println(\u0026#34;猜错了，它不是字符串\u0026#34;) // 程序不会崩溃，走到这里 } 情况 ok 的值 str 的值 程序行为 类型正确 true 实际的值 正常继续 类型错误 false 该类型的零值（如 \u0026quot;\u0026quot;） 正常继续 18.3 类型断言与反射的配合 # 在反射场景中，类型断言可以用来从 interface{} 直接提取原生值，作为 reflect.Value 方法的替代方案：\nswitch v1.Elem().Kind() { case reflect.String: // 方式 A：通过类型断言从 interface{} 直接提取原生 string v1.Elem().SetString(value.(string)) case reflect.Int: // 方式 B：通过反射对象的 .Int() 方法提取原生 int64 v1.Elem().SetInt(v2.Int()) } 两种方式效果等价——SetString(value.(string)) 和 SetString(v2.String()) 都能正常工作。区别仅在于值的来源：一个直接从 interface{} 断言，另一个从 reflect.Value 反射对象中提取。\n18.4 Comma-ok 惯用法：Go 的核心设计哲学 # Go 语言中，凡是**\u0026ldquo;可能成功、也可能失败/不存在\u0026rdquo;**的操作，都倾向于返回两个值，让调用方自行决定如何处理。这种模式被称为 Comma-ok 惯用法。\n类型断言只是这种设计模式下的一个具体应用场景。\nGo 语法层面内置的三种 Comma-ok（返回 值, bool） # 场景 语法 ok = false 的含义 类型断言 v, ok := x.(string) 接口里装的不是 string 字典查询 v, ok := myMap[\u0026quot;key\u0026quot;] 字典中不存在该键 信道接收 v, ok := \u0026lt;-ch 信道已被关闭 这三种是 Go 编译器级别特殊支持的——允许你选择用 1 个值接收（出错就崩溃/取零值）或 2 个值接收（自己处理）。\n函数级别的状态返回（返回 值, error） # 除了上面三种内置语法，Go 的普通函数也大量采用类似的设计，只是把 bool 升级为了具体的错误信息 error：\nfile, err := os.Open(\u0026#34;test.txt\u0026#34;) // 文件操作 resp, err := http.Get(\u0026#34;https://...\u0026#34;) // 网络请求 rows, err := db.Query(\u0026#34;SELECT ...\u0026#34;) // 数据库查询 这些不是类型断言，但它们与类型断言共享同一种设计哲学：把正常结果和状态信息分开返回，让调用方显式处理失败情况。\n18.5 如何一眼区分\u0026quot;是不是类型断言\u0026quot; # 永远只看语法长相，不要看有几个返回值：\n代码 是否为类型断言 判断依据 v := x.(string) ✅ 是 有 .(类型) 语法 v, ok := x.(string) ✅ 是 有 .(类型) 语法 v, ok := myMap[\u0026quot;key\u0026quot;] ❌ 不是 是 Map 取值，没有 .(类型) v, ok := \u0026lt;-ch ❌ 不是 是 Channel 接收，没有 .(类型) file, err := os.Open(\u0026quot;x\u0026quot;) ❌ 不是 是普通函数调用，没有 .(类型) 18.6 类型断言的常见工程应用 # 场景一：处理 JSON 动态数据 # 前端发来结构不确定的 JSON，解析到 map[string]interface{} 后，必须用类型断言提取具体值：\nvar data map[string]interface{} json.Unmarshal(body, \u0026amp;data) if name, ok := data[\u0026#34;name\u0026#34;].(string); ok { fmt.Println(\u0026#34;用户名:\u0026#34;, name) } if age, ok := data[\u0026#34;age\u0026#34;].(float64); ok { // JSON 数字默认解析为 float64 fmt.Println(\u0026#34;年龄:\u0026#34;, int(age)) } 场景二：错误类型判断 # 判断一个 error 接口具体是哪种错误，以便做针对性处理：\nif netErr, ok := err.(net.Error); ok { if netErr.Timeout() { fmt.Println(\u0026#34;网络超时，正在重试……\u0026#34;) } } Go 1.13+ 推荐使用 errors.As() 替代直接断言，但底层原理相同。\n场景三：配合 switch 进行多类型分发 # 当一个 interface{} 可能是多种类型时，使用 type switch 语法逐一匹配：\nfunc describe(value interface{}) { switch v := value.(type) { // 特殊语法：.(type) 只能在 switch 中使用 case string: fmt.Println(\u0026#34;字符串，长度:\u0026#34;, len(v)) case int: fmt.Println(\u0026#34;整数，值:\u0026#34;, v) case bool: fmt.Println(\u0026#34;布尔值:\u0026#34;, v) default: fmt.Println(\u0026#34;未知类型\u0026#34;) } } value.(type) 是类型断言的变体语法，只能出现在 switch 语句中，用于同时判断类型并提取值。\n18.7 总结 # 概念 含义 类型断言 从 interface{} 中提取具体类型值的特定语法操作（x.(Type)） Comma-ok Go 语言的设计哲学/编码风格，用双返回值让调用方自行处理失败情况 两者的关系 类型断言的安全写法应用了 Comma-ok 模式，但 Comma-ok 的范围远大于类型断言 口诀：看到 .(类型) 就是类型断言；看到 值, ok/err 就是 Comma-ok 思想。前者是后者的一个子集。\n","date":"2026年6月9日","externalUrl":null,"permalink":"/posts/go-basic/getting-started-with-golang/","section":"文章","summary":"","title":"初识 GoLang","type":"posts"},{"content":" 1. LLaMA: Open and Efficient Foundation Language Models # 1.1 前言 # ​\tLLaMA是一个系列模型，模型参数量从7B到65B。在大部分的任务上，LLaMA-13B强于GPT-3(175B)。LLaMA-65B的性能，可以和最好的LM相媲美，如Chinchilla-70B 和 PaLM-540B。\n​\t一般而言，模型越大，效果越好。如以 GPT-3 为代表的大语言模型在海量文本集合上训练，展示出了惊人的涌现能力以及零样本迁移和少样本学习能力。GPT-3 把模型的量级缩放到了 175B，也使得后面的研究工作继续去放大语言模型的量级。大家好像有一个共识，就是：模型参数量级的增加就会带来同样的性能提升。\n​\t最近的 \u0026ldquo;Training Compute-Optimal Large Language Models\u0026rdquo; 这篇论文提出一种缩放定律 (Scaling Law)：\n训练大语言模型时，在计算成本达到最优情况下，模型大小和训练数据 (token) 的数量应该比例相等地缩放，即：如果模型的大小加倍，那么训练数据的数量也应该加倍。\n​\t即当我们给定特定的计算成本预算的前提下，语言模型的最佳性能不仅仅可以通过设计较大的模型搭配小一点的数据集得到，也可以通过设计较小的模型配合大量的数据集得到。\n​\t那么，相似成本训练 LLM，是大 LLM 配小数据训练，还是小 LLM 配大数据训练更好？\n​\t**缩放定律 ** 告诉我们对于给定的特定的计算成本预算，如何去匹配最优的模型和数据的大小。但是本文作者团队认为，这个功能只考虑了总体的计算成本，忽略了推理时候的成本。因为大部分社区用户其实没有训练 LLM 的资源，他们更多的是拿着训好的 LLM 来推理。在这种情况下，我们首选的模型应该不是训练最快的，而应该是推理最快的 LLM。呼应上题，本文认为答案就是：小 LLM 配大数据训练更好，因为小 LLM 推理更友好。\n​\tLLaMa 沿着小 LLM 配大数据训练的指导思想，训练了一系列性能强悍的语言模型，参数量从 7B 到 65B。例如，LLaMA-13B 比 GPT-3 小10倍，但是在大多数基准测试中都优于 GPT-3。大一点的 65B 的 LLaMa 模型也和 Chinchilla 或者 PaLM-540B 的性能相当。\n​\t同时，LLaMa 模型只使用了公开数据集，开源之后可以复现。但是大多数现有的模型都依赖于不公开或未记录的数据完成训练。\n1.2 预训练数据 # 1.2.1 数据集 # ​\tLLaMa 预训练数据大约包含 1.4T tokens，对于绝大部分的训练数据，在训练期间模型只见到过1次，Wikipedia 和 Books 这两个数据集见过2次。\n​\t如下图所示是 LLaMa 预训练数据的含量和分布，其中包含了 CommonCrawl 和 Books 等不同域的数据。\n​\tCommonCrawl (占 67%)： 包含 2017 到 2020 的5个版本，预处理部分包含：删除重复数据，去除掉非英文的数据，并通过一个 n-gram 语言模型过滤掉低质量内容。\n​\tC4 (Colossal Clean Crawled Corpus 占 15%)： 在探索性实验中，作者观察到使用不同的预处理 CommonCrawl 数据集可以提高性能，因此在预训练数据集中加了 C4。预处理部分包含：删除重复数据，过滤的方法有一些不同，主要依赖于启发式方法，例如标点符号的存在或网页中的单词和句子的数量。\n​\tGithub (占 4.5%)： 在 Github 中，作者只保留在 Apache、BSD 和 MIT 许可下的项目。此外，作者使用基于行长或字母数字字符比例的启发式方法过滤低质量文件，并使用正则表达式删除标题。最后使用重复数据删除。\n​\tWikipedia (占 4.5%)： 作者添加了 2022 年 6-8 月的 Wikipedia 数据集，包括 20 种语言，作者处理数据以删除超链接、评论和其他格式样板。\n​\tGutenberg and Books3 (占 4.5%)： 作者添加了两个书的数据集，分别是 Gutenberg 以及 ThePile (训练 LLM 的常用公开数据集) 中的 Book3 部分。处理数据时作者执行重复数据删除，删除内容重叠超过 90% 的书籍。\n​\tArXiv (占 2.5%)： 为了添加一些科学数据集，作者处理了 arXiv Latex 文件。作者删除了第一部分之前的所有内容，以及参考文献。还删除了 .tex 文件的评论，以及用户编写的内联扩展定义和宏，以增加论文之间的一致性。\n​\tStack Exchange (占 2%)： 作者添加了 Stack Exchange，这是一个涵盖各种领域的高质量问题和答案网站，范围从计算机科学到化学。作者从 28 个最大的网站保留数据，从文本中删除 HTML 标签并按分数对答案进行排序。\n1.2.2 Tokenzier # ​\t使用byte pair encoding (BPE) 算法，使用的是Sentence-Piece的实现。所有数字被拆分为单独的digit，所有未知的UTF-8 字符，回退到字节来进行分解。因此，LLaMA 可以通过byte 的方式，构造出很多不在 vocab 中的字符，从而也具有较好的多语言能力。\nimport os # 导入 os 模块，用于操作系统功能，如文件路径 from logging import getLogger # 从 logging 模块导入 getLogger，用于日志记录 from typing import List # 从 typing 模块导入 List，用于指定列表类型的注释 from sentencepiece import SentencePieceProcessor # 从 sentencepiece 导入 SentencePieceProcessor，用于文本分词和编码/解码 logger = getLogger() # 获取一个日志记录器对象 class Tokenizer: \u0026#34;\u0026#34;\u0026#34;使用 SentencePiece 进行文本的分词和编码/解码。\u0026#34;\u0026#34;\u0026#34; def __init__(self, model_path: str): \u0026#34;\u0026#34;\u0026#34; 使用 SentencePiece 模型初始化 Tokenizer。 参数: model_path (str): SentencePiece 模型文件的路径。 \u0026#34;\u0026#34;\u0026#34; # 重新加载分词器 assert os.path.isfile(model_path), model_path # 断言模型路径是一个文件，如果不是则抛出异常 self.sp_model = SentencePieceProcessor(model_file=model_path) # 加载 SentencePiece 模型 logger.info(f\u0026#34;Reloaded SentencePiece model from {model_path}\u0026#34;) # 记录日志，表示模型已重新加载 # 设置开始（BOS）/结束（EOS）标记的 ID self.n_words: int = self.sp_model.vocab_size() # 词汇表大小 self.bos_id: int = self.sp_model.bos_id() # 开始标记的 ID self.eos_id: int = self.sp_model.eos_id() # 结束标记的 ID self.pad_id: int = self.sp_model.pad_id() # 填充（PAD）标记的 ID logger.info( f\u0026#34;#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}\u0026#34; ) # 记录词汇表大小和各标记的 ID assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() # 确保词汇表大小与分词器的大小一致 def encode(self, s: str, bos: bool, eos: bool) -\u0026gt; List[int]: \u0026#34;\u0026#34;\u0026#34; 将字符串编码成一个 token ID 列表。 参数: s (str): 要编码的输入字符串。 bos (bool): 是否在序列开始处添加开始标记。 eos (bool): 是否在序列结束处添加结束标记。 返回: List[int]: token ID 的列表。 \u0026#34;\u0026#34;\u0026#34; assert type(s) is str # 断言输入是字符串类型 t = self.sp_model.encode(s) # 使用 SentencePiece 模型编码字符串 if bos: t = [self.bos_id] + t # 如果需要，添加开始标记 if eos: t = t + [self.eos_id] # 如果需要，添加结束标记 return t def decode(self, t: List[int]) -\u0026gt; str: \u0026#34;\u0026#34;\u0026#34; 将 token ID 列表解码成字符串。 参数: t (List[int]): 要解码的 token ID 列表。 返回: str: 解码后的字符串。 \u0026#34;\u0026#34;\u0026#34; return self.sp_model.decode(t) # 使用 SentencePiece 模型解码 token ID 列表 1.2.3 BPE算法 # ​\t**NLP中分词的概念如下：**执行分词的算法模型称为分词器（Tokenizer） ，划分好的一个个词称为 Token （为啥不直接叫 Word？接着往后看），这个过程称为 Tokenization 。\n​\t我们将一个个的 token（可以理解为小片段）表示向量，我们分词的目的就是尽可能的让这些向量蕴含更多有用的信息，然后把这些向量输入到算法模型中。\n​\t由于一篇文本的词往往太多了，为了方便算法模型训练，我们会选取出频率 （也可能是其它的权重）最高的若干个词组成一个词表（Vocabulary） 。\n​\t我们知道，一门语言中，通常有几万到几十万量级的单词数。若使用这种编码方式（one-hot），在语言模型预测的时候需要在这个拥有几万个单词的列表上计算一个概率分布，那样的计算量是非常恐怖的，而且过大的token列表十分影响模型的预测准确度。在GPT-3提出以后，又增加了prompt的feature，其特点之一就是用户可以指定将源语言翻译成某一种语言。举个例子，若是我们输入：\nEnglish: Let\u0026rsquo;s have a drink tonight.\nFrench:\n模型就能输出一句与\u0026quot;Let\u0026rsquo;s have a drink tonight.\u0026ldquo;所对应的法语翻译。要是\u0026quot;French：\u0026ldquo;改成\u0026quot;Spanish：\u0026quot;，那模型将输出对应的西班牙语翻译。\n​\t随着模型集成的不同国家的语言越来越多，模型的词汇列表势必会增长到一个非常可怕的数量级，到时候该如何去处理它带来的矩阵内存占用和预测准确性问题呢？\n​\t别急，有一种编码方式能大大减小token list，那就是即将介绍的Byte Pair Encoding(BPE)\n​\tBPE 最早由 Philip Gage 在 1994 年提出，用于数据压缩领域。其核心思想是通过迭代合并频率最高的字节对（byte pair），将原始数据压缩为更紧凑的表示。2015 年，Sennrich 等人将 BPE 引入 NLP，用于神经机器翻译（Neural Machine Translation, NMT），并将其适配为一种子词级别（subword-level）的分词方法。\n​\tBPE 的基本思想可以用一句话概括：从字符级别开始，通过统计频率最高的字符对或子词对，逐步构建一个词汇表，用于表示文本中的单词或子词单元。 这种方法既能保留词的语义信息，又能灵活处理未见过的新词，在深度学习模型中表现出色。\n​\tBPE 的工作原理与实现步骤\n​\tBPE 的实现分为两个主要阶段：训练阶段（构建词汇表）和应用阶段（分词）。以下是详细步骤：\n1. 训练阶段：构建词汇表 **初始化：**输入一个大规模的语料库（corpus），例如一堆句子。对每个单词进行预分词，通常以字符为单位，并在每个单词末尾添加一个特殊标记（如 ），以区分词内字符和词间边界。例如，单词 “cat” 被初始化为 c a t 。 统计语料库中所有单词的初始表示及其出现频率。例如：\n\u0026#34;low\u0026#34;: l o w \u0026lt;/w\u0026gt;, 5次 \u0026#34;lower\u0026#34;: l o w e r \u0026lt;/w\u0026gt;, 3次 \u0026#34;new\u0026#34;: n e w \u0026lt;/w\u0026gt;, 4次 统计字符对频率：\n​\t遍历语料库，统计所有相邻字符对（或子词对）的出现频率。例如，在上面的例子中，可能会统计到：\nl o: 8次（5次来自 \u0026#34;low\u0026#34;，3次来自 \u0026#34;lower\u0026#34;） o w: 8次（5次来自 \u0026#34;low\u0026#34;，3次来自 \u0026#34;lower\u0026#34;） w \u0026lt;/w\u0026gt;: 9次（5次来自 \u0026#34;low\u0026#34;，4次来自 \u0026#34;new\u0026#34;） 合并频率最高的字符对：\n​\t选择频率最高的字符对进行合并。例如，假设 l o 是频率最高的对，则将其合并为 lo，更新语料库中的表示：\n\u0026#34;low\u0026#34;: lo w \u0026lt;/w\u0026gt;, 5次 \u0026#34;lower\u0026#34;: lo w e r \u0026lt;/w\u0026gt;, 3次 \u0026#34;new\u0026#34;: n e w \u0026lt;/w\u0026gt;, 4次 迭代执行：\n​\t重复步骤 2 和 3，合并频率最高的字符对，直到达到预定的词汇表大小（vocabulary size，例如 10,000）或迭代次数上限。每次合并都会生成新的子词单元。例如，下一次可能合并 lo w 为 low，最终词汇表可能包含：\n[l, o, w, e, r, n, \u0026lt;/w\u0026gt;, lo, low, new, ...] 输出词汇表：\n​\t训练完成后，得到一个包含字符和子词的词汇表，用于后续的分词。 2. 应用阶段：分词 ​\t在应用阶段，BPE 使用训练好的词汇表将新输入的文本进行分词。具体步骤如下：\n**单词拆分为字符：**对于输入单词（如 “lowest”），先将其拆分为字符序列并添加词尾标记：l o w e s t 。 贪心合并：\n​\t根据训练阶段生成的词汇表，依次尝试合并字符对，优先选择词汇表中最长的子词单元。例如： ​\t检查 l o，发现 lo 在词汇表中，合并为 lo w e s t 。 ​\t检查 lo w，发现 low 在词汇表中，合并为 low e s t 。 ​\t检查 e s，不在词汇表中，继续检查 e s t，不在词汇表中，最终结果可能是 low e s t 。 ​\t输出子词序列：\n​\t最终输出分词结果：[low, e, s, t]，作为模型的输入 token。\n​\t为了以最有效的方式构建语料库，BPE 在迭代的时候通过比较token的频率大小来穷尽每一种可能。所以，是的，它遵循一种贪婪的策略来尽可能取得最优的解决方案。\n​\t无论如何，BPE 是使用最广泛的sub-word tokenization算法之一。尽管贪婪，但它具有良好的性能！并被作为机器翻译等主流NLP任务的首选tokenize方法之一。\n1.3 网络结构改进 # ​\t在LLaMa中使用了基于transformer的架构，并做了如下3点改进：\n1.3.1 Pre-Normalization # ​\t为了提高训练稳定性，LLaMa 对每个 Transformer 的子层的输入进行归一化，而不是对输出进行归一化。同时，使用 RMSNorm归一化函数。\n​\tRoot Mean Square Normalization，是一种归一化技术，用于深度神经网络中，特别是在处理序列数据时（如在自然语言处理任务中）。这种技术的目的是通过调整网络层的输入来改善训练过程的稳定性和效率。\n常规的 Layer Normalization： aiˉ=ai−μσgi, yi=f(aiˉ+bi) \\bar{a_i}=\\frac{a_i-\\mu}{\\sigma}g_i,\\space\\space y_i=f(\\bar{a_i}+b_i) ai​ˉ​=σai​−μ​gi​, yi​=f(ai​ˉ​+bi​) ​\t式中，gig_igi​ 和 是 bib_ibi​ LN 的缩放系数和平移项，这两个参数都是可训练的，训练过程中网络会自动学习它们，以便在标准化之后仍能保留或恢复必要的信息。，μ\\muμ 和 σ\\sigmaσ 的计算如下式所示： μ=1n∑i=1nai, σ=1n∑i=1n(ai−μ)2 \\mu=\\frac{1}{n}\\sum_{i=1}^na_i,\\space \\sigma=\\sqrt{\\frac{1}{n}\\sum_{i=1}^n(a_i-\\mu)^2} μ=n1​i=1∑n​ai​, σ=n1​i=1∑n​(ai​−μ)2​ **RMSNorm：**相当于是去掉了 μ\\muμ 这一项。 aiˉ=aiRMS(a)gi, where RMS(a)=1n∑i=1nai2 \\bar{a_i}=\\frac{a_i}{\\text{RMS(a)}}g_i,\\space \\text{where} \\space \\text{RMS(a)}=\\sqrt{\\frac{1}{n}\\sum_{i=1}^na_i^2} ai​ˉ​=RMS(a)ai​​gi​, where RMS(a)=n1​i=1∑n​ai2​​看上去就这一点小小的改动，有什么作用呢？RMSNorm 的原始论文进行了一些不变性的分析和梯度上的分析。\n和 LayerNorm 对比：\n计算复杂性：RMSNorm 的计算通常比 LayerNorm 简单，因为它不涉及计算均值，这可以在减少约 7%∼64% 的计算时间。 对于序列长度的适应性：两者都适用于处理序列数据，尤其是在自然语言处理中。但 RMSNorm 在处理非常长的序列时可能表现更好，因为它不依赖于均值的计算。 稳定性和效率：两者都旨在提高训练过程的稳定性和效率，但它们在处理不同类型的数据集和网络结构时的表现可能有所不同 具体代码如下：\nimport torch import torch.nn as nn class RMSNorm(torch.nn.Module): # 定义一个继承自 torch.nn.Module 的 RMSNorm 类 def __init__(self, dim: int, eps: float = 1e-6): \u0026#34;\u0026#34;\u0026#34; 初始化 RMSNorm 归一化层。 参数: dim (int): 输入张量的维度。 eps (float, optional): 为了数值稳定性添加到分母的小值，默认为 1e-6。 属性: eps (float): 为了数值稳定性添加到分母的小值。 weight (nn.Parameter): 可学习的缩放参数。 \u0026#34;\u0026#34;\u0026#34; super().__init__() # 调用父类的初始化方法 self.eps = eps # 存储数值稳定性参数 self.weight = nn.Parameter(torch.ones(dim)) # 初始化可学习的缩放参数 def _norm(self, x): \u0026#34;\u0026#34;\u0026#34; 对输入张量应用 RMSNorm 归一化。 参数: x (torch.Tensor): 输入张量。 返回: torch.Tensor: 归一化后的张量。 \u0026#34;\u0026#34;\u0026#34; # 计算归一化值，使用均方根（RMS）作为分母 return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): \u0026#34;\u0026#34;\u0026#34; RMSNorm 层的前向传播。 参数: x (torch.Tensor): 输入张量。 返回: torch.Tensor: 应用 RMSNorm 后的输出张量。 \u0026#34;\u0026#34;\u0026#34; # 将输入张量转换为 float 类型，应用归一化，然后恢复原始类型 output = self._norm(x.float()).type_as(x) # 应用可学习的缩放参数并返回结果 return output * self.weight 1.3.2 SwiGLU # ​\t具体各种激活函数的详细介绍请看 LLM 通用基础\n​\tLLaMa 使用 SwiGLU 激活函数替换 ReLU 非线性以提高性能，维度从 4d4d4d 变为 234d\\frac{2}{3}4d32​4d，具体推导如下：\n​\t原来的FFN有两个MLP层，这两个MLP层的参数量分别为：h×4hh \\times 4hh×4h 和 4h×h4h \\times h4h×h，总的参数量为 8h28h^28h2。\n​\tSwiGLU 的公式为： FFNSwiGLU(x,W,V)=[Swish(xW)⊗(xV)]W2 \\text{FFN}_\\text{SwiGLU}(x,W,V)=[\\text{Swish}(xW) \\otimes (xV)]W_2 FFNSwiGLU​(x,W,V)=[Swish(xW)⊗(xV)]W2​ ​\t从上述公式中可以知道，矩阵 WWW 与矩阵 VVV 的维度是相同的，其作用是对输入向量 xxx 进行升维；矩阵 W2W_2W2​ 的作用是将高维的隐向量还原到和输入向量 x 相同的维度。所以 W、V、W2W、V、W_2W、V、W2​ 这三个矩阵的维度分别为：(h,αh)、(h,αh)、(αh,h)(h,\\alpha h)、(h,\\alpha h)、(\\alpha h,h)(h,αh)、(h,αh)、(αh,h)，总的参数量为 3αh23\\alpha h^23αh2。为了保持和原始的 FFN 参数量相同，有：8h2=3αh28h^2=3\\alpha h^28h2=3αh2\n​\t解得 α=83\\alpha =\\frac{8}{3}α=38​，最终 W、V、W2W、V、W_2W、V、W2​ 这三个矩阵的维度分别为：(h,83h)、(h,83h)、(83h,h)(h,\\frac{8}{3}h)、(h,\\frac{8}{3}h)、(\\frac{8}{3}h,h)(h,38​h)、(h,38​h)、(38​h,h)，可以很明显的看出严格按照该公式计算出来的不是整数，所以使用该公式计算出来的是模型真实维度的近似值。\n​\t各个激活函数图像如下所示：\nLLaMa中FFN的代码如下所示：\nimport torch import torch.nn as nn import torch.nn.functional as F class FeedForward(nn.Module): def __init__( self, dim: int, # 输入维度（等于 Transformer 中的 hidden_size） hidden_dim: int, # FFN 隐藏层维度（原始 Transformer 通常为 4 × dim） multiple_of: int, # 用于将隐藏层维度对齐到指定倍数（如 64 的倍数，以加速 GPU 计算） ): super().__init__() # ============================================================ # Step 1: 调整隐藏层维度（LLaMA 特有） # 原始 Transformer 使用 hidden_dim = 4 × dim # LLaMA 使用 SwiGLU 激活，因此将维度缩小为原来的 2/3， # 即 hidden_dim = (2/3) × (4 × dim) ≈ 2.67 × dim # 这样可以在保持性能的同时减少参数量和计算量。 # ============================================================ hidden_dim = int(2 * hidden_dim / 3) # ============================================================ # Step 2: 对齐 hidden_dim 至 multiple_of 的倍数 # 例如 multiple_of = 64 时，将 hidden_dim 向上取整为 64 的倍数， # 便于 GPU 矩阵乘法在 CUDA Tensor Core 上高效执行。 # ============================================================ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) # ============================================================ # Step 3: 定义前馈网络的线性层 # LLaMA 的 FFN 采用 SwiGLU 结构： # FFN(x) = W2( SiLU(W1(x)) * W3(x) ) # # - W1 和 W3 都是升维层（输入维度 dim → hidden_dim） # - W2 是降维层（hidden_dim → dim） # - “*” 表示逐元素相乘（门控机制） # ============================================================ # W1：输入升维，用于生成主特征分支 A self.w1 = ColumnParallelLinear( dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x ) # W2：输出降维，将门控结果投影回原维度 self.w2 = RowParallelLinear( hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x ) # W3：输入升维，用于生成门控分支 B self.w3 = ColumnParallelLinear( dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x ) def forward(self, x): # ============================================================ # 前向传播过程： # 1. w1(x)：主分支特征 A # 2. w3(x)：门控分支特征 B # 3. F.silu(w1(x))：对主分支使用 SiLU 激活函数（Swish） # 4. F.silu(w1(x)) * w3(x)：门控机制，控制信息通过比例 # 5. w2(...)：将结果映射回原维度 # # 最终公式： # FFN(x) = W2( SiLU(W1(x)) * W3(x) ) # ============================================================ return self.w2(F.silu(self.w1(x)) * self.w3(x)) 这里需要注意的点是： 激活函数用的是 F.silu()，也就是 Swish 激活函数。 self.w2(F.silu(self.w1(x)) * self.w3(x)) 的实现也就是 SwiGLU 激活函数\n1.3.3 RoPE # 绝对位置编码：\n在经典 Transformer 中，每个词的 embedding x∈Rdx \\in \\mathbb{R}^{d}x∈Rd 会加上一个位置向量 PE(pos)PE(pos)PE(pos)：z=x+PE(pos)z = x + PE(pos)z=x+PE(pos)\n而 PE(pos)PE(pos)PE(pos) 的每个维度的值由固定函数生成： PE(pos,2i)=sin⁡(pos/100002i/d) PE_{(pos, 2i)} = \\sin(pos / 10000^{2i/d}) PE(pos,2i)​=sin(pos/100002i/d)PE(pos,2i+1)=cos⁡(pos/100002i/d) PE_{(pos, 2i+1)} = \\cos(pos / 10000^{2i/d}) PE(pos,2i+1)​=cos(pos/100002i/d)即：\n每个维度 i 都有一个对应的频率； 模型靠加法 x+PEx + PEx+PE 把位置信息“叠加”进 embedding 向量。 旋转位置编码：\nRoPE（Su et al., 2021, RoFormer）采取了完全不同的策略：\n不再“相加”位置信息，而是通过旋转变换把位置信息直接嵌入到向量的几何空间结构中。\n直观理解：\n对于每个位置 pos，我们把词向量的部分维度成对看作平面坐标，然后对它们旋转一个角度。\n​\t不同于原始 Transformers 论文中，将 pos embedding 和 token embedding 进行相加，RoPE 是将位置编码和 query （或者 key） 进行相乘。具体如下：\n​\t也就是说，给位置为 mmm 的向量 qqq 乘上矩阵 RmR_mRm​、位置为 nnn 的向量 kkk 乘上矩阵 RnR_nRn​，用变换后的 Q,KQ,KQ,K 序列做Attention，那么Attention就自动包含相对位置信息了，因为成立恒等式： (Rmq)T(Rnk)=qTRmTRnk=qTRn−mk (R_mq)^\\text{T}(R_nk)=q^\\text{T}R_m^\\text{T}R_nk=q^\\text{T}R_{n−m}k (Rm​q)T(Rn​k)=qTRmT​Rn​k=qTRn−m​k ​\t值得指出的是，RmR_mRm​是一个正交矩阵，它不会改变向量的模长，因此通常来说它不会改变原模型的稳定性。\n​\t由于 RmR_mRm​ 的稀疏性，所以直接用矩阵乘法来实现会很浪费算力，推荐通过下述方式来实现RoPE：\n​\t其中 ⊗\\otimes⊗ 是逐位对应相乘，θi=1000−2i/d\\theta_i=1000^{-2i/d}θi​=1000−2i/d 。\nRoPE实现代码如下：\n# 代码增加了注释，可以看到和原始公式的对应关系。 class LlamaRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() # 此处 inv_freq 对应公式中的 theta inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) self.register_buffer(\u0026#34;inv_freq\u0026#34;, inv_freq) self.max_seq_len_cached = max_position_embeddings t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) # 此处 freqs 对应公式中的 m * theta, t 对应公式中的 m，表示位置 freqs = torch.einsum(\u0026#34;i,j-\u0026gt;ij\u0026#34;, t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation # 此处和原始公式不同，theta_0 和 theta_0 不再相邻 # 而是分在向量的前半部分和后半部分 emb = torch.cat((freqs, freqs), dim=-1) dtype = torch.get_default_dtype() self.register_buffer(\u0026#34;cos_cached\u0026#34;, emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer(\u0026#34;sin_cached\u0026#34;, emb.sin()[None, None, :, :].to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len \u0026gt; self.max_seq_len_cached: self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.einsum(\u0026#34;i,j-\u0026gt;ij\u0026#34;, t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.register_buffer(\u0026#34;cos_cached\u0026#34;, emb.cos()[None, None, :, :].to(x.dtype), persistent=False) self.register_buffer(\u0026#34;sin_cached\u0026#34;, emb.sin()[None, None, :, :].to(x.dtype), persistent=False) # 大部分情况下，直接从这里返回 return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) def rotate_half(x): \u0026#34;\u0026#34;\u0026#34;Rotates half the hidden dims of the input.\u0026#34;\u0026#34;\u0026#34; # 此次和原始推导中不同，正负号不是间隔的，而是分前半部分和后半部分。但对于结果没有影响 x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] # 对应上图中 RoPE 的简化计算 q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed 1.4 高效实现 # 1.4.1LLaMa的优化与高效实现 # ​\tAdamW， β1=0.9,β2=0.95\\beta_1=0.9, \\beta_2=0.95β1​=0.9,β2​=0.95，使用 cosine 学习率衰减策略，2000 步的 warm-up，最终学习率等于最大学习率的 10%，使用 0.1 的权重衰减和 1.0 的梯度裁剪。\n​\t**快速的注意力机制：**LLaMa 采用了高效的 causal multi-head attention (基于 xformers)，不存储注意力权重，且不计算 mask 掉的 query 和 key 的值。\n​\t**手动实现反向传播过程，不使用 PyTorch autograd：**使用 checkpointing 技术减少反向传播中的激活值的计算，更准确地说，LLaMa 保存计算代价较高的激活值，例如线性层的输出。\n​\t通过使用模型和序列并行减少模型的内存使用。此外，LLaMa 还尽可能多地重叠激活的计算和网络上的 GPU 之间的通信。\n​\tLLaMa-65B 的模型使用 2048 块 80G 的 A100 GPU，在 1.4T token 的数据集上训练 21 天。\n1.4.2 其他代码实现： # Self-Attention 的 PyTorch 代码：\nclass Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size() self.head_dim = args.dim // args.n_heads self.wq = ColumnParallelLinear( args.dim, args.n_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x, ) self.wk = ColumnParallelLinear( args.dim, args.n_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x, ) self.wv = ColumnParallelLinear( args.dim, args.n_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x, ) self.wo = RowParallelLinear( args.n_heads * self.head_dim, args.dim, bias=False, input_is_parallel=True, init_method=lambda x: x, ) self.cache_k = torch.zeros( (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) ).cuda() self.cache_v = torch.zeros( (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) ).cuda() def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) self.cache_k = self.cache_k.to(xq) self.cache_v = self.cache_v.to(xq) self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv keys = self.cache_k[:bsz, : start_pos + seqlen] values = self.cache_v[:bsz, : start_pos + seqlen] xq = xq.transpose(1, 2) keys = keys.transpose(1, 2) values = values.transpose(1, 2) scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) if mask is not None: scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) output = output.transpose( 1, 2 ).contiguous().view(bsz, seqlen, -1) return self.wo(output) 这里有几个地方值得注意一下： 首先是 model.py 文件里面从 fairscale 中 import 了3个类，分别是：ParallelEmbedding，RowParallelLinear，和 ColumnParallelLinear。 Fairscale 链接如下，是一个用于高性能大规模预训练的库，LLaMa 使用了其 ParallelEmbedding 去替换 Embedding， 使用了其 RowParallelLinear 和 ColumnParallelLinear 去替换 nn.Linear，猜测可能是为了加速吧。\n另一个需要注意的点是：cache 的缓存机制，可以看到在构造函数里面定义了下面两个东西： self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda() self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)).cuda()\n关键其实就是这几行代码： self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv keys = self.cache_k[:bsz, : start_pos + seqlen] values = self.cache_v[:bsz, : start_pos + seqlen]\n在训练的时候，因为每次都是输入完整的一句话，所以 cache 机制其实是不发挥作用的。 在推理的时候，比如要生成 \u0026ldquo;I have a cat\u0026rdquo;，过程是： 1 输入 \u0026lt;s\u0026gt;，生成 \u0026lt;s\u0026gt; I。 2 输入 \u0026lt;s\u0026gt; I，生成 \u0026lt;s\u0026gt; I have。 3 输入 \u0026lt;s\u0026gt; I have，生成 \u0026lt;s\u0026gt; I have a。 4 输入 \u0026lt;s\u0026gt; I have a，生成 \u0026lt;s\u0026gt; I have a cat。\n在执行3这一步时，计算 \u0026ldquo;a\u0026rdquo; 的信息时，还要计算 \u0026lt;s\u0026gt; I have 的 Attention 信息，比较复杂。因此，cache 的作用就是在执行2这一步时，提前把 \u0026lt;s\u0026gt; I have 的 keys 和 values 算好，并保存在 self.cache_k 和 self.cache_v 中。在执行3这一步时，计算 Attention 所需的 keys 和 values 是直接从这里面取出来的： keys = self.cache_k[:bsz, : start_pos + seqlen] values = self.cache_v[:bsz, : start_pos + seqlen] 只需要额外地计算 \u0026ldquo;a\u0026rdquo; 的 keys 和 values 即可，这对模型的快速推理是至关重要的。\n还有一个值得注意的点：self.cache_k = self.cache_k.to(xq) 这里使用的是 to() 函数的一种不太常见的用法：torch.to(other, non_blocking=False, copy=False)→Tensor Returns a Tensor with same torch.dtype and torch.device as the Tensor other.\nTransformer Block 的 PyTorch 代码：\nclass TransformerBlock(nn.Module): def __init__(self, layer_id: int, args: ModelArgs): super().__init__() self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.dim // args.n_heads self.attention = Attention(args) self.feed_forward = FeedForward( dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of ) self.layer_id = layer_id self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask) out = h + self.feed_forward.forward(self.ffn_norm(h)) return out Transformer 的 PyTorch 代码：\nclass Transformer(nn.Module): def __init__(self, params: ModelArgs): super().__init__() self.params = params self.vocab_size = params.vocab_size self.n_layers = params.n_layers self.tok_embeddings = ParallelEmbedding( params.vocab_size, params.dim, init_method=lambda x: x ) self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(layer_id, params)) self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = ColumnParallelLinear( params.dim, params.vocab_size, bias=False, init_method=lambda x: x ) self.freqs_cis = precompute_freqs_cis( self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 ) @torch.inference_mode() def forward(self, tokens: torch.Tensor, start_pos: int): _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) self.freqs_cis = self.freqs_cis.to(h.device) freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] mask = None if seqlen \u0026gt; 1: mask = torch.full((1, 1, seqlen, seqlen), float(\u0026#34;-inf\u0026#34;), device=tokens.device) mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask) h = self.norm(h) output = self.output(h[:, -1, :]) # only compute last logits return output.float() self.tok_embeddings 用的是 ParallelEmbedding 这个函数，把 ids 变为词向量。 mask 部分通过 torch.full() 函数和 torch.triu() 函数得到一个上三角矩阵，用于注意力的计算。 通过 torch.nn.ModuleList() 函数定义所有的 Transformer Block。 所有的 norm 函数都使用 RMSNorm 去定义。\n生成过程的 PyTorch 代码：\nclass LLaMA: def __init__(self, model: Transformer, tokenizer: Tokenizer): self.model = model self.tokenizer = tokenizer def generate( self, prompts: List[str], max_gen_len: int, temperature: float = 0.8, top_p: float = 0.95, ) -\u0026gt; List[str]: bsz = len(prompts) params = self.model.params assert bsz \u0026lt;= params.max_batch_size, (bsz, params.max_batch_size) prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts] min_prompt_size = min([len(t) for t in prompt_tokens]) max_prompt_size = max([len(t) for t in prompt_tokens]) total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) tokens = torch.full((bsz, total_len), self.tokenizer.pad_id).cuda().long() for k, t in enumerate(prompt_tokens): tokens[k, : len(t)] = torch.tensor(t).long() input_text_mask = tokens != self.tokenizer.pad_id start_pos = min_prompt_size prev_pos = 0 for cur_pos in range(start_pos, total_len): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) if temperature \u0026gt; 0: probs = torch.softmax(logits / temperature, dim=-1) next_token = sample_top_p(probs, top_p) else: next_token = torch.argmax(logits, dim=-1) next_token = next_token.reshape(-1) # only replace token if prompt has already been generated next_token = torch.where( input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token ) tokens[:, cur_pos] = next_token prev_pos = cur_pos decoded = [] for i, t in enumerate(tokens.tolist()): # cut to max gen len t = t[: len(prompt_tokens[i]) + max_gen_len] # cut to eos tok if any try: t = t[: t.index(self.tokenizer.eos_id)] except ValueError: pass decoded.append(self.tokenizer.decode(t)) return decoded def sample_top_p(probs, p): probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort \u0026gt; p probs_sort[mask] = 0.0 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token 这里需要注意的是： torch.multinomial() 函数用于按照一定的概率 (probs_sort) 采样一定数量 (num_samples) 的 Tensor。 torch.gather() 函数是一个抽数据的函数，按照 probs_idx 的索引和 dim=-1 的维度。\n1.5 主要结果与结论 # 1.5.1 主要成果 # 1.5.2 结论 # ​\t在本文中，作者介绍了一系列公开发布且与最先进的基础模型具有竞争力的语言模型。最值得注意的是，LLaMA-13B 在体积仅为 GPT-3 十分之一的情况下性能更优，而 LLaMA-65B 则与 Chinchilla-70B 和 PaLM-540B 相媲美。\n与以往的研究不同，作者展示了仅通过使用公开可用的数据进行训练，而不依赖专有数据集，就可以达到最先进的性能。我们希望向研究社区发布这些模型能够加速大型语言模型的发展，并帮助改善它们的鲁棒性，并减轻诸如毒性和偏见等已知问题。\n​\t此外，论文也像 Chung 等人（2022年）观察到的那样，发现对这些模型进行指令微调可以带来有希望的结果，我们计划在未来的工作中进一步研究这一点。最后，我们计划未来发布在更大的预训练语料库上训练的更大型模型，因为我们已经看到随着规模的扩大性能在持续提升。\n参考：\nLLM 系列超详细解读 (六)：LLaMa：开源高效的大语言模型\n论文精读：LLaMA: Open and Efficient Foundation Language Models\nLLaMA 超详细解读（paper \u0026amp; code）\n理解NLP最重要的编码方式 — Byte Pair Encoding (BPE)，这一篇就够了\n分词（tokenization）算法之Byte Pair Encoding (BPE) 算法详解（代码实现）_bpe算法\n通俗易懂-大模型的关键技术之一：旋转位置编码rope Transformer升级之路：2、博采众长的旋转式位置编码|Scientific Spaces\n你还不懂旋转位置编码吗？\n","date":"2026年6月4日","externalUrl":null,"permalink":"/posts/llm-basic/llama-series/","section":"文章","summary":"","title":"LLaMA 系列","type":"posts"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/llm/","section":"标签","summary":"","title":"LLM","type":"tags"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/categories/llm%E5%9F%BA%E7%A1%80/","section":"分类","summary":"","title":"LLM基础","type":"categories"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/rope/","section":"标签","summary":"","title":"RoPE","type":"tags"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/lora/","section":"标签","summary":"","title":"LoRA","type":"tags"},{"content":" 1. LoRA # 1.1 前言 # Adapters：\n出现时间：Adapters方法最早出现，其初步形式可以追溯到2016年左右。 方法描述：Adapters通过在模型的每一层之间添加较小的、可训练的网络（称为adapter模块），而不是微调整个模型。这样可以显著减少训练时需要调整的参数数量。 应用：Adapters适用于那些希望在保持预训练模型结构不变的同时，对模型进行特定任务调整的场景。 缺陷：添加适配器层（adapters）的策略虽然参数少，但会在推理阶段引入延迟，特别是在大规模和对延迟敏感的生产环境中。 Prefix Tuning：\n出现时间：Prefix Tuning是在Adapters之后出现的，大约是在2020年左右。 方法描述：在Prefix Tuning中，固定了大部分预训练模型的权重，仅在模型的输入部分添加一系列可训练的前缀向量（prefixes）。这些向量会和输入数据一起被送入模型，从而影响模型的行为。 应用：Prefix Tuning适用于需要对模型进行轻量级微调的场景，特别是当模型非常大，而可用于训练的资源有限时。 缺陷：直接优化输入层激活的方法（prefix）在训练参数方面存在非单调变化，且通过预留部分序列长度用于适应，降低了处理下游任务的序列长度。 Lora (Low-Rank Adaptation)：\n出现时间：Lora是最近几年（大约2021年）出现的方法。\n方法描述：Lora通过向预训练模型的每一层的权重矩阵中添加低秩矩阵来实现微调。这种方法旨在通过改变权重的一个小子集来调整模型的行为，而不是修改整个权重矩阵。如下图所示：\n应用：Lora适用于那些需要在不显著增加计算负担的情况下微调大型模型的场景。\n小结：\n三种方法都是为了解决大型预训练模型微调时存在的参数数量庞大、计算成本高等问题。 Adapters通过添加额外的小型模块进行调整，Prefix Tuning通过修改输入的前缀向量来影响模型，而Lora通过对模型权重的低秩调整来实现微调。 1.2 LoRA的特点和原理 # ​\tLoRA方法冻结预训练模型的权重，并在Transformer架构的每一层注入可训练的秩分解矩阵，大大减少了下游任务的可训练参数数量。LoRA可以通过更少的训练参数和更高的训练吞吐量达到与全面微调相当或更好的效果，并且与适配器不同，没有额外的推理延迟。具体过程如下图所示： ​\t这种方法通过在预训练权重矩阵上加入低秩矩阵B和A，从而实现对模型的微调，同时保持预训练权重冻结。这种设计在训练时只需优化低秩矩阵，大大减少了可训练参数的数量。具体计算公式如下图所示：\n​\t这个公式说明：训练 ΔW\\Delta WΔW 来替代训练 W0W_0W0​。这里 B 的维度为 d∗rd*rd∗r ，A=r∗d,B∗A=d∗dA=r * d, B * A = d * dA=r∗d,B∗A=d∗d，而 r\u0026lt;\u0026lt;dr \u0026lt;\u0026lt; dr\u0026lt;\u0026lt;d。\n​\t基于此，LoRA 可以成功地外推至全参微调：\nAdapters 和 prefix 都无法维持原有架构； 而 LoRA 只是增加了 ΔW\\Delta WΔW，可以维持原有架构。 论文中的说明：LoRA（低秩适应）走得更远，不要求在适应期间累积梯度更新到权重矩阵必须是满秩的。这意味着，当我们将LoRA应用于所有权重矩阵并训练所有偏差时，我们大致上通过将LoRA的秩r设置为预训练权重矩阵的秩，恢复了完全微调的表达能力。换句话说，随着我们增加可训练参数的数量，训练LoRA大致上收敛于训练原始模型，而基于适配器的方法收敛于一个多层感知机（MLP），基于前缀的方法则收敛于一个不能处理长输入序列的模型。 LoRA的特点如下：\nBase LLMs + 不同的 LoRA 支持不同的下游任务。就是定向预调大模型的意思 微调过程中需要相对很少的显存 数学原理：原参数矩阵，加上一个小的、简单的低秩矩阵（B*A升维之后，+W），来生成新的参数矩阵。 在 Transformer 的 multi-head attention 中使用 LoRA 微调后部署：Base + LoRA，不会增加推理时间 可以结合其它微调方法如QLoRA，降低精度，进一步减少计算量 LoRA的优势：\n一个预训练模型可以共享并用于为不同任务构建许多小型LoRA模块。我们可以冻结共享模型，并通过替换图1中的矩阵A和B来高效地切换任务，这显著减少了存储需求和任务切换开销。 LoRA使训练更加高效，并且当使用自适应优化器时，硬件进入门槛降低了高达3倍，因为我们不需要计算大多数参数的梯度或维护优化器状态。相反，我们只优化注入的、更小的低秩矩阵。 我们简单的线性设计允许我们在部署时将可训练矩阵与冻结权重合并，与全面微调模型相比，由于构造原因，不引入推理延迟。 LoRA与许多先前的方法正交，可以与其中许多方法（如前缀调优 prefix-tuning）结合使用。 2. 模型量化 # 2.1 什么是模型量化 # ​\t模型量化是指将神经网络中参数（权重）和激活值（activation）的数值精度从高位（如32位浮点数，FP32）降低为低位（如16位、8位、4位甚至2位整数），以减少存储开销、加速推理并降低能耗的技术。\n其核心思想是：\n用更少的比特来表示相近的数值，同时尽量保持模型性能不受显著影响。\n​\t量化过程一般分为两个阶段： Quantization: q=clip(round(xs)+z,qmin⁡,qmax⁡) \\text{Quantization: } q = \\text{clip}\\left(\\text{round}\\left(\\frac{x}{s}\\right) + z, q_{\\min}, q_{\\max}\\right) Quantization: q=clip(round(sx​)+z,qmin​,qmax​)Dequantization: x^=s×(q−z) \\text{Dequantization: } \\hat{x} = s \\times (q - z) Dequantization: x^=s×(q−z)其中：\nxxx：原始浮点值 qqq：量化后的整数 sss：scale（缩放因子） zzz：zero-point（零点偏移） ⁡qmin⁡,qmax⁡⁡q_{\\min}, q_{\\max}⁡qmin​,qmax​：量化范围（由比特数决定） ​\t可以对模型参数（weight）、激活值（activation）或者梯度（gradient）做量化。通常而言，模型的参数分布较为稳定，因此对参数 weight 做量化较为容易。\n​\t然而，模型的激活值往往存在异常值，直接对其做量化，会降低有效的量化格点数，导致精度损失严重，因此，激活值的量化需要更复杂的处理方法（如 SmoothQuant）。\n​\t模型量化可以看成模型的压缩/解压过程，也可以理解成模型加密/解密的过程。既然量化算法相当于一个压缩算法，自然我们需要关注：\n压缩比，也就是说，一种量化方法能减少多少内存/显存占用？ 压缩/解压缩的速度，这影响量化模型推理的速度，也是我们需要重点优化之处。 ​\t对于第一个关注点，当我们确定了量化精度（例如 int4），确定了量化方法，以及需要量化模型的哪些 layer，其内存和显存占用就基本确定下来了。大部分情况下，我们都只去量化 nn.Linear 层，目前几乎所有量化策略都是这么做的，而且量化模型的显存占用较少，因此我们几乎不会去考虑怎么进一步减少量化模型的体积。\n​\t对于第二个关注点，我们着重于模型 forward、backward 计算过程的解压缩速度。由于这些计算基本都在 GPU 上进行，所以我们就需要去优化 GPU 的 op 了。\n2.2 数值编码方式 # FP16\nFP16 是一种 浮点格式（不是整数量化），采用 IEEE 754 半精度标准：\n位数 字段 含义 1 bit 符号位（sign） sss 5 bits 阶码位（exponent） eee 10 bits 尾数位（mantissa） mmm 量化公式： xFP16=(−1)s×2(e−15)×(1+m210) x_{FP16} = (-1)^s \\times 2^{(e - 15)} \\times (1 + \\frac{m}{2^{10}}) xFP16​=(−1)s×2(e−15)×(1+210m​) 反量化公式（即从FP16还原FP32）： x^FP32=(−1)s×2(e−15)×(1+m210) \\hat{x}_{FP32} = (-1)^s \\times 2^{(e - 15)} \\times (1 + \\frac{m}{2^{10}}) x^FP32​=(−1)s×2(e−15)×(1+210m​) FP16 实际上是低位浮点存储，不需要显式 scale/z，硬件自动完成。\nFP8\nFP8 同样是浮点格式，主要有两种变体：\nE4M3：4位指数 + 3位尾数 E5M2：5位指数 + 2位尾数 常见于 NVIDIA Hopper / TransformerEngine 中。\n量化公式（以E4M3为例）： xFP8=(−1)s×2(e−7)×(1+m23) x_{FP8} = (-1)^s \\times 2^{(e - 7)} \\times (1 + \\frac{m}{2^{3}}) xFP8​=(−1)s×2(e−7)×(1+23m​) 其中 e∈[0,15]e \\in [0,15]e∈[0,15]，m∈[0,7]m \\in [0,7]m∈[0,7]。\n反量化公式： x^FP32=(−1)s×2(e−7)×(1+m8) \\hat{x}_{FP32} = (-1)^s \\times 2^{(e - 7)} \\times (1 + \\frac{m}{8}) x^FP32​=(−1)s×2(e−7)×(1+8m​) FP4\nFP4 是一种实验性低比特浮点格式，常用于研究级或自定义硬件。\n位数 含义 1 bit sign 2 bits exponent 1 bit mantissa 量化公式： xFP4=(−1)s×2(e−1)×(1+m2) x_{FP4} = (-1)^s \\times 2^{(e - 1)} \\times (1 + \\frac{m}{2}) xFP4​=(−1)s×2(e−1)×(1+2m​) 反量化公式： x^FP32=(−1)s×2(e−1)×(1+m2) \\hat{x}_{FP32} = (-1)^s \\times 2^{(e - 1)} \\times (1 + \\frac{m}{2}) x^FP32​=(−1)s×2(e−1)×(1+2m​) FP4 的动态范围极小，因此常用于辅助存储或特定层。\nINT8\n量化公式： q=clip(round(xs)+z,−128,127) q = \\text{clip}\\left(\\text{round}\\left(\\frac{x}{s}\\right) + z, -128, 127\\right) q=clip(round(sx​)+z,−128,127) 反量化公式： x^=s×(q−z) \\hat{x} = s \\times (q - z) x^=s×(q−z) 其中：\ns=xmax⁡−xmin⁡28−1s = \\frac{x_{\\max} - x_{\\min}}{2^{8}-1}s=28−1xmax​−xmin​​ z=round(−xmin⁡s)z = \\text{round}\\left(-\\frac{x_{\\min}}{s}\\right)z=round(−sxmin​​) INT4\n量化公式： q=clip(round(xs)+z,−8,7) q = \\text{clip}\\left(\\text{round}\\left(\\frac{x}{s}\\right) + z, -8, 7\\right) q=clip(round(sx​)+z,−8,7) 反量化公式： x^=s×(q−z) \\hat{x} = s \\times (q - z) x^=s×(q−z) 其中：\ns=xmax⁡−xmin⁡24−1=xmax⁡−xmin⁡15s = \\frac{x_{\\max} - x_{\\min}}{2^{4}-1} = \\frac{x_{\\max} - x_{\\min}}{15}s=24−1xmax​−xmin​​=15xmax​−xmin​​ z=round(−xmin⁡s)z = \\text{round}\\left(-\\frac{x_{\\min}}{s}\\right)z=round(−sxmin​​) 2.3 常用的模型量化方法 # 2.3.1 训练后量化 PTQ # ​\t训练后量化（Post-Training Quantization）在模型训练完成后进行量化；速度快、无额外训练代价，但精度下降明显。\n1. 训练后动态量化\n​\t将模型的权重提前量化为INT8，但激活值在推理过程中动态地量化为INT8。这种方法最简单，通常用于LSTM等模型。\n代码如下：\n# -*- coding: utf-8 -*- import torch class MyModel(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(3, 3, bias=False) self.relu = torch.nn.ReLU() self.linear2 = torch.nn.Linear(3, 1, bias=False) def forward(self, inputs): outputs = self.linear1(inputs) outputs = self.relu(outputs) outputs = self.linear2(outputs) return outputs # 构造训练数据 weights = torch.tensor([[1.1], [2.2], [3.3]]) torch.manual_seed(123) training_features = torch.randn(12000, 3) training_labels = training_features @ weights # 构造测试数据 torch.manual_seed(123) test_feature = torch.randn(1000, 3) test_labels = test_feature @ weights # 初始化模型与优化器 model = MyModel() optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # 训练过程 for i in range(100): preds = model(training_features) loss = torch.nn.functional.mse_loss(preds, training_labels) loss.backward() optimizer.step() optimizer.zero_grad() # 测试 float32 模型 model.eval() with torch.no_grad(): preds = model(test_feature) mse = torch.nn.functional.mse_loss(preds, test_labels) print(f\u0026#34;float32 model testing loss: {mse.item():.3f}\u0026#34;) # 动态量化（int8） model_int8 = torch.ao.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) with torch.no_grad(): preds = model_int8(test_feature) mse = torch.nn.functional.mse_loss(preds, test_labels) print(f\u0026#34;int8 model testing loss: {mse.item():.3f}\u0026#34;) # 查看参数 print(\u0026#34;float32 model linear1 parameter:\\n\u0026#34;, model.linear1.weight) print(\u0026#34;int8 model linear1 parameter (int representation):\\n\u0026#34;, torch.int_repr(model_int8.linear1.weight())) print(\u0026#34;int8 model linear1 parameter (dequantized):\\n\u0026#34;, model_int8.linear1.weight()) 训练后动态量化的问题：\n每一次推理每一层都要对输入统计量化参数，耗时。 每一层计算完都转换为fp32，存入显存，占用显存带宽 2. 训练后静态量化\n​\t在模型推理之前，通过一个**“校准”过程（输入一批有代表性的数据）来统计激活值的分布范围，确定一个固定的量化参数（缩放因子和零点）**。然后将权重和激活值都静态地量化为INT8。这是最常用、性能最好的后量化方法。\n针对训练后动态量化的问题\n用有代表性的输入数据跑一遍整个网络，通过统计得到每层大概的量化参数。 这一层的输出是下一层的输入。下一层还要量化，不如在这一层直接量化在传递给下一层。 训练后静态量化过程如下图所示：\n代码如下：\nimport torch class MyModel(torch.nn.Module): def __init__(self): super().__init__() self.quant = torch.ao.quantization.QuantStub() self.linear1 = torch.nn.Linear(3, 3, bias=False) self.relu = torch.nn.ReLU() self.linear2 = torch.nn.Linear(3, 1, bias=False) self.dequant = torch.ao.quantization.DeQuantStub() def forward(self, inputs): q_inputs = self.quant(inputs) outputs = self.linear1(q_inputs) outputs = self.relu(outputs) outputs = self.linear2(outputs) f_outputs = self.dequant(outputs) return f_outputs # 构造训练数据 weights = torch.tensor([[1.1], [2.2], [3.3]]) torch.manual_seed(123) training_features = torch.randn(12000, 3) training_labels = training_features @ weights # 构造测试数据 torch.manual_seed(123) test_feature = torch.randn(1000, 3) test_labels = test_feature @ weights # 初始化模型与优化器 model = MyModel() optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # 训练 for i in range(100): preds = model(training_features) loss = torch.nn.functional.mse_loss(preds, training_labels) loss.backward() optimizer.step() optimizer.zero_grad() # 测试 float32 模型 model.eval() with torch.no_grad(): preds = model(test_feature) mse = torch.nn.functional.mse_loss(preds, test_labels) print(f\u0026#34;float32 model testing loss: {mse.item():.3f}\u0026#34;) # 设置量化配置（必须执行这步） model.qconfig = torch.ao.quantization.get_default_qconfig(\u0026#34;x86\u0026#34;) # 准备量化（插入 observer） model_prepared = torch.ao.quantization.prepare(model) # 用代表性数据校准 observer model_prepared(test_feature) # 转换成量化模型 model_int8 = torch.ao.quantization.convert(model_prepared) # 测试量化模型 with torch.no_grad(): preds = model_int8(test_feature) mse = torch.nn.functional.mse_loss(preds, test_labels) print(f\u0026#34;int8 model testing loss: {mse.item():.3f}\u0026#34;) # 查看量化参数 print(\u0026#34;float32 model linear1 weight:\\n\u0026#34;, model.linear1.weight) print(\u0026#34;int8 model linear1 weight (int representation):\\n\u0026#34;, torch.int_repr(model_int8.linear1.weight())) print(\u0026#34;int8 model linear1 weight (dequantized):\\n\u0026#34;, model_int8.linear1.weight()) 2.3.2 量化感知训练 # ​\t对训练好的模型，无论怎么量化，总是会有误差。量化感知训练（Quantization-Aware Training）是在网络训练过程中，模拟量化，让模型在训练过程中就能调整参数，让它更适合量化，提高量化后模型的精度。在 QAT 中，前向传播通过“假量化”模块模拟 INT8 运算，使模型在训练中感知量化误差；反向传播中，梯度通过 Straight-Through Estimator 回传到 FP32 权重，使权重在高精度空间中更新。最终在推理阶段才真正将权重与激活量化为 INT8。具体过程如下图所示：\n代码如下：\n# -*- coding: utf-8 -*- import torch class MyModel(torch.nn.Module): def __init__(self): super().__init__() self.quant = torch.ao.quantization.QuantStub() self.linear1 = torch.nn.Linear(3, 3, bias=False) self.relu = torch.nn.ReLU() self.linear2 = torch.nn.Linear(3, 1, bias=False) self.dequant = torch.ao.quantization.DeQuantStub() def forward(self, inputs): q_inputs = self.quant(inputs) outputs = self.linear1(q_inputs) outputs = self.relu(outputs) outputs = self.linear2(outputs) f_outputs = self.dequant(outputs) return f_outputs # 构造数据 weights = torch.tensor([[1.1], [2.2], [3.3]]) torch.manual_seed(123) training_features = torch.randn(12000, 3) training_labels = training_features @ weights torch.manual_seed(123) test_feature = torch.randn(1000, 3) test_labels = test_feature @ weights model = MyModel() # 设置量化配置（必须执行这步） model.qconfig = torch.ao.quantization.get_default_qconfig(\u0026#34;x86\u0026#34;) # 准备量化（插入 observer） model_prepared = torch.ao.quantization.prepare_qat(model) # 训练模型 optimizer = torch.optim.Adam(model.parameters(), lr=0.1) for i in range(100): preds = model_prepared(training_features) loss = torch.nn.functional.mse_loss(preds, training_labels) loss.backward() optimizer.step() optimizer.zero_grad() # 测试浮点模型 model.eval() with torch.no_grad(): preds = model_prepared(test_feature) mse = torch.nn.functional.mse_loss(preds, test_labels) print(f\u0026#34;float32 model testing loss: {mse.item():.3f}\u0026#34;) # 转换成量化模型 model_int8 = torch.ao.quantization.convert(model_prepared) # 测试量化模型 with torch.no_grad(): preds = model_int8(test_feature) mse = torch.nn.functional.mse_loss(preds, test_labels) print(f\u0026#34;float32 model testing loss: {mse.item():.3f}\u0026#34;) # 查看量化参数 print(\u0026#34;float32 model linear1 parameter:\\n\u0026#34;, model_prepared.linear1.weight) print(\u0026#34;int8 model linear1 parameter (int representation):\\n\u0026#34;, torch.int_repr(model_int8.linear1.weight())) print(\u0026#34;int8 model linear1 parameter (dequantized):\\n\u0026#34;, model_int8.linear1.weight()) 2.3.3 对称量化 # ​\t对称量化假设数据分布关于 0 近似对称（如权重分布），因此零点（zero-point）固定为 0。量化范围也关于 0 对称。\n量化公式： q=clip(round(xs),−qmax,qmax) q = \\text{clip}\\left(\\text{round}\\left(\\frac{x}{s}\\right), -q_{max}, q_{max}\\right) q=clip(round(sx​),−qmax​,qmax​) 其中：\nxxx：浮点数 qqq：量化后整数 sss：缩放因子（scale） qmaxq_{max}qmax​：量化区间上界（如 int8 → 127） 反量化公式： x^=s×q \\hat{x} = s \\times q x^=s×q 缩放因子计算：\n通常取对称范围： s=max⁡(∣xmin∣,∣xmax∣)qmax s = \\frac{\\max(|x_{min}|, |x_{max}|)}{q_{max}} s=qmax​max(∣xmin​∣,∣xmax​∣)​ 示例（以 int8 为例）\nqmin=−127,qmax=127q_{min} = -127, q_{max} = 127qmin​=−127,qmax​=127\n若浮点权重范围是 [−2.5,3.0][−2.5,3.0][−2.5,3.0]，则 s=max⁡(∣−2.5∣,∣3.0∣)127=3.0127≈0.0236 s = \\frac{\\max(|-2.5|, |3.0|)}{127} = \\frac{3.0}{127} \\approx 0.0236 s=127max(∣−2.5∣,∣3.0∣)​=1273.0​≈0.0236 某个浮点数 x=1.2x = 1.2x=1.2 的量化结果：q=round(1.2/0.0236)=51q = \\text{round}(1.2 / 0.0236) = 51q=round(1.2/0.0236)=51\n反量化结果：x^=0.0236×51=1.2036\\hat{x} = 0.0236 \\times 51 = 1.2036x^=0.0236×51=1.2036\n2.3.4 非对称量化 # ​\t非对称量化考虑到数据可能不以 0 为中心（例如 ReLU 激活输出通常为非负），因此引入**零点（Zero-Point, z）**以对齐量化整数的零值。\n量化公式： q=clip(round(xs)+z,qmin,qmax) q = \\text{clip}\\left(\\text{round}\\left(\\frac{x}{s}\\right) + z, q_{min}, q_{max}\\right) q=clip(round(sx​)+z,qmin​,qmax​) 其中：\nsss：scale（比例因子） zzz：zero-point（偏移量） qmin,qmaxq_{min}, q_{max}qmin​,qmax​：量化整数范围（如 0~255 对应 uint8） 反量化公式： x^=s×(q−z) \\hat{x} = s \\times (q - z) x^=s×(q−z) 缩放因子和零点计算： s=xmax−xminqmax−qmin,z=round(qmin−xmins) s = \\frac{x_{max} - x_{min}}{q_{max} - q_{min}}, \\quad z = \\text{round}\\left(q_{min} - \\frac{x_{min}}{s}\\right) s=qmax​−qmin​xmax​−xmin​​,z=round(qmin​−sxmin​​) 示例（以 uint8 为例）\n假设激活范围为 [0.5,5.5][0.5,5.5][0.5,5.5]，\nqmin=0,qmax=255q_{min}=0, q_{max}=255qmin​=0,qmax​=255\n计算：\ns=5.5−0.5255−0=5255≈0.0196s = \\frac{5.5 - 0.5}{255 - 0} = \\frac{5}{255} \\approx 0.0196s=255−05.5−0.5​=2555​≈0.0196\nz=round(0−0.5/0.0196)=−26 (clip到[0,255]范围后取0)z = \\text{round}(0 - 0.5 / 0.0196) = -26 \\;\\text{(clip到[0,255]范围后取0)}z=round(0−0.5/0.0196)=−26(clip到[0,255]范围后取0)\n量化：q=round(2.50.0196)+0=128q = \\text{round}\\left(\\frac{2.5}{0.0196}\\right) + 0 = 128q=round(0.01962.5​)+0=128\n反量化：x^=0.0196×(128−0)=2.5088\\hat{x} = 0.0196 \\times (128 - 0) = 2.5088x^=0.0196×(128−0)=2.5088\n2.3 常用的量化方法在LLM中不适用 # ​\t传统的量化方法（如对称/非对称定点量化，典型是 INT8 量化）在 LLM（如 GPT、LLaMA、BERT 大模型）上会遇到明显问题，主要原因如下：\nLLM 对精度极为敏感 LLM 的参数规模动辄数百亿甚至上万亿，网络结构极深。 一点量化误差在层间传播会被放大，造成输出语义漂移。 特别是 注意力层（Attention）和归一化层（LayerNorm） 对精度变化极度敏感。 传统的 8-bit 量化方法往往会导致严重的性能下降（PPL上升、生成文本不连贯等）。 权重和激活分布高度不均匀 CNN 的激活值通常集中在较小范围内，而 LLM 的激活值和权重分布呈 重尾（heavy-tailed）分布。 少数异常值（outlier）占据较大动态范围，导致普通线性量化（如 min-max）无法兼顾大部分数据的精度。 不同层、不同通道的分布差异大 LLM 的不同 Transformer 层、甚至不同矩阵的列/行统计特性差异巨大。 传统统一比例的量化（per-tensor）会极大破坏部分层的数值关系。 → 因此需要更细粒度的量化（per-channel、per-group）甚至自适应量化策略。 2.4 用于量化LLM的方法 # ​\t近年来，研究者提出了多种专为 LLM 优化的量化技术，主要分为以下几类：\n类别 方法代表 特点 1. Post-Training Quantization (PTQ) GPTQ, AWQ, RPTQ, OmniQuant 无需重新训练，通过统计或优化减少量化误差；部署简单，效率高。 2. Quantization-Aware Training (QAT) QLoRA, SmoothQuant+再训练GPTQ, AWQ, RPTQ, OmniQuant 在训练或微调过程中引入量化噪声模拟，提高模型对低精度的适应性。无需重新训练，通过统计或优化减少量化误差；部署简单，效率高。 3. 混合精度量化 (Hybrid Precision) LLM.int8(), FP8, INT4+FP16 混合 对敏感层使用高精度（如 FP16），对鲁棒层使用低比特（如 INT4），平衡性能与精度。 下面介绍目前主流的方法：\nGPTQ（Gradient Post-training Quantization）\n原理：通过最小化量化误差对下游任务影响的梯度近似优化。 逐层量化，每层用二阶信息近似量化误差最小化。 支持 INT4/INT3，精度损失极小。 目前是 LLM 最常用的离线量化方案之一（如 LLaMA、OPT、Vicuna 等均有 GPTQ 版本）。 AWQ（Activation-aware Weight Quantization）\n由 MIT 提出，主要解决激活值中 outlier 过大的问题。 方法：识别关键通道（outlier channel），保留其高精度表示，对其他部分低比特量化。 效果：在几乎无性能损失的前提下实现 INT4 推理。 支持大部分 Transformer 结构，实际推理速度快于 GPTQ。 SmoothQuant\n由微软提出，用于 QAT 或半PTQ。 思路：通过平滑（smooth）激活与权重的比例，使得量化区间更均衡。 通常配合 INT8 推理，尤其在 CPU 或 GPU 部署上效果好。 可用于模型蒸馏或再训练阶段。 QLoRA（Quantized Low-Rank Adapter）\n来自 HuggingFace，用于 量化后的微调（LoRA）。 核心思想：将预训练模型权重量化为 INT4，但微调时仅训练低秩适配器（LoRA 层）。 优点：显著减少显存占用，可在单张 24GB GPU 上微调 65B 模型。 缺点：只适合微调，不直接用于推理优化。 LLM.int8（8-bit Matrix Multiplication for LLMs）\n由 Tim Dettmers 等人提出，专为大语言模型（LLM）设计的 8-bit 混合精度量化方法。\n**核心思想：**在 Transformer 中，不同通道的激活分布差异较大，存在少量异常激活值（outliers），若直接量化会造成精度显著下降。LLM.int8 通过检测这些异常通道，仅将分布稳定的部分量化为 INT8，而将异常通道保留为 FP16 精度，从而实现 INT8 与 FP16 的混合计算，既降低显存占用又保持模型性能。\n优点：\n无需再训练（Post-Training Quantization）。 显存占用减少约 50%，几乎无精度损失。 已在 Hugging Face 的 BitsAndBytes 库中广泛实现，部署方便。 缺点：\n计算中仍包含 FP16 通道，速度提升有限。 依赖支持 INT8 加速的 GPU 硬件（如 A100、H100）。 主要用于推理阶段，不适用于进一步微调。 3.QLoRA # 3.1 背景与动机 # ​\t大型语言模型（LLMs）的微调通常需要大量内存，这限制了在资源受限环境中的应用。**如常规16位微调65B参数的LLaMA模型需要超过780GB的GPU内存（一张A100 GPU 的显存是64GB）。**现有方法如LoRA虽然减少参数，但仍有内存瓶颈。最近的量化方法能进一步减少LLMs的内存占用，但这些技术仅适用于推理，在训练期间会失效。要 在量化模型上做训练 / 微调 ，尤其要保留训练表现，是极具挑战的。\n​\t论文提出了QLoRA，一种高效的微调方法，它减少了内存的使用，足以在单个48GB GPU上微调65B参数的模型，同时保留了完整的16位微调任务性能。QLoRA通过一个冻结的、4位量化的预训练语言模型反向传播梯度到低秩适配器（LoRA）。这标志着LLM微调可访问性的重大转变：现在最大的公开可用模型可以在单个GPU上进行微调。\n​\tQLoRA引入了许多创新来节省内存而不牺牲性能：(a) 4位NormalFloat（NF4），一种对于正态分布权重理论上最优的新数据类型；(b) 双重量化，通过**量化量化常数（缩放因子）**来减少平均内存占用；(c) 分页优化器来管理内存峰值。\n与QLoRA相关的技术如下：\n块状 k-bit 量化（Block-wise k-bit Quantization）：这是一种数据压缩技术，通过减少表示数据的比特数（如从32位浮点数到8位整数）来减少模型大小。为了有效使用低比特数据类型的整个范围，通常会通过最大绝对值归一化输入数据。然而，这种方法存在一个问题：如果输入数据中有极大或极小的异常值，一些量化区间将不会被充分利用。为了解决这个问题，可以将输入数据划分为块，每个块独立进行量化。如下图所示： LoRA（Low-Rank Adaptation） ：LoRA 是一种常用的参数高效微调技术：不直接修改基模型权重，而是在每个线性层插入两个小矩阵（低秩矩阵）去拟合增量，并只训练这些适配器（冻结原权重）。这样可以大幅减少要训练的参数数量。如下图所示： 其中 sss 是一个 标量缩放参数\n​\tsss 是为了补偿量化引入的尺度差异；\n​\t它是一个可学习的标量；\n​\t能够在微调时自动调整 LoRA 的影响强度；\n​\t保证在低精度量化下模型依然稳定、有效。\n参数高效微调的内存需求（Memory Requirement of Parameter-Efficient Finetuning）：这部分讨论了在训练期间，使用低秩适配器（LoRA）时的内存需求。 每个Transformer module，都加上一个 LoRA module； 在不同地方加上 LoRA 模块，会带来不同效果； LoRA的内存占用很小，因此可以使用更多的适配器来提高性能，而不会显著增加总体内存使用。然而，训练时最大的内存开销来自激活梯度，而不是LoRA参数本身。 这些概念是实现QLoRA方法的基础，通过它们，论文展示了如何在保持性能的同时显著减少大型语言模型微调过程中的内存需求。\n3.2 核心要点 # QLoRA 在逻辑上由以下步骤组成：\n1. 量化预训练模型：将基模型权重量化为 4-bit 表示（采用 NF4 类型 + 双重量化）\n加入 LoRA Adapter：在每层的线性/投影矩阵上插入 LoRA（低秩）参数，并冻结原始基模型权重 前向／反向计算时 dequantize → 16-bit 计算：量化存储 + 在需要时恢复高精度 把梯度仅传回给 LoRA 参数（通过 dequantize 路径传播） 使用分页优化器 (paged optimizer) 管理内存峰值（尤其 gradient checkpointing 情况下） 作者强调：为了在量化微调中恢复与 16-bit 相当的性能，必须在所有适当的线性层都加 LoRA，而不能只选部分层。\n3.3 QLoRA微调 # ​\t通过提出的两种技术实现了高保真度的4位微调——4位NormalFloat（NF4）量化和双重量化。此外，论文引入了分页优化器，以防止在梯度检查点期间内存峰值导致内存不足错误，这种错误传统上使得在单台机器上对大型模型进行微调变得困难。\n​\tQLoRA有一种低精度的存储数据类型，通常是4位，和一个通常为BFloat16的计算数据类型。在实践中，这意味着每当使用QLORA权重张量时，我们将其解量化为BFloat16，然后以16位进行矩阵乘法。\n​\t值得注意的是：NF4 并不是一种硬件支持的数据类型（不像 INT4 / FP16 那样），而是一种“非线性量化映射方式”。它只是把符合正态分布的权重映射到 16 个非均匀分布的浮点查表值上，以减少量化误差。因此，在推理计算（矩阵乘法）时不能直接在 INT4 上计算，必须先反量化（dequantize）回浮点数BF16再参与计算。\n3.3.1 NF4 # ​\tNF4是一种量化数据类型，特别适用于正态分布的数据。它通过量化过程将数据的精度降低到4位，从而减少数据存储和计算所需的内存和带宽。\n​\tNormalFloat（NF）数据类型基于分位数量化，这是一种理论上最优的数据类型，确保每个量化区间从输入张量中分配到相等数量的值。分位数量化通过估计输入张量的经验累积分布函数来工作。\n​\t分位数量化的主要限制是分位数估计过程昂贵。因此，使用快速分位数近似算法，例如SRAM分位数，来估计它们。由于这些分位数估计算法的近似性质，该数据类型对离群值的量化误差较大，而这些离群值通常是最重要的值（混合精度模型载入）。当输入张量来自一个固定到量化常数的分布时，可以避免昂贵的分位数估计和近似误差。在这种情况下，输入张量具有相同的分位数，使得精确的分位数估计在计算上是可行的。\n​\t由于预训练神经网络权重通常具有以零为中心的正态分布，标准差为 σ\\sigmaσ ，我们可以通过缩放 σ\\sigmaσ 将所有权重转换为单一固定分布，使得分布正好适合我们数据类型的范围。对于我们的数据类型，我们设置任意范围[−1, 1]。因此，数据类型和神经网络权重的分位数都需要标准化到这个范围内。\n对于标准偏差 σ\\sigmaσ 在 [−1,1][−1, 1][−1,1] 范围内的零均值正态分布的信息理论上最优数据类型的计算如下：\n估计理论上的 N(0,1)N(0, 1)N(0,1) 分布的 2k+12^k+12k+1 分位数，以获得正态分布的 kkk 位分位数量化数据类型； 将此数据类型的值标准化到 [−1,1][−1, 1][−1,1] 范围内； 通过绝对最大重缩放将输入权重张量量化并标准化到 [−1,1][−1, 1][−1,1] 范围内。 ​\t一旦权重范围和数据类型范围匹配，我们就可以像往常一样进行量化。步骤(3)相当于将权重张量的标准偏差重缩放以匹配 k位数 据类型的标准偏差。更正式地，我们如下估计数据类型的 2k2^k2k 个值 qiq_iqi​： qi=12(QX(i2k+1)+QX(i+12k+1)) q_i=\\frac{1}{2}\\big (Q_X(\\frac{i}{2^k+1})+Q_X(\\frac{i+1}{2^k+1}) \\big) qi​=21​(QX​(2k+1i​)+QX​(2k+1i+1​)) ​\t当 QX(⋅)Q_X(·)QX​(⋅) 为标准正态分布 N(0,1)N(0, 1)N(0,1) 的分位数函数时，对称的 kkk 位量化的一个问题是这种方法不能精确表示零，而零在神经网络中很重要（例如填充、零初始化）。为了确保0的离散零点，并使用所有 2k2^k2k 位表示 kkk 位数据类型，我们通过估计两个范围的分位数 qiq_iqi​ 来创建一个非对称数据类型：\n负部分：使用 2k−12^{k-1}2k−1 个量化值覆盖负值范围（对于k=4，有8个值）。这些值对应概率范围[0, 0.5]，分位数点从 j/(2k−1+1)j/(2^{k-1} + 1)j/(2k−1+1) 或类似方式计算，但论文未详细说明具体概率值。\n正部分：使用 2k−1+12^{k-1} + 12k−1+1 个量化值覆盖正值范围（对于k=4，有9个值）。这些值对应概率范围[0.5, 1]。\n合并和去重：将负部分和正部分的量化值集合合并，由于零在两者中都出现（作为边界），移除一个重复的零。最终得到 2k2^k2k 个值（对于k=4，16个值）。\n​\t我们将这种具有每个量化区间内预期值数量相等的结果数据类型称为 kkk 位NormalFloat（NFk），因为这种数据类型对于以零为中心的正态分布数据在信息论上是最优的。\n​\t由于论文中的方法需要计算 2k+12^k+12k+1 个分位数，然后取平均，这可能会在边界处出现无穷大的问题（因为正态分布的分位数在p=0或1时是无穷大）。所以在实现中通常选择一个偏移量（offset=0.9677083）来避免计算极端分位数。这实际上是一种近似，它通过选择一组关键分位数来覆盖大部分概率范围，然后通过归一化到[-1,1]来得到量化值。这种方法可能在实际应用中表现良好，且计算简单。具体代码如下：\n# -*- codeing = utf-8 -*- import torch from scipy.stats import norm def create_normal_map(offset=0.9677083): # 正数部分 v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1].tolist()) # 正数部分 v1 = v1.tolist() # 中间部分（0） v2 = [0] # 负数部分 v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1].tolist())).tolist() # 负数部分 # 合并正数、零、负数部分 v = v1 + v2 + v3 # 转为 torch.Tensor 并排序（从负到正） values = torch.tensor(v) values = values.sort().values # 归一化到 [-1, 1] values /= values.max() return values # 运行查看结果 nf4_values = create_normal_map() print(\u0026#34;NF4 查表分位点：\u0026#34;) print(nf4_values) 计算得到分位数如下：\nImportant NF4_quant_levels = [ -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0 ]\n3.3.2 双重量化 # ​\t双重量化是对量化常数进行二次量化，以进一步节省内存。尽管精确的4位量化需要较小的块大小，但它也带来了相当大的内存开销。例如，使用32位常数和64的块大小对W进行量化时，量化常数平均每个参数增加了 32/64=0.532/64=0.532/64=0.5 位。双重量化有助于减少量化常数的内存占用。\n​\t更具体地说，双重量化将第一次量化的量化常数 cFP32c^{FP32}cFP32 作为第二次量化的输入。这一步产生了量化后的量化常数 cFP8c^{FP8}cFP8 和第二级量化常数 cFP32c^{FP32}cFP32 。我们对第二次量化使用8位浮点数和256的块大小，因为观察到对于8位量化，性能没有降低。**由于 cFP32c^{FP32}cFP32是正数，我们在量化前从c2中减去平均值，使值围绕零居中，并利用对称量化。**平均来说，对于64的块大小，这种量化方式将每个参数的内存占用从 32/64=0.532/64 = 0.532/64=0.5 位降低到 8/64+32/(64⋅256)=0.1278/64 + 32/(64 · 256) = 0.1278/64+32/(64⋅256)=0.127 位，降低了0.373位。如下图所示：\n3.3.3 分页优化器 # ​\t论文使用了NVIDIA的统一内存特性，在GPU偶尔内存不足时，能够自动在CPU和GPU之间进行页面到页面的传输，确保GPU处理过程中无错误。这个特性类似于CPU RAM和磁盘之间的常规内存分页。我们使用这个特性来为优化器状态分配分页内存，当GPU内存不足时，这些状态会自动被转移到CPU RAM中，并在优化器更新步骤中需要内存时再次分页回GPU内存。\n3.3.4 QLoRA # ​\t使用上述组件，论文定义了单个线性层中的QLoRA，在量化基础模型中配备单个LoRA适配器。其中使用 NF4 作为 W，使用 FP8 作为 c2。使用 64 的块大小来提高 W 的量化精度，并使用 256 的块大小来节省 c2 的内存。\nAdapter 布局：在模型的每个适合线性层（例如注意力中的查询/键/值/输出投影矩阵、MLP 层的线性层等）加入 LoRA adapter，并冻结原始量化基模型的权重。\n前向 / 反向流程\n前向时：把量化的权重 dequantize → 与输入相乘 → 加上 adapter 的输出 → 构成最终输出 反向传播时：梯度从损失流回 adapter 参数（LoRA），而原量化权重保持不更新 在梯度传播过程中，梯度不会传给主量化权重，而是直接进入 LoRA 参数，这样保证量化误差不会破坏主模型。 ​\t**总结：**QLoRA有一个存储数据类型（通常是4位的NormalFloat）和一个计算数据类型（16位的BrainFloat）。我们将存储数据类型解量化为计算数据类型，以执行前向和后向传递，但我们只为使用16位BrainFloat的LoRA参数计算权重梯度。\n3.4 论文后续部分 # ​\t第四章及之后的内容主要介绍了QLoRA的实验验证与性能分析。作者在多个大型语言模型（如 LLaMA-7B、13B、33B、65B）上进行了广泛实验，比较了QLoRA 与全参数微调、LoRA、Adapter、Prefix-tuning 等方法的性能与资源消耗。结果表明，QLoRA 仅需 4-bit 量化权重即可在几乎不损失性能的情况下完成高效微调，显著降低显存占用（最多节省约 75% 资源）。在多项基准（如 Alpaca、Vicuna、OpenAssistant 及多任务评测集）上，QLoRA 微调的模型达到了与全参数微调相当甚至更优的效果。此外，作者还介绍了 QLoRA 的可扩展性、稳定性分析以及开源实现细节（PEFT框架集成）。\n​\t论文还分析了训练模型的趋势。首先，结果发现数据质量比数据集大小更重要，例如，一个9k样本数据集（OASST1）在聊天机器人性能上超过了一个450k样本数据集（FLAN v2，抽样子集）即使两者都旨在支持指令遵循泛化。其次，论文展示了在Massive Multitask Language Understanding (MMLU)基准测试上的强劲表现并不意味着在Vicuna聊天机器人基准测试上同样表现强劲，反之亦然——换句话说，对于给定任务，数据集的适用性比大小更重要。\n​\t最后在 结论部分，论文总结了 QLoRA 在节省计算资源与保持模型性能方面的突出优势，并展望了未来在大规模分布式微调、混合精度优化、以及自动量化策略方面的进一步研究方向。\n4. 知识蒸馏 # 4.1 研究背景与动机 # ​\t在一般情况下，我们不会去区分训练和部署使用的模型，但是训练和部署之间存在着一定的不一致性:\n​\t在训练过程中，我们需要使用复杂的模型，大量的计算资源，以便从非常大、高度冗余的数据集中提取出信息。在实验中，效果最好的模型往往规模很大，甚至由多个模型集成得到。而大模型不方便部署到服务中去，常见的瓶颈如下:\n推断速度慢 对部署资源要求高(内存，显存等) ​\t随着深度学习的发展，模型结构越来越复杂、参数量庞大，计算与存储开销显著增加。这种复杂模型（如ResNet、Transformer、YOLO系列大型模型）在准确率上表现优异，但难以部署在计算资源受限的设备（如移动端、嵌入式系统）上。 因此，研究者提出了**知识蒸馏（KD）**方法，其核心思想是：\n利用一个性能较高但庞大的“教师模型”（Teacher Network）指导一个较小的“学生模型”（Student Network）学习，以实现模型压缩与性能保持的平衡。\n​\t该思想最早由 Hinton 等人（2015） 在论文 “Distilling the Knowledge in a Neural Network” 中提出，被认为是知识蒸馏领域的奠基性工作。\n4.2 核心思想 # 知识蒸馏通过引入“软标签（Soft Labels）”传递教师网络的“暗知识（Dark Knowledge）”。\n教师网络在训练后输出的是每个类别的概率分布； 学生网络通过模仿这种概率分布，不仅学习到最终分类结果（硬标签），还学习到类别间的相似关系。 这使学生模型在训练中能获得比硬标签更多的信息，从而提升泛化能力。\n4.3 主要组成部分 # 教师网络（Teacher Network） 通常为预训练好的高性能模型； 参数量大、计算量高； 输出用于生成软标签，作为学生学习的指导信号。 学生网络（Student Network） 结构轻量，待训练； 目标是学习教师网络的知识表示； 可与教师网络结构相同或不同。 蒸馏目标与损失函数 ​\t在蒸馏过程中，通常采用两种监督信号：\n硬标签（Hard Label）：来自数据集的真实标签，使用普通的交叉熵损失； 软标签（Soft Label）：来自教师模型的输出分布，经过温度平滑处理后用于蒸馏损失计算。 综合损失函数为： L=(1−α)Lhard+αT2Lsoft L = (1 - \\alpha)L_{hard} + \\alpha T^2 L_{soft} L=(1−α)Lhard​+αT2Lsoft​ 其中：\nTTT：温度参数，用于平滑教师网络输出； α\\alphaα：加权系数，控制两部分损失的平衡。 在 softmax 前将教师输出 logits 除以温度 T\u0026gt;1T\u0026gt;1T\u0026gt;1： pi=ezi/T∑jezj/T p_i = \\frac{e^{z_i/T}}{\\sum_j e^{z_j/T}} pi​=∑j​ezj​/Tezi​/T​ ​\tTTT 越高，softmax输出的概率分布越趋于平滑，其分布的熵越大，负标签携带的信息会被相对地放大，模型训练将更加关注负标签。\n​\t温度 TTT 的特点：\n原始的softmax函数是 T=1T=1T=1 时的特例，T\u0026lt;1T\u0026lt;1T\u0026lt;1 时，概率分布比原始更“陡峭”，T\u0026gt;1T\u0026gt;1T\u0026gt;1 时，概率分布比原始更“平缓”。 温度越高，softmax上各个值的分布就越平均（思考极端情况: (i) T=∞T=\\inftyT=∞ , 此时softmax的值是平均分布的；(ii) T→0T\\rightarrow0T→0 ，此时softmax的值就相当于 argmaxargmaxargmax , 即最大的概率处的值趋近于1，而其他值趋近于0） 不管温度T怎么取值，Soft target都有忽略相对较小的 携带的信息的倾向 ​\t温度的高低改变的是学生网络训练过程中对负标签的关注程度: 温度较低时，对负标签的关注，尤其是那些显著低于平均值的负标签的关注较少；而温度较高时，负标签相关的值会相对增大，学生网络会相对多地关注到负标签。\n​\t实际上，负标签中包含一定的信息，尤其是那些值显著高于平均值的负标签。但由于教师网络的训练过程决定了负标签部分比较noisy，并且负标签的值越低，其信息就越不可靠。因此温度的选取需要通过实验决定，本质上就是在下面两件事之中取舍:\n从有部分信息量的负标签中学习 \u0026ndash;\u0026gt; 温度要高一些 防止受负标签中噪声的影响 \u0026ndash;\u0026gt;温度要低一些 ​\t总的来说， TTT 的选择和学生网络的大小有关，学生网络参数量比较小的时候，相对比较低的温度就可以了（因为参数量小的模型不能捕捉到所有知识，所以可以适当忽略掉一些负标签的信息）\n4.4 蒸馏过程与损失函数 # 第一步是训练教师网络；\n第二步是在高温 TTT 下，蒸馏教师网络的知识到学生网络，整体流程如下图所示 ​\t训练教师网络的过程很简单，下面详细讲讲第二步：高温蒸馏的过程。高温蒸馏过程的目标函数由distill loss(对应soft target)和student loss(对应hard target)加权得到。示意图如上。 L=(1−α)Lhard+αT2Lsoft L = (1 - \\alpha)L_{hard} + \\alpha T^2 L_{soft} L=(1−α)Lhard​+αT2Lsoft​ viv_ivi​ : 教师网络的logits。 ziz_izi​ : 学生网络的logits。 piTp_i^TpiT​ : 教师网络的在温度 TTT 下的softmax输出在第 iii 类上的值。 qiTq_i^TqiT​ : 学生网络的在温度 TTT 下的softmax输出在第 iii 类上的值。 cic_ici​ : 在第 iii 类上的ground truth值,, 正标签取1，负标签取0。 NNN : 总标签数量。 ​\t教师网络和学生网络同时输入 training set (这里可以直接复用训练教师网络用到的training set)，用教师网络产生的softmax distribution (with high temperature) 来作为soft target，学生网络在相同温度 TTT 条件下的softmax输出和soft target的cross entropy就是Loss函数的第一部分 LsoftL_{soft}Lsoft​\n​\tLsoft=−∑jNpjTlog⁡(qjT)L_{soft}=-\\sum_j^Np_j^T\\log(q_j^T)Lsoft​=−∑jN​pjT​log(qjT​) ，其中 piT=exp⁡(vi/T)∑kNexp⁡(vk/T)p_i^T=\\frac{\\exp(v_i/T)}{\\sum_k^N\\exp(v_k/T)}piT​=∑kN​exp(vk​/T)exp(vi​/T)​ , qiT=exp⁡(zi/T)∑kNexp⁡(zk/T)q_i^T=\\frac{\\exp(z_i/T)}{\\sum_k^N\\exp(z_k/T)}qiT​=∑kN​exp(zk​/T)exp(zi​/T)​\n​\t学生网络在 T=1T=1T=1 的条件下的softmax输出和ground truth的cross entropy就是Loss函数的第二部分 LhardL_{hard}Lhard​ 。\n​\tLhard=−∑jNcjlog⁡(qj1)L_{hard}=-\\sum_j^Nc_j\\log(q_j^1)Lhard​=−∑jN​cj​log(qj1​) ，其中 qj1=exp⁡(zi)∑kNexp⁡(zk)q_j^1=\\frac{\\exp(z_i)}{\\sum_k^N\\exp(z_k)}qj1​=∑kN​exp(zk​)exp(zi​)​\n​\t第二部分Loss 的必要性其实很好理解：教师网路也有一定的错误率，使用ground truth可以有效降低错误被传播给学生网络的可能。打个比方，老师虽然学识远远超过学生，但是他仍然有出错的可能，而这时候如果学生在老师的教授之外，可以同时参考到标准答案，就可以有效地降低被老师偶尔的错误“带偏”的可能性。\n讨论: 为什么要用 T2T^2T2 来缩放 ？\n​\t实验发现第二部分所占比重比较小的时候，能产生最好的结果，这是一个经验的结论。一个可能的原因是，由于soft target产生的gradient与hard target产生的gradient之间有与 TTT 相关的比值。\n​\t简而言之：在知识蒸馏中, soft loss（distillation loss）所带来的梯度，其大小相比hard loss会因温度 TTT 而缩小一个 1T2\\frac{1}{T^2}T21​ 的量级。因此，为了让两部分损失的梯度规模在训练中可比，我们把soft loss乘上 T2T^2T2 。\n4.5 主要优点 # 模型压缩：减少参数与计算量，适用于移动端与嵌入式系统。\n性能保持：学生模型精度接近甚至超过教师模型。\n泛化增强：通过学习教师的隐含知识，模型对不确定样本更鲁棒。\n迁移能力强：可与剪枝、量化、神经网络架构搜索等方法结合。\n4.6 知识蒸馏在LLM中的应用 # ​\t在大语言模型（LLM）领域，知识蒸馏主要用于将超大规模模型（如 GPT、LLaMA 等）的语言理解与生成能力迁移到小型模型中，以实现更高的推理效率和端侧部署能力。其蒸馏过程通常包括以下几个步骤：\n教师输出生成：教师模型根据给定的 Prompt 生成高质量文本响应，作为训练数据的参考输出。 教师与学生对齐输入：教师和学生模型使用相同的输入序列（通常为「Prompt + 教师输出」拼接后的完整序列）。 概率分布蒸馏：在每个时间步（token-level），计算教师与学生预测分布之间的差异（例如使用 KL 散度）。 综合损失优化：综合使用软标签损失（Soft Loss）和硬标签损失（Hard Loss）进行训练。 Ltotal=(1−α)Lhard+αT2Lsoft L_{total} = (1 - \\alpha)L_{hard} + \\alpha T^2 L_{soft} Ltotal​=(1−α)Lhard​+αT2Lsoft​​\t其中，LhardL_{hard}Lhard​ 为交叉熵损失，LsoftL_{soft}Lsoft​ 为基于教师概率分布的 KL 散度损失，TTT 为蒸馏温度系数。\n生成长度不一致问题与对齐策略\n​\t由于 LLM 是自回归生成模型，不同输入或模型生成的文本长度往往不同（教师与学生生成的 token 数量可能不一致），若直接计算分布差异会导致 token 无法对齐。为解决这一问题，常用以下两种策略：\n截断对齐（Truncation Alignment） 当教师与学生生成长度不一致时，取二者生成序列的最短长度 T′=min⁡(TT,TS)T\u0026#x27; = \\min(T_T, T_S)T′=min(TT​,TS​)，仅在前 T′T\u0026#x27;T′ 个 token 上计算损失。 这种方法简单有效，但可能忽略教师输出中后续有价值的信息。\n强制对齐（Teacher Forcing Alignment） 在训练时，固定输入序列为「Prompt + 教师输出」，让学生模型强制学习教师在每个位置的预测分布，而不让学生自由生成。 这种方式保证教师与学生在时间步上严格对齐，是当前主流 LLM 蒸馏（如 DistilBERT、TinyLLaMA、Alpaca 等）采用的方案。 其蒸馏损失定义为： Lsoft=1T∑t=1TKL(pT(T)(⋅∣x\u0026lt;t)∥pS(T)(⋅∣x\u0026lt;t)) L_{soft} = \\frac{1}{T} \\sum_{t=1}^{T} \\text{KL}\\big(p_T^{(T)}(\\cdot|x_{\u0026lt;t}) \\| p_S^{(T)}(\\cdot|x_{\u0026lt;t})\\big) Lsoft​=T1​t=1∑T​KL(pT(T)​(⋅∣x\u0026lt;t​)∥pS(T)​(⋅∣x\u0026lt;t​)) 该方法能显著提高训练稳定性，并避免长度差异带来的对齐误差。\n参考：\n论文精读：LoRa: Low-Rank Adaptation of Large Language Models\n微调方法lora论文逐段精读\n大模型PEFT的LORA算法 lora_rank， lora_alpha\n模型量化：量化基础 对称量化 非对称量化 极大值量化 零点量化\nQLoRA、GPTQ：模型量化概述\n论文精读：QLoRA: Efficient Finetuning of Quantized LLMs\n【论文精读】QLORA: Efficient Finetuning of Quantized LLMs\n史上最全模型蒸馏全解：一步一步带你走向模型蒸馏全流程\n【精读AI论文】知识蒸馏\n【经典简读】知识蒸馏(Knowledge Distillation) 经典之作 ","date":"2026年6月4日","externalUrl":null,"permalink":"/posts/llm-basic/lora-quantization-distillation/","section":"文章","summary":"","title":"LoRA、模型量化与模型蒸馏","type":"posts"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/peft/","section":"标签","summary":"","title":"PEFT","type":"tags"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/qlora/","section":"标签","summary":"","title":"QLoRA","type":"tags"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/%E6%A8%A1%E5%9E%8B%E9%87%8F%E5%8C%96/","section":"标签","summary":"","title":"模型量化","type":"tags"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/%E6%A8%A1%E5%9E%8B%E8%92%B8%E9%A6%8F/","section":"标签","summary":"","title":"模型蒸馏","type":"tags"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/dpo/","section":"标签","summary":"","title":"DPO","type":"tags"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/grpo/","section":"标签","summary":"","title":"GRPO","type":"tags"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/ppo/","section":"标签","summary":"","title":"PPO","type":"tags"},{"content":" 1. PPO(近端策略优化) # ​\tPPO 是 OpenAI 在 2017 年提出的一种强化学习算法，其设计的初衷是为了解决传统策略梯度算法（Policy Gradient）中训练不稳定、更新步长难以确定的问题。在 RLHF 的背景下，PPO 的稳定性和可靠性使其成为优化语言模型的首选。\n1.1 PPO的核心直觉 # ​\t想象一下你在教一个孩子玩一个游戏。如果孩子每次尝试后，你都给予他非常剧烈的反馈（要么是极高的赞扬，要么是严厉的批评），他可能会感到困惑和不知所措，甚至放弃学习。一个更好的方法是，在他当前行为的基础上，温和地引导他向更好的方向改进，每次只做小小的调整。\n​\tPPO 的核心思想与此类似。它认为，在更新模型的策略（即模型生成文本的方式）时，更新的步伐不应该太大。如果新策略与旧策略相差过大，可能会导致模型性能的急剧下降，即“掉下悬崖”。PPO 通过一个巧妙的机制，将策略更新限制在一个“近端”的、可信赖的区域内，从而保证了学习过程的稳定。\n1.2 PPO的关键机制 # ​\tPPO 的魔法在于其目标函数的设计。标准的策略梯度算法的目标是最大化期望回报。PPO 在此基础上引入了两个关键概念：概率比（Probability Ratio）和优势函数（Advantage Function）。\n**优势函数A(s,a)A(s, a)A(s,a) **：代表在状态 sss 下，采取动作 aaa 相较于平均水平有多好。在 RLHF 中，它通常由奖励模型（RM）的输出减去一个基线（Baseline，通常是另一个叫做“价值模型”的网络的输出）来估计。如果 A\u0026gt;0A\u0026gt;0A\u0026gt;0 ，说明这个回答比预期的要好，我们应该增加生成它的概率；反之则减少。 **概率比 rθr_{\\theta}rθ​ **：表示新策略 πθ\\pi_{\\theta}πθ​ 和旧策略 πθold\\pi_{\\theta_{old}}πθold​​ 对于同一个动作的输出概率之比。rθ=πθπθoldr_{\\theta}=\\frac{\\pi_{\\theta}}{\\pi_{\\theta_{old}}}rθ​=πθold​​πθ​​。如果， rθ\u0026gt;1r_{\\theta}\u0026gt;1rθ​\u0026gt;1 说明新策略更倾向于采取这个动作。 PPO 的目标函数（简化版）如下： LCLIP(θ)=Et[min⁡(rt(θ)A^t, clip(rt(θ),1−ϵ,1+ϵ)A^t)] L^{CLIP}(\\theta) = \\mathbb{E}_{t}\\left[ \\min\\Big( r_t(\\theta)\\hat{A}_t, \\; \\text{clip}(r_t(\\theta), 1-\\epsilon, 1+\\epsilon)\\hat{A}_t \\Big) \\right] LCLIP(θ)=Et​[min(rt​(θ)A^t​,clip(rt​(θ),1−ϵ,1+ϵ)A^t​)] 让我们来解读这个公式：\n我们希望最大化概率比 rθr_{\\theta}rθ​ 和优势函数 A^t\\hat{A}_tA^t​ 的乘积。当优势 A^t\\hat{A}_tA^t​ 为正时（好回答），我们希望增大 rθr_{\\theta}rθ​ ；当优势为负时（坏回答），我们希望减小 rθr_{\\theta}rθ​ 。这很直观。 关键在于 clip 函数。clip(x, min, max) 会将 x 限制在 [min, max] 区间内。这里的 ϵ\\epsilonϵ 是一个超参数（通常取 0.1 或 0.2），它定义了一个“信任区域”。 min 函数的作用： 当优势 A^t\u0026gt;0\\hat{A}_t\u0026gt;0A^t​\u0026gt;0 时：我们希望增大 rθr_{\\theta}rθ​ ，但 clip 函数将其上界限制在 。这意味着，即使某个回答非常好，我们对策略的更新也不会过于激进，防止“一步走得太大扯着蛋”。 当优势 A^t\u0026lt;0\\hat{A}_t\u0026lt;0A^t​\u0026lt;0 时：我们希望减小 rθr_{\\theta}rθ​ ，clip 函数将其下界限制在 。这意味着，我们也不会因为一个糟糕的回答而过度惩罚模型，给了模型“改过自新”的机会。 ​\t通过这种方式，PPO 像一个温和而有耐心的老师，确保模型在每次学习后都能稳定进步，而不会因为某次剧烈的更新而崩溃。\n1.3 PPO在RLHF中的角色 # 在 RLHF 流程中，PPO 的工作流程如下：\n初始化：用 SFT 模型的权重初始化策略模型（Policy Model），并通常也用它来初始化价值模型（Value Model）。 采样：从一个指令数据集中随机抽取一个指令（Prompt）。 生成：策略模型根据指令生成一个回答。 评估：奖励模型（RM）对“指令-回答”对打分，得到奖励（Reward）。价值模型（Value Model）对指令进行评估，得到价值（Value）。 计算优势：根据奖励和价值计算优势函数。 更新：使用 PPO 的 Clipped Surrogate Objective 计算损失，并更新策略模型和价值模型的参数。 循环：重复步骤 2-6，直到模型收敛。 ​\t同时，为了防止模型在优化过程中“忘记”SFT 阶段学到的知识，或者为了防止模型生成一些虽然奖励高但内容乱七八糟的文本，PPO 的损失函数中通常还会加入一个 KL 散度惩罚项。这个惩罚项用来衡量当前策略与初始 SFT 策略的差异，差异越大，惩罚越重，确保模型在追求高奖励的同时，不会偏离其原始的语言能力。\n1.4 PPO算法工作流程 # 损失函数推导过程见：零基础学习强化学习算法：PPO\n1.4.1 阶段 1：收集数据 # 从旧策略 πold\\pi_\\text{old}πold​（固定参数）采样生成样本\n给定一个 prompt（输入），用当前策略（语言模型）生成一段文本（token 序列）。 这一过程就形成了一条 trajectory： τ={(s0,a0,r0),(s1,a1,r1),…,(sT,aT,rT)} \\tau = \\{(s_0, a_0, r_0), (s_1, a_1, r_1), \\dots, (s_T, a_T, r_T)\\} τ={(s0​,a0​,r0​),(s1​,a1​,r1​),…,(sT​,aT​,rT​)}其中：\nsts_tst​：状态（当前上下文） ata_tat​：当前生成的 token rtr_trt​：该 token 或整个序列的奖励（来自奖励模型） 记录以下数据：\n每个 token 的 log 概率：log⁡πold(at∣st)\\log π_{\\text{old}}(a_t|s_t)logπold​(at​∣st​) 价值网络预测：Vold(st)V_{\\text{old}}(s_t)Vold​(st​) 奖励：rtr_trt​ 得到一批样本（batch of trajectories）\n1.4.2 计算优势函数和目标回报 # 计算时序差分 δₜ δt=rt+γVold(st+1)−Vold(st) \\delta_t = r_t + \\gamma V_{\\text{old}}(s_{t+1}) - V_{\\text{old}}(s_t) δt​=rt​+γVold​(st+1​)−Vold​(st​) 使用广义优势估计（GAE）计算优势 Aₜ At=∑l=0∞(γλ)lδt+l A_t = \\sum_{l=0}^{\\infty} (\\gamma \\lambda)^l \\delta_{t+l} At​=l=0∑∞​(γλ)lδt+l​ GAE 的作用：平衡偏差与方差，获得平滑稳定的优势估计。\n计算价值网络的目标值（标签） V^ttarget=At+Vold(st) \\hat{V}_t^{target} = A_t + V_{\\text{old}}(s_t) V^ttarget​=At​+Vold​(st​) 这是我们训练 value network 的监督信号。\n对优势进行标准化（稳定训练） At←At−mean(At)std(At)+ϵ A_t \\leftarrow \\frac{A_t - \\text{mean}(A_t)}{\\text{std}(A_t) + \\epsilon} At​←std(At​)+ϵAt​−mean(At​)​ 1.4.3 策略与价值网络的联合训练 # 我们通常会重复多个 epoch，每次从采样数据中取多个 minibatch 来更新。\n策略网络更新\n重新计算当前策略的 log 概率：log⁡πθ(at∣st)\\log π_{\\theta}(a_t|s_t)logπθ​(at​∣st​)\n计算概率比 rt(θ)=πθ(at∣st)πold(at∣st)=exp⁡(log⁡πθ−log⁡πold) r_t(\\theta) = \\frac{π_\\theta(a_t|s_t)}{π_{\\text{old}}(a_t|s_t)} = \\exp(\\log π_{\\theta} - \\log π_{\\text{old}}) rt​(θ)=πold​(at​∣st​)πθ​(at​∣st​)​=exp(logπθ​−logπold​) 计算剪切（Clipped）目标函数 Ltpolicy(θ)=min⁡[rt(θ)At,clip(rt(θ),1−ϵ,1+ϵ)At] L^{policy}_t(\\theta) = \\min \\left[ r_t(\\theta) A_t, \\text{clip}(r_t(\\theta), 1 - \\epsilon, 1 + \\epsilon) A_t \\right] Ltpolicy​(θ)=min[rt​(θ)At​,clip(rt​(θ),1−ϵ,1+ϵ)At​] 当策略改变过大（ratio 超出 [1−ε,1+ε]）时，梯度被截断； 保证更新“近端”稳定。 取 batch 的平均作为策略损失 Lpolicy=−Et[Ltpolicy(θ)] L^{policy} = -\\mathbb{E}_t[L^{policy}_t(\\theta)] Lpolicy=−Et​[Ltpolicy​(θ)] （负号是因为我们希望最大化优势）\n价值网络更新\n使用 MSE 损失\nLvalue=12Et[(Vθ(st)−V^ttarget)2] L^{value} = \\frac{1}{2} \\mathbb{E}_t \\left[ (V_\\theta(s_t) - \\hat{V}_t^{target})^2 \\right] Lvalue=21​Et​[(Vθ​(st​)−V^ttarget​)2]Vθ(st)V_\\theta(s_t)Vθ​(st​)：当前价值网络输出\nV^ttarget\\hat{V}_t^{target}V^ttarget​：由 GAE 得到的回报估计\n熵正则（Entropy）项（可选） Lentropy=−Et[H(πθ(⋅∣st))] L^{entropy} = -\\mathbb{E}_t [ H(π_\\theta(\\cdot|s_t)) ] Lentropy=−Et​[H(πθ​(⋅∣st​))] 鼓励策略保持一定的随机性，避免过早收敛。\n总损失函数 Ltotal=Lpolicy+c1Lvalue−c2Lentropy L^{total} = L^{policy} + c_1 L^{value} - c_2 L^{entropy} Ltotal=Lpolicy+c1​Lvalue−c2​Lentropy 然后对该损失反向传播，更新网络参数 θ。\n1.4.4 迭代更新 # 重复执行： 用新策略采样（rollout）； 计算 A、V、R； 多 epoch、小 batch 训练； 更新策略和价值网络； 直到收敛或性能达到要求。 1.5 PPO的挑战 # ​\t尽管 PPO 非常成功，但 RLHF 中的 PPO 流程相当复杂。它需要同时维护和训练多个模型（策略模型、价值模型、奖励模型、SFT 参考模型），这使得训练过程非常消耗计算资源和内存，且超参数调整也颇具挑战。正是这些挑战，催生了更简洁的替代方案——DPO。\n2. DPO(直接偏好优化) # ​\tDPO 由斯坦福大学的研究者于 2023 年提出，它以一种惊人的简洁性，对传统的 RLHF 流程发起了挑战。DPO 的核心洞见是：我们完全可以绕过奖励模型建模这一中间步骤，直接利用人类的偏好数据来优化语言模型。\n2.1 DPO的核心直觉 # ​\t传统 RLHF 是一个“两步走”的过程：先用偏好数据（A 比 B 好）训练一个能给绝对分数（A 得 90 分，B 得 60 分）的奖励模型，然后再用这个分数去指导强化学习。\n​\tDPO 的提出者反思道：我们最终的目标不就是让模型知道“A 比 B 好”吗？为什么非要先把它变成“A=90分，B=60分”，再回头去学习这个偏好呢？这个中间的奖励建模步骤不仅复杂，还可能引入误差。我们能不能直接建立一个从“偏好”到“策略更新”的数学桥梁？\n​\tDPO 做到了这一点。它巧妙地将偏好数据和语言模型的策略联系起来，推导出了一个简单的、可以用分类损失函数直接优化的目标。\n2.2 从偏好到损失 # ​\tDPO 的理论推导略显复杂，但其最终的损失函数却异常优雅。它始于一个名为 Bradley-Terry 的模型，该模型常用于根据成对比较来估计事物的排名。DPO 假设人类的偏好概率 p∗p^*p∗ 可以用一个潜在的奖励模型 r∗(y,x)r^*(y, x)r∗(y,x) 来建模： P(yw≻yl∣x)=11+exp((r(yl∣x)−r(yw∣x)))=σ(r(yw∣x)−r(yl∣x)) P(y_w \\succ y_l \\mid x) = \\frac{1}{1 + exp((r(y_l|x) - r(y_w|x)))}=\\sigma\\big(r(y_w|x) - r(y_l|x)\\big) P(yw​≻yl​∣x)=1+exp((r(yl​∣x)−r(yw​∣x)))1​=σ(r(yw​∣x)−r(yl​∣x)) ​\t这里 ywy_wyw​ 是被偏好的回答（winner），yly_lyl​ 是不被偏好的回答（loser），xxx 是指令。这个公式本质上是一个 Sigmoid 函数，表示 ywy_wyw​ 的奖励比 yly_lyl​ 高得越多，人类偏好 ywy_wyw​ 的概率就越大。\n​\t接下来是 DPO 最关键的一步。它通过一系列精妙的数学变换，证明了优化语言模型 πθ\\pi_{\\theta}πθ​ 以最大化这个奖励，等价于最小化以下损失函数： LDPO(θ)=−E(x,yw,yl)[log⁡σ(βlog⁡πθ(yw∣x)πθ(yw∣x)−βlog⁡πref(yl∣x)πref(yl∣x))] L_{\\text{DPO}}(\\theta) = -\\mathbb{E}_{(x,y_w,y_l)} \\Big[ \\log \\sigma\\big( \\beta \\log \\frac{\\pi_\\theta(y_w|x)} {\\pi_\\theta(y_w|x)} - \\beta \\log \\frac{\\pi_{\\text{ref}}(y_l|x)} {\\pi_{\\text{ref}}(y_l|x)} \\big) \\Big] LDPO​(θ)=−E(x,yw​,yl​)​[logσ(βlogπθ​(yw​∣x)πθ​(yw​∣x)​−βlogπref​(yl​∣x)πref​(yl​∣x)​)] 让我们来解读这个“天书般”的公式：\nE(x,yw,yl)\\mathbb{E}_{(x,y_w,y_l)}E(x,yw​,yl​)​ 是我们的偏好数据集，包含了大量的 (x,yw,yl)(x, y_w, y_l)(x,yw​,yl​) 三元组。 πθ\\pi_\\thetaπθ​ 是我们正在训练的模型。 πref\\pi_{ref}πref​ 是一个参考模型，通常就是 SFT 阶段得到的模型。它的作用和 PPO 中的 KL 散度惩罚项类似，是为了防止模型“跑偏”。 πθ(y∣x)πref(y∣x)\\frac{\\pi_\\theta(y|x)} {\\pi_{\\text{ref}}(y|x)}πref​(y∣x)πθ​(y∣x)​ 衡量了当前模型相对于参考模型，生成回答 yyy 的概率增加了多少。我们可以把 log⁡πθ(y∣x)πref(y∣x)\\log \\frac{\\pi_\\theta(y|x)} {\\pi_{\\text{ref}}(y|x)}logπref​(y∣x)πθ​(y∣x)​ 看作是模型对回答 yyy 的“隐式奖励”。 核心部分是括号里的差值： (隐式奖励w−隐式奖励l)(\\text{隐式奖励}_w-\\text{隐式奖励}_l)(隐式奖励w​−隐式奖励l​) 。DPO 的目标就是最大化这个差值。也就是说，它希望当前模型生成“获胜”回答的概率相对于参考模型的增幅，要远大于生成“失败”回答的概率的增幅。 最外层的 log⁡σ\\log \\sigmalogσ (log-sigmoid) 结构，使得这个目标变成了一个标准的二元交叉熵损失。这正是 DPO 的绝妙之处，它将复杂的强化学习问题，转化为了一个我们非常熟悉的分类问题。 2.3 推导过程 # 符号约定： πθ(y∣x)\\pi_\\theta(y|x)πθ​(y∣x)：待优化策略（语言模型） πref(y∣x)\\pi_{\\text{ref}}(y|x)πref​(y∣x)：参考策略（通常是 SFT 模型） r(y∣x)r(y|x)r(y∣x)：奖励函数（reward） β\u0026gt;0\\beta\u0026gt;0β\u0026gt;0：温度/平衡参数 σ(s)=11+e−s\\sigma(s) = \\frac{1}{1+e^{-s}}σ(s)=1+e−s1​：Sigmoid 函数 推导过程见：DPO算法讲解\n2.4 DPO的工作流程 # 相比 PPO，DPO 的流程大大简化：\n2.4.1 准备数据 # 收集人类偏好数据集 D={(x,yw,yl)}D = \\{(x, y_w, y_l)\\}D={(x,yw​,yl​)} 每个样本包含一个输入 prompt xxx 两个模型输出：一个被偏好 (winner)，一个被拒绝 (loser) 训练一个SFT（Supervised Fine-Tuned）参考模型 记作 πref(y∣x)\\pi_{ref}(y|x)πref​(y∣x) 它在高质量文本上训练，用作“原始分布约束”。 2.4.2 计算损失 # 对于每个样本 (x,yw,yl)(x, y_w, y_l)(x,yw​,yl​)：\n计算当前模型的 log 概率：log⁡πθ(yw∣x),log⁡πθ(yl∣x)\\log \\pi_\\theta(y_w|x), \\quad \\log \\pi_\\theta(y_l|x)logπθ​(yw​∣x),logπθ​(yl​∣x)\n计算参考模型的 log 概率：log⁡πref(yw∣x),log⁡πref(yl∣x)\\log \\pi_{ref}(y_w|x), \\quad \\log \\pi_{ref}(y_l|x)logπref​(yw​∣x),logπref​(yl​∣x)\n构造相对对比项： Δ=β[(log⁡πθ(yw∣x)−log⁡πref(yw∣x))−(log⁡πθ(yl∣x)−log⁡πref(yl∣x))] \\Delta = \\beta \\left[ (\\log \\pi_\\theta(y_w|x) - \\log \\pi_{ref}(y_w|x)) - (\\log \\pi_\\theta(y_l|x) - \\log \\pi_{ref}(y_l|x)) \\right] Δ=β[(logπθ​(yw​∣x)−logπref​(yw​∣x))−(logπθ​(yl​∣x)−logπref​(yl​∣x))] 计算 DPO 损失：LDPO=−log⁡σ(Δ)L_{DPO} = -\\log \\sigma(\\Delta)LDPO​=−logσ(Δ)\n2.4.3 优化 # 对整个 batch 求平均损失：Lbatch=1N∑iLDPO(i)L_{batch} = \\frac{1}{N} \\sum_i L_{DPO}^{(i)}Lbatch​=N1​∑i​LDPO(i)​ 反向传播更新策略模型参数 θ。 参考模型 πrefπ_refπr​ef 固定不变（不参与更新）。 重复多个 epoch 直到收敛。 ​\t整个过程不需要训练一个独立的奖励模型，也不需要复杂的采样和优势计算，更没有价值模型。这使得 DPO 在实现上更简单，训练更稳定，也更节省计算资源。\n2.5 DPO的权衡 # ​\tDPO 虽然简洁高效，但也有其适用场景和潜在的局限。DPO 强依赖于高质量的成对偏好数据。如果偏好数据的质量不高，或者标注不一致，DPO 的效果可能会受到影响。此外，由于其直接优化的特性，它对于数据分布的变化可能比 PPO 更敏感。在某些需要精细控制奖励函数的复杂场景下，PPO 的灵活性可能依然具有优势。\n3. GRPO(组别相对策略优化) # ​\t就在 PPO 和 DPO 的讨论如火如荼之时，DeepSeek-AI 在其模型（如 DeepSeekMath 和 DeepSeek-R1）的训练中，提出并使用了一种名为 GRPO 的新方法，为 RLHF 带来了新的视角。GRPO 可以看作是 PPO 的一个变种，它通过一种新颖的方式来估计优势函数，从而省去了 PPO 中昂贵的价值模型（Critic Model）。\n3.1 GRPO核心直觉 # ​\tPPO 的核心是优势函数 A(s,a)=R(s,a)−V(s)A(s,a)=R(s,a)-V(s)A(s,a)=R(s,a)−V(s) ，它需要一个奖励模型 R(s,a)R(s,a)R(s,a) 来提供奖励，还需要一个价值模型 V(s)V(s)V(s) 来提供基线（Baseline），即在状态 sss 下的平均期望回报。训练和维护这个价值模型是 PPO 流程中主要的复杂性和成本来源之一。\n​\tGRPO 的提出者思考：我们能不能找到一种更简单的方式来估计这个“平均水平”呢？\n​\t他们的答案是：利用群体智慧。对于同一个指令，我们不只生成一个回答，而是生成一组（Group）回答。然后，我们假设这组回答的平均奖励，就可以近似地作为当前策略下的“平均水平”，也就是价值 的一个估计。\n​\t这个想法非常直观。想象一下，要评价一个学生这次考试的成绩（某个回答的奖励）是好是坏（优势），我们不需要知道他历史上的平均分（价值模型的输出），我们可以直接看他这次在班级里（一组回答中）的排名。如果他的分数远高于班级平均分，那么他的优势就是正的，反之亦然。\n3.2 GRPO关键机制：组内优势估计 # GRPO 的核心是对 PPO 中优势函数的计算方式进行了修改。其步骤如下：\n组采样 (Group Sampling)：对于一个给定的指令 qqq ，使用当前的策略模型 πθ\\pi_\\thetaπθ​ 生成一个包含 GGG 个回答的组 o1,o2,…,oGo_1,o_2,\\dots,o_Go1​,o2​,…,oG​ 。\n组评估 (Group Evaluation)：使用一个奖励函数（可以是一个训练好的奖励模型，也可以是某种可计算的启发式规则，例如代码的执行结果、数学题的答案是否正确等）为组内的每一个回答 oio_ioi​ 打分，得到奖励 rir_iri​ 。\n组内优势计算 (Group-Relative Advantage Estimation)*：计算组内所有回答的平均奖励 rˉ=1G∑i=1Gri\\bar{r}=\\frac1G \\sum_{i=1}^{G} r_irˉ=G1​∑i=1G​ri​ 和标准差 σr\\sigma_rσr​ 。对于组内的每一个回答 oio_ioi​ ，其优势被定义为其归一化后的奖励： Ai=ri−rˉσr A_i=\\frac{r_i-\\bar{r}}{\\sigma_r} Ai​=σr​ri​−rˉ​ 这种方法被称为组内奖励归一化（Group-wise Reward Normalization）。它直接用组内的统计量（均值和标准差）来替代了 PPO 中需要专门训练的价值模型所扮演的角色。\n策略更新：一旦计算出了每个样本的优势 AiA_iAi​，接下来的步骤就和 PPO 非常相似了。GRPO 同样使用 Clipped Surrogate Objective 来更新策略模型。\nLGRPO=E^q,oi[∑i=1Gmin(ri(θ)Ai,clip(ri(θ),1−ϵ,1+ϵ)Ai)] L^{\\text{GRPO}}=\\hat{\\mathbb{E}}_{q,o_i} \\Big [ \\sum_{i=1}^G \\text{min} \\big(r_i(\\theta)A_i, \\text{clip}(r_i(\\theta),1-\\epsilon,1+\\epsilon)A_i \\big) \\Big ] LGRPO=E^q,oi​​[i=1∑G​min(ri​(θ)Ai​,clip(ri​(θ),1−ϵ,1+ϵ)Ai​)]​\t其中 rt(θ)=πθ(oi∣q)πold(oi∣q)r_t(\\theta) = \\frac{π_\\theta(o_i|q)}{π_{\\text{old}}(o_i|q)}rt​(θ)=πold​(oi​∣q)πθ​(oi​∣q)​ 是概率比，AiA_iAi​ 是上面计算出的组内相对优势。\n3.3 GRPO工作流程 # 3.3.1 采样阶段：生成多样回答 # 对于同一个输入 prompt xxx，模型采样多条回答：{y1,y2,...,yN}\\{y_1, y_2, ..., y_N\\}{y1​,y2​,...,yN​}（例如生成 N=4 条不同回答）\n3.3.2 奖励评估阶段：整体打分 # 使用奖励模型或人类偏好模型对每个回答整体打分：ri=R(x,yi)r_i = R(x, y_i)ri​=R(x,yi​)\n这时每条回答只对应一个整体奖励。\n3.3.3 群体标准化（计算相对优势） # 我们并不再去估计状态价值 VtV_tVt​，而是直接使用群体内部的相对得分作为优势： Ai=ri−mean(r1,...,rN)std(r1,...,rN)+ϵ A_i = \\frac{r_i - \\text{mean}(r_1, ..., r_N)}{\\text{std}(r_1, ..., r_N) + \\epsilon} Ai​=std(r1​,...,rN​)+ϵri​−mean(r1​,...,rN​)​ 即：\n如果回答 yiy_iyi​ 比平均好 → 优势 Ai\u0026gt;0A_i \u0026gt; 0Ai​\u0026gt;0； 如果比平均差 → 优势 Ai\u0026lt;0A_i \u0026lt; 0Ai​\u0026lt;0。 这样我们直接得到每条回答的“相对优势”。 这一过程本质上用群体平均值代替了价值函数 V(s)V(s)V(s)。\n3.3.4 广播优势到 token 层 # 虽然每个回答 yiy_iyi​ 只有一个优势 AiA_iAi​，但我们仍要更新生成该回答的每个 token的策略参数。\n因此：At(i)=Ai,∀t∈tokens of yiA_{t}^{(i)} = A_i, \\quad \\forall t \\in \\text{tokens of } y_iAt(i)​=Ai​,∀t∈tokens of yi​\n也就是：\n整个回答的优势平均分配到每个 token 上。\n3.3.5 策略更新 # 使用和 PPO 相同的 clipped policy loss： LGRPOpolicy=−E[min⁡(rt(θ)At,clip(rt(θ),1−ϵ,1+ϵ)At)] L_{GRPO}^{policy} = -\\mathbb{E}\\left[ \\min\\left( r_t(\\theta) A_t, \\text{clip}(r_t(\\theta), 1 - \\epsilon, 1 + \\epsilon) A_t \\right) \\right] LGRPOpolicy​=−E[min(rt​(θ)At​,clip(rt​(θ),1−ϵ,1+ϵ)At​)] 其中： rt(θ)=πθ(at∣st)πold(at∣st) r_t(\\theta) = \\frac{\\pi_\\theta(a_t|s_t)}{\\pi_{\\text{old}}(a_t|s_t)} rt​(θ)=πold​(at​∣st​)πθ​(at​∣st​)​ 区别仅在于：\nPPO 中的 AtA_tAt​ 来源于 GAE； GRPO 中的 AtA_tAt​来源于群体相对归一化得分。 3.4 GRPO优势与特点 # 高效性：GRPO 最显著的优势是无需价值模型。价值模型通常和策略模型一样大，去掉它可以节省近一半的训练内存和计算量，这对于训练超大规模模型来说意义重大。 灵活性：GRPO 对奖励函数的定义非常灵活。它不一定需要一个端到端训练的神经网络奖励模型。在某些任务中（如代码生成、数学推理），我们可以设计出可验证的奖励函数（Verifiable Reward Functions）。例如，如果生成的代码能成功运行并通过所有单元测试，就给予高奖励；如果数学题的最终答案正确，也给予高奖励。这种方式使得奖励信号更客观、更廉价。 稳定性：通过组内归一化，GRPO 使得优势函数的尺度保持在一个稳定的范围内，这有助于稳定训练过程，减少了对超参数的敏感性。 3.5 GRPO的适用场景 # GRPO 特别适用于以下场景：\n计算资源受限：当你希望以更低的成本进行 RLHF 训练时，GRPO 是一个极具吸引力的选择。 存在客观评价标准：在代码、数学、科学等领域，可以通过程序化、确定性的方式来评估生成内容的质量，GRPO 可以充分利用这种廉价而准确的奖励信号。 需要提升模型推理能力：DeepSeek 的实践表明，通过精心设计的奖励函数和 GRPO 训练，可以显著提升模型在复杂推理任务上的表现。 4. PPO、DPO 与 GRPO 的全方位对比 # 特性维度 PPO (Proximal Policy Optimization) DPO (Direct Preference Optimization) GRPO (Group Relative Policy Optimization) 核心思想 在可信区域内小步更新策略，以最大化奖励模型给出的分数。 直接将成对的偏好数据转化为一个分类损失，绕过奖励建模。 通过组内回答的相对好坏来估计优势，从而指导策略更新。 方法论 On-policy 强化学习 Off-policy 偏好学习 (类似于分类) On-policy 强化学习 所需模型 4个：策略模型、价值模型 (Critic)、奖励模型 (RM)、SFT 参考模型。 2个：策略模型、SFT 参考模型。 3个 (通常)：策略模型、奖励函数/模型、SFT 参考模型。无价值模型。 数据需求 需要奖励模型能给出的绝对分数。 需要成对的偏好数据 。 需要奖励函数/模型能给出的分数 (可以是相对的)。 计算成本 最高。需要维护和训练多个大型模型，且有复杂的采样循环。 最低。流程类似监督微调，非常简洁高效。 中等。比 PPO 成本低（省去了价值模型），但比 DPO 复杂（仍有 RL 循环）。 主要优势 灵活、鲁棒。适用于各种复杂的奖励函数，是久经考验的工业标准。 简洁、高效。训练稳定，实现简单，大大降低了 RLHF 的门槛。 高效、灵活。显著降低了 PPO 的成本，且能灵活利用可验证奖励。 主要挑战 复杂、昂贵。实现和调试难度大，对资源消耗极高。 依赖数据质量。对偏好数据的一致性和质量要求高。 组采样开销。需要为每个指令生成多个样本，会增加推理开销。 适用场景 需要精细控制奖励函数、追求极限性能的通用场景。 数据集是成对偏好形式，希望快速、低成本地进行模型对齐的场景。 计算资源受限，或任务存在客观、可程序化验证的奖励标准的场景 (如代码、数学)。 一个生动的比喻 # PPO 像是一位经验丰富的全科医生。他会给你做全面的检查（奖励模型评估），参考你的历史病历（价值模型），然后非常谨慎地给你开药方（Clipped Update），确保疗效的同时将副作用降到最低。这个过程非常完备，但也最耗时耗力。 DPO 像是一位专攻“比对诊断”的专家。你不需要告诉他你的具体指标，你只需要告诉他“相比于昨天，我今天感觉更好了”（偏好数据）。他就能直接根据这些“比对”信息，调整你的治疗方案。这个过程非常直接，省去了很多中间化验环节。 GRPO 像是一位组织“专家会诊”的医生。他把你的情况（指令）告诉一群实习医生（生成一组回答），让他们各自给出诊断方案和信心度（奖励）。然后他根据这些方案在“会诊”中的相对好坏（组内归一化），来决定最终采纳哪个方向的治疗思路。他自己不需要对你的历史情况了如指掌（无需价值模型），而是依赖“集体智慧”做决策。 5.总结 # ​\t从 PPO 的稳定可靠，到 DPO 的简洁直接，再到 GRPO 的高效灵活，我们看到了大模型对齐技术在“效果”、“效率”和“成本”这个不可能三角中的不断探索与演进。\nPPO 作为 RLHF 的开创性和基准性方法，其地位在短期内难以被完全撼动。它强大的灵活性和鲁棒性使其在许多前沿研究和工业应用中依然是首选。 DPO 则成功地为 RLHF “祛魅”，它证明了在许多场景下，我们并不需要复杂的强化学习框架，一个巧妙的损失函数设计就能达到甚至超越 PPO 的效果。它的出现极大地推动了 RLHF 技术的普及和民主化。 GRPO 则在 PPO 的框架内进行了精妙的“减负”，它在保持 PPO 核心优势的同时，显著降低了训练成本，并为利用非传统奖励信号（如可验证奖励）开辟了新的道路，尤其在逻辑推理等领域展现出巨大潜力。 未来何去何从？\n​\t这三种算法并非简单的替代关系，而更可能是一种共存与融合的关系。未来的研究可能会探索：\n混合方法：是否可以将 DPO 的直接偏好学习思想与 PPO/GRPO 的探索性优势结合起来？ 自适应算法：模型是否可以根据任务的特性，自动选择或切换最合适的对齐策略？ 超越偏好对：除了简单的“A\u0026gt;B”偏好，我们如何利用更丰富的反馈信号，例如用户的修改意见、多维度评分等？ 对齐的理论边界：这些算法在多大程度上能真正理解和内化人类的价值观，而不是仅仅在表面上拟合偏好数据？ ​\t毫无疑问，PPO、DPO 和 GRPO 的发展，只是人类探索如何与超强人工智能和谐共舞的序章。理解它们的原理，掌握它们的应用，不仅是每一位 AI 从业者的必备技能，也是我们洞察未来技术走向的一扇重要窗口。\n参考：\n从 PPO、DPO 到 GRPO：万字长文详解大模型训练中的三大关键算法\n零基础学习强化学习算法：PPO\nDPO算法讲解\nDPO\u0026mdash;直接偏好优化（DPO）：你的语言模型实际上是一个奖励模型\nDeepSeek-GRPO\n","date":"2026年6月4日","externalUrl":null,"permalink":"/posts/llm-basic/ppo-dpo-grpo/","section":"文章","summary":"","title":"PPO、DPO 与 GRPO","type":"posts"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/%E5%BC%BA%E5%8C%96%E5%AD%A6%E4%B9%A0/","section":"标签","summary":"","title":"强化学习","type":"tags"},{"content":" 1. 强化学习 # 1.1 强化学习基本概念 # ​\t强化学习（Reinforcement Learning, RL）是一种让智能体（Agent）通过与环境（Environment）交互来**学习最优行为策略（Policy）**的机器学习方法。 其核心思想是：\n智能体根据当前状态选择一个动作，通过环境反馈的奖励（Reward）来调整自己的策略，以最大化长期回报（Cumulative Reward）。\n与监督学习和无监督学习的对比：\n学习方式 数据特征 学习目标 是否有标签 监督学习 已知输入与正确输出 学习从输入到输出的映射关系 ✅ 有标签 无监督学习 只有输入数据 发现数据的隐藏结构或聚类 ❌ 无标签 强化学习 与环境交互获得奖励 学习能获得最大长期回报的策略 ⚙️ 奖励信号取代标签 强化学习既不是监督学习，也不是无监督学习。 它通过与环境的交互、自主探索获得经验，从而优化策略，而不是依赖固定标签的数据。\n1.2 为什么使用强化学习 # 强化学习适用于以下场景：\n缺乏明确监督信号：例如机器人控制、游戏对战、自动驾驶等，无法为每一步提供准确标签； 决策过程依赖序列：当前的动作会影响未来的状态与奖励； 目标是最大化长期回报，而非即时准确率； 可通过环境交互不断改进策略，具有自适应与探索特性。 RL广泛应用于：\nAlphaGo、Atari 游戏智能体； 机器人控制； 推荐系统； RLHF（人类反馈强化学习）中的策略微调。 1.3 强化学习基本组成要素 # 强化学习系统通常由以下五个核心要素组成：\n要素 含义 Agent（智能体） 执行动作、学习策略的主体 Environment（环境） 与智能体交互并提供反馈（奖励和新状态） State（状态） 描述当前环境的情况 Action（动作） 智能体在某一状态下可以执行的行为 Reward（奖励） 环境在智能体执行动作后返回的数值反馈，用于衡量该动作的好坏 智能体的目标是通过不断试错学习到最优策略（Policy），以在长期内获得尽可能高的奖励。\n交互过程：\n智能体在状态 sts_tst​ 下选择一个动作 ata_tat​ ； 环境返回新的状态 st+1s_{t+1}st+1​ 和奖励 rtr_trt​； 智能体根据奖励调整策略； 循环往复，形成经验序列：(s0,a0,r0,s1,a1,r1,… )(s_0, a_0, r_0, s_1, a_1, r_1, \\dots)(s0​,a0​,r0​,s1​,a1​,r1​,…)，通常把这个序列称为Trajectory 1.4 强化学习的目标 # 强化学习的核心目标是：\n学习一个最优策略 π∗(a∣s)\\pi^*(a|s)π∗(a∣s)，使得 期望累积奖励 最大化。形式化地，强化学习的目标可以表示为： \u0026gt;J(π)=Eπ∗(a∣s)[∑t=0∞γtrt]\u0026gt; \u0026gt; J(\\pi) = \\mathbb{E}_{\\pi^*(a|s)}\\left[ \\sum_{t=0}^{\\infty} \\gamma^t r_t \\right] \u0026gt; \u0026gt;J(π)=Eπ∗(a∣s)​[t=0∑∞​γtrt​]\u0026gt; 其中：\nπ(a∣s)\\pi(a|s)π(a∣s)：策略（Policy），即在状态下选择动作的概率分布；\n可分为：\n确定性策略：a=π(s)\\pi(s)π(s) 随机策略：π(a∣s)=P(A=a∣S=s)\\pi(a|s) = P(A=a|S=s)π(a∣s)=P(A=a∣S=s) rtr_trt​：在时刻 ttt 获得的即时奖励；\nγ∈[0,1)\\gamma \\in [0, 1)γ∈[0,1)：折扣因子（Discount Factor），用于平衡短期与长期收益；\n目标是通过优化策略 π\\piπ，使期望累积奖励 J(π)J(\\pi)J(π) 最大化。\n**注意：**这里的期望 Eπθ\\mathbb{E}_{\\pi_\\theta}Eπθ​​ 实际上已经隐含了：\n环境状态转移的随机性； 策略选择动作的随机性。 也就是说，这个期望是对所有可能轨迹的加权平均。\n1.5 价值函数 # 价值函数用于评价状态或动作的“好坏”，帮助智能体优化策略。\n1. 状态价值函数\n表示在给定策略 π\\piπ 下，从状态 sss 开始能获得的期望长期回报： Vπ(s)=Eπ[∑t=0∞γtrt | s0=s] V^{\\pi}(s) = \\mathbb{E}_{\\pi}\\left[ \\sum_{t=0}^{\\infty} \\gamma^t r_t \\ \\middle| \\ s_0 = s \\right] Vπ(s)=Eπ​[t=0∑∞​γtrt​ ​ s0​=s] 它衡量了某个状态的“好坏”。\n2. 动作价值函数\n表示在给定策略 π\\piπ 下，从状态 sss 执行动作 aaa 后能获得的期望长期回报： Qπ(s,a)=Eπ[∑t=0∞γtrt | s0=s,a0=a] Q^{\\pi}(s, a) = \\mathbb{E}_{\\pi}\\left[ \\sum_{t=0}^{\\infty} \\gamma^t r_t \\ \\middle| \\ s_0 = s, a_0 = a \\right] Qπ(s,a)=Eπ​[t=0∑∞​γtrt​ ​ s0​=s,a0​=a] 它衡量了在某个状态下采取某个动作的“好坏”。\n3. 动作价值函数与状态价值函数之间的关系\n状态价值函数其实可以看作动作价值函数的期望（加权平均）： Vπ(s)=∑aπ(a∣s) Qπ(s,a) V^{\\pi}(s) = \\sum_{a} \\pi(a|s) \\, Q^{\\pi}(s, a) Vπ(s)=a∑​π(a∣s)Qπ(s,a) 也就是说：\n状态价值函数是对该状态下所有可能动作的加权平均（权重是策略在该状态下选择该动作的概率）。\n4. 优势函数\n用于衡量一个动作相对于该状态下平均水平的优劣程度： Aπ(s,a)=Qπ(s,a)−Vπ(s) A^{\\pi}(s, a) = Q^{\\pi}(s, a) - V^{\\pi}(s) Aπ(s,a)=Qπ(s,a)−Vπ(s) 优势函数反映了动作 aaa 相对于该状态下平均行为的增益程度，是许多强化学习算法（如PPO）的核心计算部分。\n若 Aπ(s,a)\u0026gt;0A^{\\pi}(s,a) \u0026gt; 0Aπ(s,a)\u0026gt;0，则动作 aaa 比策略平均表现更好； 若 Aπ(s,a)\u0026lt;0A^{\\pi}(s,a) \u0026lt; 0Aπ(s,a)\u0026lt;0，则动作 aaa 较差。\n1.6 探索与利用 # 在强化学习中，智能体（Agent）需要在未知环境中通过不断尝试来获得最大化长期奖励的策略。 然而，在学习过程中，智能体会面临一个关键的决策困境：\n是继续利用已有的经验选择当前看来最优的动作，还是去尝试新的动作以获取更多潜在的知识？\n这就是强化学习中著名的 “探索–利用权衡” 问题。\n1. 探索（Exploration）\n是指智能体有意识地尝试新的动作或策略，以发现可能带来更高奖励的行为。\n即：\n在不确定的情况下进行尝试； 通过探索新的状态–动作组合，获取更多关于环境的知识。 2. 利用（Exploitation）\n是指智能体根据当前已有的知识，选择当前看来能够带来最高奖励的动作。\n即：\n智能体使用已经学到的策略； 根据现有的价值函数或策略评估，选择期望回报最高的行为。 3. 常见平衡方法\nε-贪心策略（ε-Greedy Policy）\n以概率 1−ε1 - \\varepsilon1−ε 选择当前最优动作；\n以概率 ε\\varepsilonε 随机选择一个动作进行探索。\n数学表达式： a={arg⁡max⁡aQ(s,a),以概率 1−ε随机动作,以概率 ε a = \\begin{cases} \\arg\\max_a Q(s,a), \u0026amp; \\text{以概率 } 1-\\varepsilon \\\\ \\text{随机动作}, \u0026amp; \\text{以概率 } \\varepsilon \\end{cases} a={argmaxa​Q(s,a),随机动作,​以概率 1−ε以概率 ε​ 通常会让 ε\\varepsilonε 随时间逐渐减小，以在训练后期更多地利用已学知识。\nSoftmax 策略（Boltzmann Exploration）\n按照动作的 Q 值通过 Softmax 分布采样动作，使得较优动作被选择的概率更高： P(a∣s)=eQ(s,a)/τ∑a′eQ(s,a′)/τ P(a|s) = \\frac{e^{Q(s,a)/\\tau}}{\\sum_{a\u0026#x27;} e^{Q(s,a\u0026#x27;)/\\tau}} P(a∣s)=∑a′​eQ(s,a′)/τeQ(s,a)/τ​ 参数 τ\\tauτ（温度参数）控制探索程度：\nτ\\tauτ 大 → 行为更随机（高探索）； τ\\tauτ 小 → 更倾向选择最优动作（高利用）。 1.7 Actor–Critic 架构 # 在强化学习中，我们通常需要同时解决两个问题：\n如何选择动作？（即策略问题） 如何评估一个状态或动作的好坏？（即价值评估问题） Actor–Critic 架构将这两个问题分开处理：\nActor（行动者）：负责决策 —— 输出动作； Critic（评估者）：负责评估 —— 判断当前策略的好坏，并给出学习信号。 二者相互配合，共同完成策略优化。\n1. Actor\nActor 是策略网络（Policy Network），它的任务是根据当前状态 sts_tst​，输出一个动作的概率分布 π(at∣st;θ)\\pi(a_t|s_t;\\theta)π(at​∣st​;θ)。\n它决定“在当前状态下要采取什么行动”； 通过梯度上升（Policy Gradient）优化策略参数 θ\\thetaθ，以最大化期望奖励。 策略目标函数为： J(θ)=Eπθ[∑t=0∞γtrt] J(\\theta) = \\mathbb{E}_{\\pi_\\theta}\\left[ \\sum_{t=0}^{\\infty} \\gamma^t r_t \\right] J(θ)=Eπθ​​[t=0∑∞​γtrt​] Actor 通过梯度更新策略参数（推导过程见 此处） ∇θJ(θ)=Eπθ[∇θlog⁡πθ(at∣st) Aπ(st,at)] \\nabla_\\theta J(\\theta) = \\mathbb{E}_{\\pi_\\theta}\\left[\\nabla_\\theta \\log \\pi_\\theta(a_t|s_t) \\, A^{\\pi}(s_t, a_t)\\right] ∇θ​J(θ)=Eπθ​​[∇θ​logπθ​(at​∣st​)Aπ(st​,at​)] 其中：\nAπ(st,at)A^{\\pi}(s_t, a_t)Aπ(st​,at​) 是优势函数，用于衡量当前动作相对平均水平的好坏； 该项由 Critic 提供。 2. Critic\nCritic 是价值函数网络（Value Network），用于评估当前策略的表现，给 Actor 提供“学习方向”。\nCritic 的目标是预测状态或状态–动作的价值函数：\n状态价值函数：Vπ(s)V^{\\pi}(s)Vπ(s) 或动作价值函数：Qπ(s,a)Q^{\\pi}(s, a)Qπ(s,a) Critic 的优化目标通常是最小化价值函数的均方误差（MSE）： L(ϕ)=Eπθ[(Rt−Vϕ(st))2] L(\\phi) = \\mathbb{E}_{\\pi_\\theta}\\left[\\left(R_t - V_\\phi(s_t)\\right)^2\\right] L(ϕ)=Eπθ​​[(Rt​−Vϕ​(st​))2] 其中：\nRt=∑k=t∞γk−trkR_t = \\sum_{k=t}^{\\infty} \\gamma^{k-t} r_kRt​=∑k=t∞​γk−trk​ 为实际回报； Vϕ(st)V_\\phi(s_t)Vϕ​(st​) 为 Critic 网络估计的状态价值； 参数 ϕ\\phiϕ为 Critic 的网络权重。 3. Actor 与 Critic 的交互过程\n1️⃣ **状态输入：**智能体从环境中接收当前状态 sts_tst​。\n2️⃣ Actor 决策动作\n策略网络（Actor）根据当前状态输出一个动作分布：πθ(at∣st)\\pi_\\theta(a_t|s_t)πθ​(at​∣st​) 从该分布中采样得到一个具体动作 ata_tat​。 3️⃣ 环境反馈\n环境执行动作 ata_tat​ 后，返回： 奖励 rtr_trt​ 新状态 st+1s_{t+1}st+1​ 4️⃣ Critic 价值评估\n价值网络（Critic）根据当前状态估计价值：Vϕ(st)V_\\phi(s_t)Vϕ​(st​) 同时可计算下一状态的价值预测：Vϕ(st+1)V_\\phi(s_{t+1})Vϕ​(st+1​) 得到当前动作的“经验回报”（即目标Q值）：Qtarget(st,at)=rt+γVϕ(st+1)Q^{\\text{target}}(s_t,a_t) = r_t + \\gamma V_\\phi(s_{t+1})Qtarget(st​,at​)=rt​+γVϕ​(st+1​) 引导式估计（bootstrapping） 5️⃣ 计算优势函数（Advantage）\n通过 Critic 的输出，计算动作的相对优劣：A(st,at)=Qtarget(st,at)−Vϕ(st)A(s_t,a_t) = Q^{\\text{target}}(s_t,a_t) - V_\\phi(s_t)A(st​,at​)=Qtarget(st​,at​)−Vϕ​(st​)\n这个值表示“当前动作比平均水平好多少”。\n6️⃣ Actor 参数更新（策略优化）\n使用优势函数指导 Actor 更新策略参数：∇θJ(θ)=E[∇θlog⁡πθ(at∣st) A(st,at)]\\nabla_\\theta J(\\theta) = \\mathbb{E}\\left[\\nabla_\\theta \\log \\pi_\\theta(a_t|s_t) \\, A(s_t,a_t)\\right]∇θ​J(θ)=E[∇θ​logπθ​(at​∣st​)A(st​,at​)]\n当 A(st,at)\u0026gt;0A(s_t,a_t) \u0026gt; 0A(st​,at​)\u0026gt;0 时，该动作的概率将增加；反之减少。\n7️⃣ Critic 参数更新（价值拟合）\nCritic 通过最小化 MSE 损失函数来更新自身参数：\nL(ϕ)=(Qtarget(st,at)−Vϕ(st))2L(\\phi) = \\left(Q^{\\text{target}}(s_t,a_t) - V_\\phi(s_t)\\right)^2L(ϕ)=(Qtarget(st​,at​)−Vϕ​(st​))2\n即让预测值 Vϕ(st)V_\\phi(s_t)Vϕ​(st​) 更接近目标值 QtargetQ^{\\text{target}}Qtarget。\n8️⃣ 循环交互\n进入下一时间步 t+1t+1t+1，重复整个过程。 1.8 总结 # ​\t强化学习是一种基于奖励反馈机制的自学习方法，其核心目标是通过交互学习最优策略。在现代人工智能系统（如 RLHF 框架）中，强化学习方法尤其在优化大语言模型（LLM）的行为策略方面发挥着关键作用。掌握强化学习的基础知识 —— 尤其是价值函数、策略梯度与Actor–Critic 架构 —— 是理解 RLHF 的前提条件。\n2. GAE和重要性采样 # 2.1 GAE # 1. TD 误差与基本优势估计\n定义时间差分（TD）误差： δt=rt+γV(st+1)−V(st) \\delta_t = r_t + \\gamma V(s_{t+1}) - V(s_t) δt​=rt​+γV(st+1​)−V(st​) 那么最简单的优势近似是： A^t=δt \\hat{A}_t = \\delta_t A^t​=δt​ 这就是 1-step Advantage Estimation。虽然方差小，但偏差较大。\n2. n-step Advantage 与 GAE\n​\tGAE（广义优势估计，Generalized Advantage Estimation）由 Schulman 等人在 TRPO/PPO 中提出。它通过加权多步回报，在方差与偏差之间取得平衡。\nn步 TD 误差定义为： At(n)=∑l=0n−1(γ)lδt+l A_t^{(n)} = \\sum_{l=0}^{n-1} (\\gamma)^l \\delta_{t+l} At(n)​=l=0∑n−1​(γ)lδt+l​ 比如：\nn=1：只看一步（偏差大，方差小）； n=T：看整个序列（Monte Carlo，无偏但方差大）。 2. GAE 的核心思想\nGAE 通过一个衰减因子 λ∈[0,1]\\lambda \\in [0,1]λ∈[0,1]，将所有 n-step 优势加权平均： A^tGAE(γ,λ)=∑l=0∞(γλ)lδt+l \\hat{A}_t^{GAE(\\gamma,\\lambda)} = \\sum_{l=0}^{\\infty} (\\gamma \\lambda)^l \\delta_{t+l} A^tGAE(γ,λ)​=l=0∑∞​(γλ)lδt+l​ 或者用递推式写作： A^t=δt+γλA^t+1 \\hat{A}_t = \\delta_t + \\gamma \\lambda \\hat{A}_{t+1} A^t​=δt​+γλA^t+1​ 其中：\nγ\\gammaγ：折扣因子； λ\\lambdaλ：平滑参数（控制偏差与方差权衡）。作用如下 λ 值 行为 说明 λ = 0 等价于 1-step TD 方差小、偏差大 λ = 1 等价于 Monte Carlo Return 方差大、无偏 0 \u0026lt; λ \u0026lt; 1 折中方案（常用 λ=0.95） 最佳性能实践值 2.2 重要性采样 # 在策略梯度中，我们希望最大化期望： J(θ)=Eπθ[Rt] J(\\theta) = \\mathbb{E}_{\\pi_\\theta}[R_t] J(θ)=Eπθ​​[Rt​] 但实际数据通常来自 旧策略 πθold\\pi_{\\theta_{old}}πθold​​，而不是当前策略 πθ\\pi_\\thetaπθ​。为了重用旧数据，我们引入 重要性采样（Importance Sampling）。\n1. 重要性采样的基本原理\n设有两个分布 p(x)p(x)p(x) 和 q(x)q(x)q(x)，如果我们只能从 q(x)q(x)q(x) 采样，但要计算 Ep[f(x)]E_p[f(x)]Ep​[f(x)]， 则可以通过： Ep[f(x)]=Eq[p(x)q(x)f(x)] E_p[f(x)] = E_q\\left[ \\frac{p(x)}{q(x)} f(x) \\right] Ep​[f(x)]=Eq​[q(x)p(x)​f(x)] 这时 p(x)q(x)\\frac{p(x)}{q(x)}q(x)p(x)​称为 重要性权重（importance weight）。\n2. 应用于策略梯度\n在策略梯度中： ∇θJ(θ)=Eπθ[∇θlog⁡πθ(at∣st)At] \\nabla_\\theta J(\\theta) = \\mathbb{E}_{\\pi_\\theta}\\left[ \\nabla_\\theta \\log \\pi_\\theta(a_t|s_t) A_t \\right] ∇θ​J(θ)=Eπθ​​[∇θ​logπθ​(at​∣st​)At​] 但数据来自旧策略 πold\\pi_{old}πold​，于是用重要性采样修正分布： ∇θJ(θ)=Eπold[πθ(at∣st)πold(at∣st)∇θlog⁡πθ(at∣st)At] \\nabla_\\theta J(\\theta) = \\mathbb{E}_{\\pi_{old}}\\left[ \\frac{\\pi_\\theta(a_t|s_t)}{\\pi_{old}(a_t|s_t)} \\nabla_\\theta \\log \\pi_\\theta(a_t|s_t) A_t \\right] ∇θ​J(θ)=Eπold​​[πold​(at​∣st​)πθ​(at​∣st​)​∇θ​logπθ​(at​∣st​)At​] 定义：\nrt(θ)=πθ(at∣st)πold(at∣st)r_t(\\theta) = \\frac{\\pi_\\theta(a_t|s_t)}{\\pi_{old}(a_t|s_t)}rt​(θ)=πold​(at​∣st​)πθ​(at​∣st​)​ 称为 重要性比（importance ratio）。\n3. PPO算法 # 具体算法介绍请看 零基础学习强化学习算法：PPO\n3.1 PPO算法的背景与思想 # 在传统的 策略梯度方法（Policy Gradient） 中，我们的目标是最大化策略的期望回报： J(θ)=Eπθ[Rt] J(\\theta) = \\mathbb{E}_{\\pi_\\theta}\\left[ R_t \\right] J(θ)=Eπθ​​[Rt​] 更新方式为： ∇θJ(θ)=Eπθ[∇θlog⁡πθ(at∣st)At] \\nabla_\\theta J(\\theta) = \\mathbb{E}_{\\pi_\\theta}\\left[ \\nabla_\\theta \\log \\pi_\\theta(a_t|s_t) A_t \\right] ∇θ​J(θ)=Eπθ​​[∇θ​logπθ​(at​∣st​)At​] 但是这种直接更新有两个问题：\n方差大 → 更新不稳定； 步长难控制 → 策略可能剧烈变化，导致性能退化。 于是，PPO 出现了：\n它通过“约束策略更新幅度”来稳定训练， 同时结合 GAE 降低方差、结合 重要性采样 重用旧数据。\n3.2 PPO核心目标函数 # 在 PPO 中，我们定义： rt(θ)=πθ(at∣st)πθold(at∣st) r_t(\\theta) = \\frac{\\pi_\\theta(a_t|s_t)}{\\pi_{\\theta_{old}}(a_t|s_t)} rt​(θ)=πθold​​(at​∣st​)πθ​(at​∣st​)​ 这是 重要性比（importance ratio）， 用于衡量新旧策略在同一状态下采取相同行动的相对概率。\n原始目标（带重要性采样的策略梯度） LPG(θ)=Et[rt(θ)A^t] L^{PG}(\\theta) = \\mathbb{E}_{t}\\left[ r_t(\\theta) \\hat{A}_t \\right] LPG(θ)=Et​[rt​(θ)A^t​] 其中 A^t\\hat{A}_tA^t​ 为优势函数估计（通常使用 GAE 计算）。\nPPO 的关键创新：Clipped Objective\n为了防止策略变化过大，PPO 引入“裁剪”操作： LCLIP(θ)=Et[min⁡(rt(θ)A^t, clip(rt(θ),1−ϵ,1+ϵ)A^t)] L^{CLIP}(\\theta) = \\mathbb{E}_{t}\\left[ \\min\\Big( r_t(\\theta)\\hat{A}_t, \\; \\text{clip}(r_t(\\theta), 1-\\epsilon, 1+\\epsilon)\\hat{A}_t \\Big) \\right] LCLIP(θ)=Et​[min(rt​(θ)A^t​,clip(rt​(θ),1−ϵ,1+ϵ)A^t​)] 当 rt(θ)r_t(\\theta)rt​(θ) 远离 1 时（即策略变化太大），更新会被裁剪； ϵ\\epsilonϵ 是一个超参数（通常取 0.1～0.3）。 这样做能有效防止策略“走太远”，稳定训练。\n3.3 GAE与PPO # PPO 仍是一个 Actor-Critic 结构：\nActor（策略网络） 更新基于 LCLIPL^{CLIP}LCLIP； Critic（价值网络） 更新基于状态值误差。 其中 Advantage 使用 GAE 计算： A^t=∑l=0∞(γλ)lδt+l,δt=rt+γV(st+1)−V(st) \\hat{A}_t = \\sum_{l=0}^{\\infty} (\\gamma \\lambda)^l \\delta_{t+l}, \\quad \\delta_t = r_t + \\gamma V(s_{t+1}) - V(s_t) A^t​=l=0∑∞​(γλ)lδt+l​,δt​=rt​+γV(st+1​)−V(st​) GAE 优势：\n平滑优势估计； 平衡偏差与方差； 提高 PPO 策略更新的稳定性。 3.4 价值函数与熵正则项 # PPO 的总损失函数通常包含三部分： LPPO(θ)=Et[LCLIP(θ)−c1(Vθ(st)−Rt)2+c2S[πθ](st)] L^{PPO}(\\theta) = \\mathbb{E}_t[ L^{CLIP}(\\theta) - c_1 (V_\\theta(s_t) - R_t)^2 + c_2 S[\\pi_\\theta](s_t) ] LPPO(θ)=Et​[LCLIP(θ)−c1​(Vθ​(st​)−Rt​)2+c2​S[πθ​](st​)] 在实现中，定义损失为： loss=−LPPO(θ) loss=−L^{PPO}(\\theta) loss=−LPPO(θ) 其中：\n第一项：策略目标（Actor 部分）； 第二项：价值函数损失（Critic 部分）； 第三项：熵正则项，用于鼓励探索（防止策略过早收敛）； c1,c2c_1, c_2c1​,c2​ 为权重超参数。 3.5 PPO 的训练流程（结合 GAE 与 IS） # Step 1：采样轨迹\n从当前旧策略 πθold\\pi_{\\theta_{old}}πθold​​ 运行若干步，获得：{(st,at,rt,st+1)}\\{(s_t, a_t, r_t, s_{t+1})\\}{(st​,at​,rt​,st+1​)}\nStep 2：计算优势函数（使用 GAE）\n计算 TD 误差：δt=rt+γV(st+1)−V(st)\\delta_t = r_t + \\gamma V(s_{t+1}) - V(s_t)δt​=rt​+γV(st+1​)−V(st​) 递推得到 GAE：A^t=δt+γλA^t+1\\hat{A}_t = \\delta_t + \\gamma \\lambda \\hat{A}_{t+1}A^t​=δt​+γλA^t+1​ 同时计算回报目标：Rt=A^t+V(st)R_t = \\hat{A}_t + V(s_t)Rt​=A^t​+V(st​) Step 3：计算重要性比 rt=πθ(at∣st)πθold(at∣st) r_t = \\frac{\\pi_\\theta(a_t|s_t)}{\\pi_{\\theta_{old}}(a_t|s_t)} rt​=πθold​​(at​∣st​)πθ​(at​∣st​)​ Step 4：优化目标函数\n最小化负的 PPO 损失（或最大化其相反数）： LCLIP=E[min⁡(rtA^t,clip(rt,1−ϵ,1+ϵ)A^t)] L^{CLIP} = \\mathbb{E}\\left[ \\min(r_t \\hat{A}_t, \\text{clip}(r_t,1-\\epsilon,1+\\epsilon)\\hat{A}_t) \\right] LCLIP=E[min(rt​A^t​,clip(rt​,1−ϵ,1+ϵ)A^t​)] Step 5：更新网络\nActor：梯度上升最大化 LCLIPL^{CLIP}LCLIP； Critic：最小化 (V(st)−Rt)2(V(s_t) - R_t)^2(V(st​)−Rt​)2； 更新旧策略：θold←θ\\theta_{old} \\leftarrow \\thetaθold​←θ 循环执行以上步骤，直到收敛。\n3.6 PPO 的两个常见变体 # 变体 特点 说明 PPO-Clip 使用裁剪函数（最常见） 稳定且计算简单 PPO-Penalty 使用KL散度惩罚项 控制新旧策略距离（但需调节系数 β） PPO-Clip 的目标： LCLIP=E[min⁡(rtAt,clip(rt,1−ϵ,1+ϵ)At)] L^{CLIP} = \\mathbb{E}[\\min(r_t A_t, \\text{clip}(r_t,1-\\epsilon,1+\\epsilon)A_t)] LCLIP=E[min(rt​At​,clip(rt​,1−ϵ,1+ϵ)At​)] PPO-Penalty 的目标： LKL=E[rtAt−βDKL(πold∥πθ)] L^{KL} = \\mathbb{E}[r_t A_t - \\beta D_{KL}(\\pi_{old} \\| \\pi_\\theta)] LKL=E[rt​At​−βDKL​(πold​∥πθ​)] 3.7 常用超参数参考 # 参数 含义 常用值 γ 折扣因子 0.99 λ GAE 衰减因子 0.95 ε PPO 裁剪范围 0.1 – 0.3 c₁ 价值函数损失权重 0.5 c₂ 熵项权重 0.01 学习率 优化器步长 3e-4 (Adam) 3.8 总结 # ​\tPPO = Policy Gradient + Importance Sampling + Clipping + GAE\n​\t它在保持理论简洁性的同时，实现了高稳定性与强泛化能力，是当前最常用、性能最稳定的策略梯度算法之一。\n4. RLHF框架 # 4.1 RLHF概述 # ​\tRLHF（Reinforcement Learning from Human Feedback） 是一种结合了监督学习与强化学习的训练方法，旨在让大语言模型（LLM）更好地符合人类偏好。其核心思想是：\n通过人类反馈训练一个奖励模型（Reward Model），再利用强化学习算法优化语言模型，使其生成的回答更接近人类认为“好的”输出。\n4.2 RLHF三阶段流程 # 1. 监督微调（SFT）\n​\t在第一阶段，使用人工编写的高质量问答对（prompt–response）对基准语言模型进行监督微调。\n​\t这一阶段的目标是让模型具备初步的对话能力和任务理解能力，为后续的偏好学习奠定基础。\n输入： 人工标注的高质量问答数据集 输出： 监督微调后的初始模型（SFT Model）\n2. 奖励模型训练（Reward Model Training）\n​\t第二阶段引入偏好数据集（Preference Dataset）。\n​\t该数据集由基准模型生成多个回答，然后由人类标注者对回答进行两两比较，指出更优的回答。格式为 (prompt,chosen_response, rejected_response)\n通过这种“偏好对比（Preference Comparison）”数据，可以训练一个**奖励模型（Reward Model）**，使其学会评估回答的优劣。 ​\t这不是一个简单的回归（预测一个绝对分数），而通常是一个排序学习问题。最常用的损失函数是Bradley-Terry模型的对比损失： loss = -log(sigmoid(R(chosen) - R(rejected)))\n​\t这个损失函数的核心思想是：最大化“被选中的回答”和“被拒绝的回答”之间的分数差。模型最终学会输出一个标量分数 R(prompt, response)，这个分数代表了回答的质量。\n输入： 偏好数据集（包含 prompt 及人类偏好排序的多个 response） 输出： 奖励模型 Reward Model，能够为任意 (prompt, response) 对打分\n​\t奖励模型的输出通常是一个标量分数，表示该回答的“好坏程度”。该模型在强化学习阶段中充当“环境”的角色，用于提供奖励信号。\n3. 强化学习阶段（Reinforcement Learning Fine-Tuning）\n​\t在第三阶段，使用强化学习算法（如 PPO，Proximal Policy Optimization）优化语言模型，使其生成更高分的回答。\n​\t这一阶段通常使用仅包含提示（prompt）的提示数据集（Prompt Dataset）。\n​\t训练过程如下：\n模型根据 prompt 生成 response。 奖励模型对该 response 打分（即计算 reward）。 强化学习算法根据 reward 信号更新语言模型的策略参数，使得模型倾向于生成更高奖励的回答。 环境： 由提示数据集和奖励模型共同构成。\n智能体： 我们需要训练的模型（通常由SFT模型初始化），它的策略就是根据提示生成回答的概率分布。\n动作： 生成下一个词元（token）。\n状态： 当前的提示和已经生成的部分回答。\n奖励： 当一个完整的回答生成后，奖励模型会给出一个奖励分数 R(prompt, response)。\n目标： 最大化奖励模型给出的期望奖励，使模型输出更符合人类偏好。\n4.3 强化学习优化算法 # 在强化学习阶段，常用的算法包括：\nPolicy Gradient（策略梯度）：通过梯度上升最大化期望奖励； PPO（Proximal Policy Optimization）：一种改进的策略优化算法，能够在稳定性与性能之间取得平衡，因而成为当前主流的RLHF优化方法（如ChatGPT使用的）。 4.4 总结 # ​\tRLHF框架的核心思想是将人类偏好引入模型优化过程。 通过“人类反馈 → 奖励模型 → 强化学习”三阶段机制，语言模型不仅学习语言规律，更学习“人类认为好的”回答方式，从而在开放式任务中展现出更自然、更符合人类预期的行为。\n5. InstructGPT # 5.1 摘要与背景 # ​\t传统大规模自回归语言模型（GPT 系列）用下一个 token 的似然作为训练目标（无监督 / 自监督），因此，模型输出的“好坏”高度依赖训练语料的分布与质量，在“服从用户意图、真实与无害”这类目标上存在明显错配（misalignment）。为了解决这一点，作者提出用人类偏好（human preference）作为监督信号来微调模型，使模型“更会听指令、更真实、更少有毒内容”。\n5.2 核心思想 # ​\tInstructGPT 的核心是 RLHF（Reinforcement Learning from Human Feedback）三步流水线：\nSFT（Supervised Fine-Tuning）：人工编写/演示式示例（labeler demonstrations）对 GPT-3 进行监督微调，使模型初步“学会”如何按指令产出。 RM（Reward Model）训练：对不同模型产出的若干候选回复，由人工对这些回复进行排序（ranking）；训练一个模型 rθ(x,y) 来预测人工偏好（即哪个回复会被人工选为更好）。 RL（PPO）微调：以 RM 的输出作为 reward，使用 PPO 对 SFT 模型进一步微调；为了防止策略漂移过大，在每个 token 上加入与 SFT 模型的 KL 惩罚（或把预训练梯度混入，见 PPO-ptx）。 创新点：\n将\u0026quot;人类偏好\u0026quot;量化为可优化的目标函数 将复杂的价值观判断分解为可学习的奖励信号 在保持语言能力的同时优化对话质量 5.3 数据集：来源、格式与规模 # ​\t数据来源分为两类：labeler-written（由标注员构造的指令/示例） 和 customer/API prompts（用户在 Playground 提交的真实提示）。论文把用于三类训练的集合列出（以“prompt 数目”计）：\nSFT 数据集（问答指令数据集） train — labeler：11,295 train — customer：1,430 valid — labeler：1,550 valid — customer：103 RM（Reward Model）数据集（偏好数据集，用于训练排序模型） train — labeler：6,623 train — customer：26,584 valid — labeler：3,488 valid — customer：14,399 PPO（RL）数据集（提示数据集，用于 RLHF，通常只用 customer prompts） train — customer：31,144 valid — customer：16,185 其他数据相关要点：\n对每个 prompt，标注员会对模型生成的 K 个候选回复（K≈4–9） 做排名（ranking），这样会产生 up to K*(K−1)/2 对比较对，用于 RM 的训练。为了避免过拟合，论文对比较对做了特殊的批处理策略（把一个 prompt 的所有 K 个回复作为一个 batch 单元来训练 RM）。 标注团队约 40 名承包标注员（contractors），通过筛选测试选出，以保证对“有害/有偏见内容”的敏感性和偏好判定的质量。 5.4 模型与训练目标 # SFT（监督微调）\n用标注员示范（prompt → 人工写好的理想回复）做 标准的交叉熵监督训练，得到 SFT 基线模型（ πSFT\\pi_{SFT}πSFT​ ）。 Reward Model（RM）的损失（pairwise ranking loss）\nRM 的目标是给出一个标量评分 rθ(x,y)r_\\theta(x,y)rθ​(x,y)，使得被人工偏好（更好的回复）在评分上高于被舍弃的回复。论文采用成对排序的 logistic（pairwise）损失，形式为（论文公式）： L(θ)=−E(x,yw,yℓ)∼D[log⁡σ(rθ(x,yw)−rθ(x,yℓ))] \\mathcal{L}(\\theta) = -\\mathbb{E}_{(x,y_w,y_\\ell)\\sim D}\\left[\\log \\sigma\\big(r_\\theta(x,y_w)-r_\\theta(x,y_\\ell)\\big)\\right] L(θ)=−E(x,yw​,yℓ​)∼D​[logσ(rθ​(x,yw​)−rθ​(x,yℓ​))]其中 ywy_wyw​ 是在一对中被标注为“更好”的回复，σ\\sigmaσ 为 sigmoid。直观上，这个损失推动 RM 使得被偏好的回复得分更高。论文在实现上把对同一 prompt 的 K 个回复一起作为一个 batch 元素来避免过拟合并加速。\nRL（PPO）阶段的目标（带 KL 惩罚的期望 reward）\n用 RM 的标量 rθ(x,y)r_\\theta(x,y)rθ​(x,y) 作为 reward，使用 PPO 对策略 πRL\\pi_{RL}πRL​（基于 SFT 初始化）做 max-reward 的强化学习。为避免模型为了追求 reward 而产生极端/不可控输出，加入 每个 token 的 KL 惩罚项（相对于 πRL\\pi_{RL}πRL​），并且论文还提出将预训练（pretrain）目标的梯度混入 RL 更新（称为 PPO-ptx），总体目标可以写成（论文给出的合并形式）： J(ϕ)=E(x,y)∼DπRLϕ[rθ(x,y)−βlog⁡πRLϕ(y∣x)πSFT(y∣x)] + γ Ex∼Dpretrain[log⁡πRLϕ(x)] J(\\phi)=\\mathbb{E}_{(x,y)\\sim D_{\\pi_{\\text{RL}}^\\phi}}\\Big[ r_\\theta(x,y) - \\beta\\log\\frac{\\pi_{\\text{RL}}^\\phi(y\\mid x)}{\\pi_{\\text{SFT}}(y\\mid x)}\\Big] \\;+\\; \\gamma\\,\\mathbb{E}_{x\\sim D_{\\text{pretrain}}}\\big[\\log \\pi_{\\text{RL}}^\\phi(x)\\big] J(ϕ)=E(x,y)∼DπRLϕ​​​[rθ​(x,y)−βlogπSFT​(y∣x)πRLϕ​(y∣x)​]+γEx∼Dpretrain​​[logπRLϕ​(x)] 解释：第一项是期望 RM-reward 减去 β 倍的 KL（以 log 比例表示），第二项（系数 γ）是把预训练 log-likelihood 加回去以减轻 RL 导致的“性能回退”（alignment tax）。 5.5 训练流程 # 总体步骤：SFT → 收集候选回复与人工排序 → 训练 RM → 用 RM 做 reward，对 SFT 用 PPO 微调（可选混入 pretrain 梯度）。\n详细流程与实现细节（论文中给出的关键点）：\n收集 prompt 与演示：从 InstructGPT Playground 的用户 prompt（customer）与标注员编写的 prompt（labeler）混合构成数据集；标注员为每个 prompt 编写示范回答（SFT 数据）。 SFT 训练：用标注员示范训练 SFT 模型（cross-entropy）。 生成候选并做 ranking：用多种策略（不同 checkpoint /温度等）生成 K 个回复，标注员对 K 个回复做完整排序（K≈4–9），得到比较数据。论文采用将同一 prompt 的所有 K 个回复作为一个 batch 元素来训练 RM，能避免过拟合。 训练 RM：训练单一 6B reward model（论文最终选择用 6B RM，因为 175B RM 虽然在验证集上 loss 低但训练不稳定、计算成本更高），RM 用上面的 pairwise loss 训练（一般训练 1 个 epoch，lr ≈ 9e-6，batch size 64，训练对 epoch 很敏感，过多 epoch 会过拟合）。 RM 的初始化：用一个 6B GPT-3（在若干公开 NLP 数据集上 finetune 的 checkpoint）来初始化。 PPO 微调（RLHF）：把 RM 输出作为即时 reward 的 bandit 环境（每条交互是一个 prompt → single reply，episode 立即结束）。用 PPO 对策略进行优化，同时在 reward 中加入 每 token KL 惩罚（系数 β），并可选地把预训练 loss 的梯度混入（系数 γ，得到 PPO-ptx）以减轻在某些公共 NLP 任务上的性能回退。论文中 value-function 也用 6B 初始化并被用于 advantage 估计。 超参数/工程要点（论文中提及）：\nRM 训练：单 epoch，lr≈9e-6，batch=64；太多 epoch 会快速过拟合。 对所有 PPO 模型都使用同一个 6B RM 与 6B value function（便于比较不同 policy model 尺寸的效果）。具体 policy 的 value learning rate：1.3B/6B 用 9e-6，175B 用 5e-6（论文给出实验细节）。 5.6 实验结果 # 人类偏好（primary claim）：在作者的 API prompt 分布上，标注员显著更喜欢 InstructGPT（经过 RLHF 的模型）输出。甚至 1.3B 的 InstructGPT（PPO-ptx）在他们的测试集上比 175B 原始 GPT-3 更受偏好（也就是说，微调与对齐带来的质量提升可以超过单纯“更大模型参数量”的提升）。此外，175B 的 InstructGPT 比 175B GPT-3 优胜约 85% ± 3%，相比 few-shot 175B GPT-3 优胜 71% ± 4%。\n真实性（TruthfulQA）与杜绝 hallucination：在 TruthfulQA 基准上，InstructGPT 生成既真实又信息性的回答出现频率约为 GPT-3 的 2 倍；在“闭域任务”（输出不应包含输入中不存在的信息）中，InstructGPT 的编造（hallucination）率约 21%，而 GPT-3 为 41%。\n毒性（toxicity）：使用 RealToxicityPrompts 做自动与人工评估，InstructGPT 在“被提示要尊重他人”的条件下生成的有毒回复 减少约 25%。在偏见（bias）评估（例如 Winogender、CrowS-Pairs）上并未显著好转。\n性能退化（alignment tax）：直接用 RLHF 可能会在某些公共 NLP 基准（如 SQuAD、DROP、HellaSwag、WMT 翻译）上出现回退；论文提出通过将预训练梯度混入（PPO-ptx）可以在不牺牲偏好得分的前提下显著降低这种回退。\n5.7 局限性与讨论 # 偏好是“有界”的：RLHF 将模型对齐到特定标注员（约 40 人）的偏好集合，这 并不等同于普遍的人类价值；不同群体/不同标注员偏好可能不同，模型如何在更广泛的人群上泛化仍需研究。\nRM 与 reward hacking 风险：RM 只是对人工偏好的近似，RL 阶段可能出现对 RM 的过度优化（并非真实的“更好”），因此论文采用 KL 惩罚与其它工程手段来缓解，但长期来看这是一个研究难点。\n不彻底解决事实性与安全性问题：尽管在 TruthfulQA 和毒性上有改善，但 InstructGPT 仍会犯错、编造事实或在某些任务上表现不佳。论文把这当作未来工作方向。\n5.8 结论 # InstructGPT 的关键贡献不是提出全新的模型架构，而是展示了 用人类偏好来微调大型 LM（RLHF）是一个可行的、在实用上效果显著的对齐路径：能显著提升“按指令回答”和“更可信 / 更少有毒”这类行为，且在某些情况下，参数更少但经过 RLHF 的模型优于更大的未对齐模型。\n三步流程（SFT → RM → PPO）是工程上可复现的管线；关键实践包括：用合适的 batch \u0026amp; ranking 策略训练 RM、用 KL（或 pretrain-mix）限制 RL drift、用单一稳定的 6B RM 作为所有策略的 reward 模型以降低不稳定性与计算开销。\n参考：\nRLHF框架\n强化学习算法：Policy Gradient\n强化学习算法：PPO\nLLM 系列超详细解读 (四)：InstructGPT：训练语言模型以遵从人类指令\nInstructGPT 论文精读\n","date":"2026年6月4日","externalUrl":null,"permalink":"/posts/llm-basic/instructgpt/","section":"文章","summary":"","title":"InstructGPT","type":"posts"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/instructgpt/","section":"标签","summary":"","title":"InstructGPT","type":"tags"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/rlhf/","section":"标签","summary":"","title":"RLHF","type":"tags"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/gpt/","section":"标签","summary":"","title":"GPT","type":"tags"},{"content":" 前言 # ​\tGenerative Pre-trained Transformer（GPT）系列是由OpenAI提出的非常强大的预训练语言模型，这一系列的模型可以在非常复杂的NLP任务中取得非常惊艳的效果，例如文章生成，代码生成，机器翻译，Q\u0026amp;A等，而完成这些任务并不需要有监督学习进行模型微调。而对于一个新的任务，GPT仅仅需要非常少的数据便可以理解这个任务的需求并达到接近或者超过state-of-the-art的方法。\n​\t当然，如此强大的功能并不是一个简单的模型能搞定的，GPT模型的训练需要超大的训练语料，超多的模型参数以及超强的计算资源。GPT系列的模型结构秉承了不断堆叠transformer的思想，通过不断的提升训练语料的规模和质量，提升网络的参数数量来完成GPT系列的迭代更新的。GPT也证明了，通过不断的提升模型容量和语料规模，模型的能力是可以不断提升的。\n1. GPT：无监督学习 # 1.1 核心思想 # ​\t在GPT-1之前（和ELMo同一年），传统的NLP模型往往使用大量的数据对有监督的模型进行任务相关的模型训练，但是这种有监督学习的任务存在两个缺点：\n​\t需要大量的标注数据，高质量的标注数据往往很难获得，因为在很多任务中，图像的标签并不是唯一的或者实例标签并不存在明确的边界；\n​\t根据一个任务训练的模型很难泛化到其它任务中，这个模型只能叫做“领域专家”而不是真正的理解了NLP。\n​\tGPT-1的核心思想是，通过在超大规模无标签文本数据上进行生成式预训练，模型可以学习到丰富的世界知识和强大的语言规律。然后，对于特定的下游任务，只需要在预训练模型的基础上，通过一个简单的线性输出层 和少量的任务特定数据微调，就能将学到的通用知识迁移到新任务上，从而取得优异的效果。这种方法的优势在于避免了为每个任务从头设计模型架构。\n​\t预训练+微调范式：GPT-1开创性地将”预训练+微调“这一范式应用于NLP领域，证明其强大效力。\n​\t模型架构：它使用了Transformer的解码器 堆叠，并且只使用了掩码自注意力机制，没有使用编码器-解码器架构中的交叉注意力机制。这使得模型在生成每个词时只能关注到它左侧的上下文，是一个单向模型。\n​\t预训练目标：预训练阶段采用标准的语言模型目标，即根据前文预测下一个词，目标是最大化似然估计。\n​\t下游任务适配：针对不同的下游任务，会构造不同的输入序列格式。\n1.2 GPT的训练 # ​\tGPT-1的训练分为无监督的预训练和有监督的模型微调。\n1.2.1 无监督预训练 # ​\t采用标准的语言模型的目标函数，即似然函数，根据前k个词预测下一个词的概率。具体如下图所示：\n1.2.2 有监督微调 # ​\t使用完整的输入序列+标签。目标函数=有监督的目标函数+λ*无监督的目标函数。具体如下图所示：\n​\t有监督微调中最终损失函数 L3 = L2 + w * L1\n​\tL1： 预训练任务的损失，即语言模型损失（根据上下文预测下一个词）。\n​\tL2： 应该是下游监督任务的损失（如分类任务的交叉熵损失）。\n​\t最终损失： L_total = L_supervised + λ * L_LM。\n​\t在微调时，模型同时优化两个目标：\n主要目标：正确完成下游任务（如分类正确），对应损失 L_supervised。 辅助目标：继续保持强大的语言建模能力，对应损失 L_LM。 ​\t这里的 λ 是一个超参数，论文中通常设置为 0.5。即总损失是监督损失和语言模型损失的加权和。\n1.3 GPT-1针对不同下游任务的输入/输出构造 # ​\tGPT-1论文中主要针对四大类任务设计了输入变换。其核心思想是：将所有任务都重构为模型在预训练时见过的“序列预测”任务。\n输入序列由以下几种标记（Token）构成：\nStart： 序列的开始标记。 Delim： 分隔符，用于分隔不同的句子或部分。 Extract： 提取标记，用于某些需要提取答案的任务。 以下是四个任务的详细构造：\n1. 文本分类（Classification，如情感分析）\n输入构造： [Start] Text [Extract] 输出： 将Transformer最后一个时间步（即 [Extract] 标记对应位置）的输出，送入一个线性层+Softmax进行分类。 解释： 模型读取整个文本，然后在最后输出一个分类结果。 2. 文本蕴含（Entailment，即自然语言推理）\n输入构造： [Start] Premise [Delim] Hypothesis [Extract] 输出： 同样，将 [Extract] 标记对应的输出送入线性分类器，判断蕴含、矛盾或中立。 解释： 将前提和假设用分隔符连接，让模型理解两者关系后做出判断。 3. 相似度计算（Similarity）\n输入构造： 由于两个句子没有固定顺序，为了对称性，会构造两个输入序列。 序列1: [Start] Text1 [Delim] Text2 [Extract] 序列2: [Start] Text2 [Delim] Text1 [Extract] 输出： 将两个输入序列在 [Extract] 标记处的输出按元素相加，然后将结果送入线性分类器预测相似度。 解释： 通过两种顺序的输入，让模型更好地理解句子间的对称相似关系。 4. 多项选择问答（Multiple Choice QA）\n输入构造： 对于每个候选答案，都构造一个输入序列。\n对于每个 (Context, Question, Answer_i)，构造： [Start] Context [Delim] Question [Delim] Answer_i [Extract]\n输出： 对每个候选答案对应的序列，计算 [Extract] 标记处的输出，并通过一个线性层得到一个分数。最后对所有候选答案的分数进行Softmax归一化，选择概率最高的答案。\n解释： 将问题和每个候选答案组合成一个新的“文本”，让模型判断哪个组合最通顺、最合理。\n以上部分的解释如下图所示： 1.4 GPT数据集 # ​\tGPT-1使用了BooksCorpus数据集。这个数据集包含7000本没有发布的书籍。作者选这个数据集的原因有二：\n数据集拥有更长的上下文依赖关系，使得模型能学得更长期的依赖关系； 这些书籍因为没有发布，所以很难在下游数据集上见到，更能验证模型的泛化能力。 1.5 GPT模型参数 # 1.5.1预训练阶段 # 模型架构： 12层Transformer解码器堆叠。 注意力头数： 12个。 隐藏层维度： 768维。 FFN层中间维度： 3072维（4倍隐藏层维度）。 激活函数： GELU。 参数总量： 约1.17亿 (117M)。 词表大小： 40,000 Byte Pair Encoding (BPE) 子词单元。 位置编码： 学习式的位置编码。 预训练数据： BooksCorpus数据集（约7,000本未出版的书籍）。 优化器： Adam。 Batch Size： 64个随机、连续的文本序列，每个序列长度为512个token。 学习率为 2.5e-4 ，训练epoch为 100；模型参数数量为 1.17亿。 1.5.2 有监督微调阶段 # 模型主体架构和参数： 从预训练模型中初始化。 新增参数： 仅在顶部为每个任务添加一个线性分类层。 超参数： 大部分超参数与预训练阶段相同。训练的epoch为3 ，学习率为 6.25e-5。 损失函数权重 λ： 论文中在大部分任务上设置为 0.5。 1.6 实验结果 # ​\t在有监督学习的12个任务中，GPT-1在9个任务上的表现超过了state-of-the-art的模型。在没有见过数据的zero-shot任务中，GPT-1的模型要比基于LSTM的模型稳定，且随着训练次数的增加，GPT-1的性能也逐渐提升，表明GPT-1有非常强的泛化能力，能够用到和有监督任务无关的其它NLP任务中。GPT-1证明了transformer对学习词向量的强大能力，在GPT-1得到的词向量基础上进行下游任务的学习，能够让下游任务取得更好的泛化能力。对于下游任务的训练，GPT-1往往只需要简单的微调便能取得非常好的效果。\n​\tGPT-1在未经微调的任务上虽然也有一定效果，但是其泛化能力远远低于经过微调的有监督任务，说明了GPT-1只是一个简单的领域专家，而非通用的语言学家。\n1.7 GPT与BERT区别 # 特性 GPT-1 BERT 核心架构 Transformer 解码器 Transformer 编码器 注意力机制 单向 / 因果。只能关注左侧上下文，采用掩码自注意力。 双向。可以同时关注左右两侧的上下文，采用完全自注意力。 预训练目标 自回归语言模型。目标是根据前文预测下一个词。 去噪自编码。主要目标是 Masked Language Model，即随机遮盖词并预测它。此外还有 Next Sentence Prediction 任务。 模型能力倾向 生成任务。因其单向特性，天然适合文本生成、对话等。 理解任务。因其双向特性，在文本分类、阅读理解、实体识别等理解类任务上表现更强。 数据流 从左到右的序列，适合序列生成。 一次性看到整个句子，适合整体理解。 微调策略 在总损失中引入语言模型损失作为辅助 (L_total = L_task + λ * L_LM)。 通常只优化下游任务损失 (L_task)，不使用MLM损失作为辅助。 代表性参数 117M 参数 (GPT-1 base) BERT_Base: 110M 参数 (12层, 768隐层, 12头) BERT_Large: 340M 参数 (24层, 1024隐层, 16头) 哲学差异 “通过生成来理解”。认为一个能够完美预测下一个词的模型，必然已经深刻理解了语言。 “通过完形填空来理解”。认为通过恢复被破坏的文本，可以学习到词语和句子间的深层关系。 3. GPT-2：多任务学习 # 2.1 核心思想 # ​\tGPT-2 的提出是对 GPT-1 “预训练+微调”范式的一次突破。\n​\t在 GPT-1 中，模型先在大规模语料上进行语言模型预训练，再通过有监督的微调适配具体下游任务。这种方式虽然有效，但存在两个明显弊端：\n任务依赖性强：每个任务都需要收集带标签的数据集，并针对性地微调模型。 计算开销大：微调过程不仅耗时，还需要大量计算资源，这使得跨领域泛化能力受限。换句话说，GPT-1 并不是“真正的通用智能”，而是一种带有监督信号的领域自适应。 ​\tGPT-2 的核心论点是：当模型容量足够大且训练语料足够广泛、多样时，一个纯粹的自回归语言模型通过学习文本续写的条件概率，就能够在统计意义上覆盖并完成许多传统上需要监督学习解决的下游任务，而无需对每个任务单独微调。换言之，许多有监督任务（翻译、摘要、问答等）只是自然语料中出现的特定文本格式或“任务范式”的子集；模型在大规模语料中见过类似的任务描述与示例后，就能通过在推理时给出自然语言提示（prompt）来实现 zero-shot 的任务执行。为保证 zero-shot 的有效性，提示应尽量与预训练时看到的自然语言形式一致（而不是包含训练时未见过的特殊开始/结束符号或人工结构化标记），例如用“Translate to French: …”或“Summary: …”这样自然的任务描述来引导模型续写。\n​\tZero-Shot 学习（零样本学习）：在推理阶段，不再依赖标注数据和参数更新，只需提供一个描述任务的 Prompt。由于模型在训练语料中已经见过各种任务的自然描述，它能够直接“类比”完成这些任务。例如：\n在训练中见过“文章：… 总结：…” → 学会摘要。 见过“English: … French: …” → 学会翻译。 见过“问题：… 答案：…” → 学会问答。 因此，在推理时，只要用类似的提示语构造输入，模型便能在 Zero-Shot 设定下给出合理的输出。 ​\t综上：GPT-1 依赖“预训练+微调”，导致对任务和标注数据高度依赖；GPT-2 则将所有任务视为条件化的语言建模问题，通过设计自然语言的 prompt 与示例在同一模型中复用预训练得到的能力，从而降低对标注数据和微调次数的依赖，实现更好的可扩展性和即时适应能力。其有效性依赖于模型容量与语料覆盖度——当模型和数据足够大时，训练集中自然存在的“任务格式”足以让模型在推理阶段通过提示直接完成许多监督任务。\n2.2 GPT-2数据集 # ​\tGPT-2的文章取自于Reddit上高赞的文章，命名为 WebText ，它是一个超大、超多样化的数据集，包含了从互联网上爬取的数千万个网页，涵盖了新闻、论坛、书籍、代码、问答等各种文体和内容。正是这种 “任务多样性”被隐式地编码在了训练数据中，才使得模型学会了在Zero-Shot设置下响应各种提示。\n​\t注意：GPT-2的论文主要推崇和验证的是Zero-Shot能力。但需要注意的是，GPT-2模型同样可以像GPT-1一样，通过在有标签数据上微调来获得特定任务的极致性能。论文中在部分任务上也给出了微调后的结果作为对比。\n2.3 GPT-2模型参数 # ​\tGPT-2的一个显著特点是其可扩展性。为了探索模型规模与性能的关系，OpenAI发布了四个不同规模的版本。\n模型名称 层数 隐藏层维度 注意力头数 参数量 GPT-2 Small 12 768 12 124 Million GPT-2 Medium 24 1024 16 355 Million GPT-2 Large 36 1280 20 774 Million GPT-2 XL 48 1600 25 ~1.5 Billion 2.4 实验结果 # ​\t在8个语言模型任务中，仅仅通过zero-shot学习，GPT-2就有7个超过了state-of-the-art的方法；\n​\t在“Children\u0026rsquo;s Book Test”数据集上的命名实体识别任务中，GPT-2超过了state-of-the-art的方法约7%；\n​\t“LAMBADA”是测试模型捕捉长期依赖的能力的数据集，GPT-2将困惑度从99.8降到了8.6；\n​\t在阅读理解数据中，GPT-2超过了4个baseline模型中的三个；\n​\t在法译英任务中，GPT-2在zero-shot学习的基础上，超过了大多数的无监督方法，但是比有监督的state-of-the-art模型要差；\n​\tGPT-2在文本总结的表现不理想，但是它的效果也和有监督的模型非常接近。\n2.5 GPT-2与GPT-1对比和区别 # 特性 GPT-1 GPT-2 核心目标 验证“预训练+微调”范式的有效性。 验证大规模无监督预训练能否直接实现Zero-Shot多任务学习。 训练数据 BooksCorpus（约7,000本书，4.6GB）。 WebText（约800万网页，40GB）。数据量更大、来源更广、内容更多样。 模型规模 单一模型，117M 参数。 多个规模的模型，从124M到1.5B，核心是探索Scaling Law。 上下文长度 512 tokens。 1024 tokens。 技术改进 标准的Transformer解码器架构。 移除了微调阶段的辅助LM损失；调整了层归一化的位置；改进了初始化方法。 任务处理方式 任务特定微调。需要为不同任务设计输入变换和添加线性分类头。 任务无关的Zero-Shot。通过提示（Prompt） 来引导模型，无需修改模型架构或参数。 哲学思想 “通过生成式预训练来获得一个强大的、可迁移的特征提取器”。 “一个通用的、任务无关的系统，可以通过条件生成来完成任何任务”。 影响 与BERT一起，确立了预训练+微调作为NLP新范式。 提出了“万物皆可生成”和“Prompt即指令”的雏形，为GPT-3和后来的大语言模型（LLM）革命铺平了道路。 3. GPT-3：上下文学习 # 3.1 核心思想 # ​\t在 GPT-2 中，研究者已经证明了大规模语言模型具备 Zero-Shot 学习能力，即不经过任务特定的微调，仅依赖合适的 Prompt 也能完成部分任务。然而，GPT-2 的参数规模（15 亿）仍然有限，其在复杂任务上的表现不够理想。基于此，OpenAI 在 GPT-3 中进一步提出了“规模定律（Scaling Law）：当模型规模、数据量和计算量不断提升时，模型在语言理解与生成上的能力会持续增长，而无需改变架构。”，并通过构建 1750 亿参数的超大模型，探索了更高层次的通用性。\n​\tGPT-3 的核心思想可以概括为以下几点：\n​\t极大规模化： GPT-3 的最大特点是采用了前所未有的参数规模。研究者发现，当模型规模、数据量和算力不断提升时，模型的语言理解与生成能力会持续增强。这一发现表明，单纯依靠规模扩展而不改变模型架构，依旧能够显著提升模型的泛化性能。\n​\t上下文学习（In-Context Learning）： GPT-3 在推理阶段展现出了一种新的学习方式，即通过输入中的上下文提示来完成任务，而无需更新参数。例如，在文本分类任务中，只需在输入中提供少量“输入—输出”的示例，模型便能够类比学习并生成符合要求的结果。这表明 GPT-3 能够将学习过程转移到推理阶段的上下文中。\n​\t**元学习（Meta-Learning）：**在上下文学习的基础上，GPT-3 展现出类似 元学习 的特征。传统机器学习在面对新任务时需要显式训练，而 GPT-3 仅依赖少量示例就能快速适配任务。这种能力表现为“学习如何学习”，意味着模型能够通过 Prompt 灵活切换任务，而不需要参数层面的更新。\n​\t**Prompt 驱动范式的统一：**GPT-3 将不同任务统一为 条件文本生成问题，差异仅体现在 Prompt 的设计上：\n在 Zero-Shot 模式下，直接使用自然语言指令； 在 One-Shot 模式下，在输入中加入一个示例； 在 Few-Shot 模式下，输入多个示例（通常为10-100个），模拟训练过程。 ​\t这种方式使模型能够在无需针对性微调的前提下，完成多种自然语言处理任务。下图展示了GPT中Zero-Shot、One-Shot、Few-Shot以及Fine-tuning的区别。\n​\t综上：GPT-3 的核心贡献在于，它不仅通过极大规模化验证了“规模定律”，证明了更大模型能够显著增强任务泛化能力，而且提出了 上下文学习 的新范式，让模型能够在推理过程中直接学习任务模式。\n​\t此外，GPT-3 展现出的 元学习能力 表明，它可以通过少量示例快速适应新任务，具备了初步的“学习如何学习”的特征。这使得 GPT-3 不再仅仅是一个语言建模工具，而是演变为一个灵活的任务适配器。\n​\t最后，GPT-3 通过 Zero-Shot、One-Shot 与 Few-Shot 等 Prompting 方式，统一了不同任务的执行逻辑，真正实现了**“任务即语言建模”**。因此，GPT-3 相较于 GPT-2，不仅在规模上实现了跨越，更在认知能力层面迈出了通向通用人工智能的重要一步。\n3.2 GPT-3数据集 # ​\tGPT-3使用了比GPT-2的WebText更大、更多样化的混合数据集。其来源和比例如下：\n数据集来源 权重 描述 Common Crawl (过滤后) 60% 核心数据源，进行了严格的质量过滤（基于与高质量语料的相似度）和去重。 WebText2 (GPT-2的数据扩展版) 22% 高质量的内部数据集。 Books1 8% 两个书籍语料库之一。 Books2 8% 两个书籍语料库之一，质量更高。 Wikipedia 3% 仅英文维基百科的文本部分。 ​\t总训练数据量约570GB的纯文本，包含近万亿个单词。\n3.3 GPT-3模型参数 # ​\tGPT-3的架构与GPT-2基本相同，仍然是仅解码器的Transformer模型。其所有的改进几乎都来自于 “放大”。\n参数规模： 1750亿（175 Billion）参数。这比之前最大的稠密模型（如Turing-NLG的17B）大了一个数量级。 模型尺寸： 为了达到175B参数，OpenAI训练了8种不同规模的模型来研究缩放定律，其中最大的GPT-3配置如下： 层数： 96 注意力头数： 96 隐藏层维度： 12288 批次大小： 3.2M个token（动态调整） 上下文窗口： 2048个token（比GPT-2又翻倍） 关键技术细节： 为了在如此大的模型下节省内存，采用了模型并行技术。 在自注意力层中使用了交替的稠密和局部带状稀疏注意力模式，以在某些层中高效处理长序列 3.4 实验结果 # ​\t仅仅用惊艳很难描述GPT-3的优秀表现。首先，在大量的语言模型数据集中，GPT-3超过了绝大多数的zero-shot或者few-shot的state-of-the-art方法。另外GPT-3在很多复杂的NLP任务中也超过了fine-tune之后的state-of-the-art方法，例如闭卷问答，模式解析，机器翻译等。除了这些传统的NLP任务，GPT-3在一些其他的领域也取得了非常震惊的效果，例如进行数学加法，文章生成，编写代码等。\n3.5 GPT-2与GPT-3的对比 # 特性 GPT-2 GPT-3 核心目标 验证Zero-Shot多任务学习的可行性。 验证超大规模模型下的Few-Shot In-Context Learning能否达到或超越微调模型的水平。 关键能力 Zero-Shot Few-Shot, One-Shot, Zero-Shot，且Few-Shot是其主要优势。 哲学思想 “一个通用的任务无关系统”。 “一个通过预训练完成了元学习的通用系统，可通过上下文快速适应新任务”。 模型规模 最大 1.5B 参数。 175B 参数，两个数量级的差距。 训练数据 WebText (~40GB)。 混合高质量数据集，~570GB，规模更大、清洗更严格。 上下文长度 1024 tokens。 2048 tokens。 计算成本 相对较低。 极其昂贵，训练一次需数千PetaFLOPs-day的计算量。 评估方式 主要评估Zero-Shot，并与微调基线对比。 系统性地评估Few-Shot, One-Shot, Zero-Shot，并与微调的SOTA模型直接竞争。 涌现能力 展示了有希望的Zero-Shot潜力。 展示了惊人的涌现能力，如进行数学运算、编写复杂的代码、理解抽象概念等，这些能力在较小模型上几乎不存在。 影响与局限 为LLM和Prompting思想铺平了道路。 真正开启了大语言模型时代。证明了“缩放定律”的惊人潜力，同时也暴露了模型的局限性（如事实幻觉、重复训练数据偏见等）。 ​\t总结：GPT-3不是一次算法上的革命，而是一次工程和理念上的极致探索。它将以GPT为代表的“生成式预训练”道路推向了当时的顶峰，并雄辩地证明：\n规模本身就是一种能力：当模型大到一定程度时，会涌现出小模型不具备的In-Context Learning等高级能力。 范式转移：对于许多NLP任务，收集大量标注数据并微调模型的传统范式，可能不再是唯一的最佳路径。通过精心设计的Prompt，直接利用大模型的内部知识成为可能。 ​\tGPT-3的成功直接催生了后来的Codex、InstructGPT和ChatGPT，奠定了当前AI浪潮的基础。\n4. 总结 # ​\tGPT系列从1到3，通通采用的是Transformer架构，可以说模型结构并没有创新性的设计。在微软的资金支持下，这更像是一场赤裸裸的炫富：1750亿的参数，31个分工明确的作者，超强算力的计算机（ 285000个CPU， 10000个GPU），1200万的训练费用，45TB的训练数据（维基百科的全部数据只相当于其中的0.6% ）。这种规模的模型是一般中小企业无法承受的，而个人花费巨金配置的单卡机器也就只能做做微调或者打打游戏了。甚至在训练GPT-3时出现了一个bug，OpenAI自己也没有资金重新训练了。\n​\t读懂了GPT-3的原理，相信我们就能客观的看待媒体上对GPT-3的过分神话了。GPT-3的本质还是通过海量的参数学习海量的数据，然后依赖Transformer强大的拟合能力使得模型能够收敛。基于这个原因，GPT-3学到的模型分布也很难摆脱这个数据集的分布情况。得益于庞大的数据集，GPT-3可以完成一些令人感到惊喜的任务，但是GPT-3也不是万能的，对于一些明显不在这个分布或者和这个分布有冲突的任务来说，GPT-3还是无能为力的。例如通过目前的测试来看，GPT-3还有很多缺点的:\n对于一些命题没有意义的问题，GPT-3不会判断命题有效与否，而是拟合一个没有意义的答案出来； 由于40TB海量数据的存在，很难保证GPT-3生成的文章不包含一些非常敏感的内容，例如种族歧视，性别歧视，宗教偏见等； 受限于transformer的建模能力，GPT-3并不能保证生成的一篇长文章或者一本书籍的连贯性，存在下文不停重复上文的问题。 ​\tOpenAI的CEO也发Twitter说“The GPT-3 hype is way too much. It\u0026rsquo;s impressive (thanks for the nice compliments!) but it still has serious weaknesses and sometimes makes very silly mistakes. AI is going to change the world, but GPT-3 is just a very early gimpse. We have a lot still to figure out.”\n​\tGPT-3对AI领域的影响无疑是深远的，如此强大性能的语言模型的提出，为下游各种类型的NLP任务提供了非常优秀的词向量模型，在此基础上必将落地更多有趣的AI应用。近年来，硬件的性能在飞速发展，而算法的研究似乎遇见了瓶颈，GPT-3给冷清的AI领域注入了一剂强心剂，告诉各大硬件厂商它们的工作还要加油，只要算力足够强，AI的性能还有不断提升的上界。\n​\t同时GPT-3如此高昂的计算代价也引发了一些关于AI领域垄断的一些担心，对于如此高的算力要求，中小企业是否有能力负担的起，或者对于这些企业来说，是否有必要花这么多钱就训练一个词向量模型。长此以往，恐怕会形成AI巨头对算力要求高的算法的技术垄断。\n参考：\nGPT，GPT-2，GPT-3 论文精读\nGPT：无标注数据的预训练生成式语言模型\nGPT-2：GPT 在零样本多任务学习的探索\nGPT-3：大型语言模型是少样本学习器\n预训练语言模型之GPT-1，GPT-2和GPT-3\n","date":"2026年6月4日","externalUrl":null,"permalink":"/posts/llm-basic/gpt/","section":"文章","summary":"","title":"GPT 系列","type":"posts"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/%E9%A2%84%E8%AE%AD%E7%BB%83%E6%A8%A1%E5%9E%8B/","section":"标签","summary":"","title":"预训练模型","type":"tags"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/bert/","section":"标签","summary":"","title":"BERT","type":"tags"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/transformer/","section":"标签","summary":"","title":"Transformer","type":"tags"},{"content":" 1. Transformer # 1.1 背景 # ​\t传统序列建模主要依赖 RNN/LSTM 或 CNN，但这类方法存在无法并行、长距离依赖难学习的问题。\n​\tTransformer的创新点是完全舍弃循环和卷积，完全基于注意力机制来建模序列关系。对比如下图所示： 1.2 模型核心思想 # ​\t模型结构如下图所示： ​\t自注意力 (Self-Attention)：序列中的每个位置都可以直接关注到任意位置，不受距离限制。\n​\t缩放点积注意力 (Scaled Dot-Product Attention)：通过 Query-Key-Value 机制计算注意力权重。\n​\t多头注意力 (Multi-Head Attention)：多个子空间并行关注不同的信息，提高模型表达能力。\n​\t位置编码 (Positional Encoding)：因为没有卷积和循环，所以通过正弦/余弦函数引入位置信息。\n​\t编码器-解码器结构：和传统 Seq2Seq 类似，但内部完全由注意力 + 前馈网络构成。\n1.3 实验结果 # ​\t在 WMT 2014 英德翻译任务上，Transformer (big) 模型取得 28.4 BLEU，比当时最好的结果高出 2 BLEU。\n​\t在英法翻译任务上，达到 41.8 BLEU，刷新单模型最优结果。\n​\t训练速度更快：仅用 8 块 P100 GPU 训练 3.5 天，比之前的 RNN/CNN 模型训练成本低得多。\n1.4 优势总结 # ​\t高并行度：相比 RNN 串行计算，Transformer 可以大幅并行化。\n​\t长距离建模能力强：任何位置之间只需常数步数即可建立依赖关系。\n​\t更好的效果 \u0026amp; 更低的训练成本：既提高了 BLEU 分数，又缩短了训练时间。\n1.5 代码 # # -*- codeing = utf-8 -*- import math import pandas as pd import torch from torch import nn from d2l import torch as d2l from transformer_utils import Encoder, MultiHeadAttention, PositionalEncoding, AttentionDecoder, EncoderDecoder, train_seq2seq, predict_seq2seq, bleu # 基于位置的前馈神经网络 class PositionWiseFFN(nn.Module): \u0026#34;\u0026#34;\u0026#34;基于位置的前馈网络\u0026#34;\u0026#34;\u0026#34; def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs): super(PositionWiseFFN, self).__init__(**kwargs) self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens) self.relu = nn.ReLU() self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs) def forward(self, X): return self.dense2(self.relu(self.dense1(X))) # 使用残差连接和层规范化 class AddNorm(nn.Module): \u0026#34;\u0026#34;\u0026#34;残差连接后进行层规范化\u0026#34;\u0026#34;\u0026#34; def __init__(self, normalized_shape, dropout, **kwargs): super(AddNorm, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout) self.ln = nn.LayerNorm(normalized_shape) def forward(self, X, Y): return self.ln(self.dropout(Y) + X) # 实现编码器中的一个层 class EncoderBlock(nn.Module): \u0026#34;\u0026#34;\u0026#34;Transformer编码器块\u0026#34;\u0026#34;\u0026#34; def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs): super(EncoderBlock, self).__init__(**kwargs) self.attention = MultiHeadAttention( key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias) self.addnorm1 = AddNorm(norm_shape, dropout) self.ffn = PositionWiseFFN( ffn_num_input, ffn_num_hiddens, num_hiddens) self.addnorm2 = AddNorm(norm_shape, dropout) def forward(self, X, valid_lens): Y = self.addnorm1(X, self.attention(X, X, X, valid_lens)) return self.addnorm2(Y, self.ffn(Y)) # Transformer编码器 class TransformerEncoder(Encoder): \u0026#34;\u0026#34;\u0026#34;Transformer编码器\u0026#34;\u0026#34;\u0026#34; def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, use_bias=False, **kwargs): super(TransformerEncoder, self).__init__(**kwargs) self.num_hiddens = num_hiddens self.embedding = nn.Embedding(vocab_size, num_hiddens) self.pos_encoding = PositionalEncoding(num_hiddens, dropout) self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_module(\u0026#34;block\u0026#34;+str(i), EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias)) def forward(self, X, valid_lens, *args): X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens)) self.attention_weights = [None] * len(self.blks) for i, blk in enumerate(self.blks): X = blk(X, valid_lens) self.attention_weights[i] = blk.attention.attention.attention_weights return X # Transformer解码器也是由多个相同的层组成 class DecoderBlock(nn.Module): \u0026#34;\u0026#34;\u0026#34;解码器中第i个块\u0026#34;\u0026#34;\u0026#34; def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i, **kwargs): super(DecoderBlock, self).__init__(**kwargs) self.i = i self.attention1 = MultiHeadAttention( key_size, query_size, value_size, num_hiddens, num_heads, dropout) self.addnorm1 = AddNorm(norm_shape, dropout) self.attention2 = MultiHeadAttention( key_size, query_size, value_size, num_hiddens, num_heads, dropout) self.addnorm2 = AddNorm(norm_shape, dropout) self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens) self.addnorm3 = AddNorm(norm_shape, dropout) def forward(self, X, state): enc_outputs, enc_valid_lens = state[0], state[1] if state[2][self.i] is None: # 开始时为空，K、V为当前的输入X key_values = X else: # 后面随着预测会产生新的K、V，需要将其与之前的X concat在一起（预测阶段） key_values = torch.cat((state[2][self.i], X), axis=1) state[2][self.i] = key_values # 最终state[2][self.i]存储的当前计算时所有的K、V if self.training: batch_size, num_steps, _ = X.shape # 如果在训练阶段，需要使用masked_multi_head attention dec_valid_lens = torch.arange( 1, num_steps + 1, device=X.device).repeat(batch_size, 1) else: # 预测阶段则不需要使用masked_multi_head attention dec_valid_lens = None X2 = self.attention1(X, key_values, key_values, dec_valid_lens) # 注意此处为自注意力机制，K和V由编码器输入以及之后的预测提供，此处dec_valid_lens表示在训练阶段在预测第t个词的时候只看前面t-1个词 Y = self.addnorm1(X, X2) Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens) # 此处为普通的注意力机制，其K和V由解码器的输出提供，此处enc_valid_lens表示最终会去除每个句子中的pad Z = self.addnorm2(Y, Y2) return self.addnorm3(Z, self.ffn(Z)), state # Transformer解码器 class TransformerDecoder(AttentionDecoder): def __init__(self, vocab_size, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, **kwargs): super(TransformerDecoder, self).__init__(**kwargs) self.num_hiddens = num_hiddens self.num_layers = num_layers self.embedding = nn.Embedding(vocab_size, num_hiddens) self.pos_encoding = PositionalEncoding(num_hiddens, dropout) self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_module(\u0026#34;block\u0026#34;+str(i), DecoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, i)) self.dense = nn.Linear(num_hiddens, vocab_size) def init_state(self, enc_outputs, enc_valid_lens, *args): return [enc_outputs, enc_valid_lens, [None] * self.num_layers] def forward(self, X, state): X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens)) self._attention_weights = [[None] * len(self.blks) for _ in range (2)] for i, blk in enumerate(self.blks): X, state = blk(X, state) self._attention_weights[0][i] = blk.attention1.attention.attention_weights self._attention_weights[1][i] = blk.attention2.attention.attention_weights return self.dense(X), state @property def attention_weights(self): return self._attention_weights 详细代码见本地 d2l-zh-pytorch/chapter_attention_mechanisms。\n2. BERT # 2.1 背景 # ​\t传统 NLP 模型（RNN、CNN、甚至早期 Transformer）大多是任务专用的，需要针对不同任务单独训练。\n​\t当时的预训练方法（如 word2vec、ELMo）要么是静态词向量，要么是单向语言模型（GPT），不能充分利用上下文信息。\n​\tBERT 提出了一种基于 Transformer 的双向预训练模型，大幅提升 NLP 各类任务的效果。\n2.2 模型核心思想 # ​\t基于 Transformer Encoder：不同于 Transformer 的 Encoder-Decoder 结构，BERT 仅使用 Encoder 堆叠。\n​\tBERTBASE：num_blocks=12, num_hiddens=768, head=12 总参数量：110M\n​\tBERTLARGE num_blocks=24, num_hiddens=1024, head=16 总参数量：340M\n​\t双向上下文建模：相比 GPT（单向）、ELMo（前后两向拼接），BERT 通过 Masked LM 训练，实现真正的双向理解。\n​\t模型网络结构如下所示： 2.3 两个关键预训练任务 # 2.3.1 Masked Language Model (MLM) # ​\t(1) 训练思路：从输入句子中随机挑选 15% 的词，作为预测目标。\n​\t(2) 对这些词的处理方式：\n​\t80% 的概率替换成 [MASK]（如 I love dogs → I [MASK] dogs）。\n​\t10% 的概率替换成 随机词（如 I love dogs → I apple dogs）。\n​\t10% 的概率保持原词不变（但仍然需要预测）。\n​\t模型需要根据上下文预测被遮蔽位置的词。\n​\t(3) 训练数据\n​\tWikipedia（英文维基百科，约 25 亿词）\n​\tBookCorpus（约 8 亿词）\n​\t输入是原始文本（无人工标注，自监督），只需随机遮掩词即可构造训练样本。让模型学习到双向语境的表示。\n2.3.2 Next Sentence Prediction (NSP) # ​\t(1) 训练思路：\n​\t输入由两个句子（A, B）组成，模型需要预测 B 是否是 A 的真实后续句子。\n​\t(2) 构造方法：\n​\t50% 的情况：B 是 A 在语料中的真实下一句（标记为 IsNext）。\n​\t50% 的情况：B 是从语料中随机采样的句子（标记为 NotNext）。\n​\t模型通过 [CLS] 位置的向量做二分类任务。\n​\t(3) 训练数据：\n​\t同样来自 Wikipedia + BookCorpus。\n​\t不需要额外人工标注，只需从原始语料中拼接句子对，就能自动生成正负样本。二分类标号（IsNext / NotNext）。\n​\t让模型具备理解句子关系的能力。\n2.4 实验结果 # ​\t在 11 个 NLP 任务（GLUE、SQuAD、SWAG 等）上，BERT 大幅刷新了 SOTA。\n​\tSQuAD 1.1 问答任务上，BERT 超过了人类基线。\n​\tBERT 也推动了后续一系列模型的发展（RoBERTa、ALBERT、ERNIE 等）。\n2.5 优势总结 # ​\t统一框架：只需预训练一次，微调时稍加修改即可应用于不同下游任务。\n​\t强大的效果：显著提升了文本分类、问答、自然语言推理等任务的性能。\n​\t里程碑意义：BERT 开启了预训练语言模型时代，对 NLP 影响深远。\n2.6 代码 # import torch from torch import nn from bert_utils import EncoderBlock class BERTEncoder(nn.Module): \u0026#34;\u0026#34;\u0026#34;BERT编码器\u0026#34;\u0026#34;\u0026#34; def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_len=1000, key_size=768, query_size=768, value_size=768, use_bias=True): super(BERTEncoder, self).__init__() self.token_embedding = nn.Embedding(vocab_size, num_hiddens) self.segment_embedding = nn.Embedding(2, num_hiddens) # 在BERT中，位置嵌入是可学习的，因此我们创建一个足够长的位置嵌入参数 self.pos_embedding = nn.Parameter(torch.randn(size=(1, max_len, num_hiddens))) self.blks = nn.Sequential() for i in range(num_layers): self.blks.add_module(f\u0026#39;{i}\u0026#39;, EncoderBlock(key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias)) def forward(self, tokens, segments, valid_lens): # 在以下代码段中，X的形状保持不变：（批量大小，最大序列长度，num_hiddens） X = self.token_embedding(tokens) + self.segment_embedding(segments) X += self.pos_embedding.data[:, :X.shape[1], :] for blk in self.blks: X = blk(X, valid_lens) return X class MaskLM(nn.Module): \u0026#34;\u0026#34;\u0026#34;BERT的掩蔽语言模型任务\u0026#34;\u0026#34;\u0026#34; def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs): super(MaskLM, self).__init__() self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens), nn.ReLU(), nn.LayerNorm(num_hiddens), nn.Linear(num_hiddens, vocab_size)) def forward(self, X, pred_positions): num_pred_positions = pred_positions.shape[1] pred_positions_id = pred_positions.reshape(-1) batch_size = X.shape[0] batch_id = torch.arange(0, batch_size) batch_idx = torch.repeat_interleave(batch_id, num_pred_positions) # 假设batch_size=2，num_pred_positions=3 # 那么batch_idx是np.array（[0,0,0,1,1,1]） masked_X = X[batch_idx, pred_positions_id] masked_X = masked_X.reshape((batch_size, num_pred_positions, -1)) mlm_Y_hat = self.mlp(masked_X) return mlm_Y_hat class NextSentencePred(nn.Module): \u0026#34;\u0026#34;\u0026#34;BERT的下一句预测任务\u0026#34;\u0026#34;\u0026#34; def __init__(self, num_inputs): super(NextSentencePred, self).__init__() self.output = nn.Linear(num_inputs, 2) def forward(self, X): # X的形状：(batchsize,num_hiddens) return self.output(X) class BERTModel(nn.Module): \u0026#34;\u0026#34;\u0026#34;BERT模型\u0026#34;\u0026#34;\u0026#34; def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_len=1000, key_size=768, query_size=768, value_size=768, use_bias=True, hid_in_features=768, mlm_in_features=768, nsp_in_features=768): super(BERTModel, self).__init__() self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout, max_len, key_size, query_size, value_size, use_bias) self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features) self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens), nn.Tanh()) self.nsp = NextSentencePred(nsp_in_features) def forward(self, tokens, segments, valid_lens=None, pred_positions=None): encoder_X = self.encoder(tokens, segments, valid_lens) if pred_positions is not None: mlm_Y_hat = self.mlm(encoder_X, pred_positions) else: mlm_Y_hat = None # 用于下一句预测的多层感知机分类器的隐藏层，0是“\u0026lt;cls\u0026gt;”标记的索引 nsp_Y_hat = self.nsp(self.hidden(encoder_X[:, 0, :])) return encoder_X, mlm_Y_hat, nsp_Y_hat if __name__ == \u0026#39;__main__\u0026#39;: vocab_size, num_hiddens, ffn_num_input, ffn_num_hiddens, num_heads, num_layers = \\ (1000, 768, 768, 1024, 4, 2) norm_shape, dropout = [768], 0.2 encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, num_layers, dropout) tokens = torch.randint(0, vocab_size, (2, 8)) segments = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1, 1, 1]]) enc_outputs = encoder(tokens, segments, None) print(enc_outputs.shape) mlm = MaskLM(vocab_size, num_hiddens) pred_positions = torch.tensor([[1, 5, 2], [6, 1, 5]]) mlm_Y_hat = mlm(enc_outputs, pred_positions) print(mlm_Y_hat.shape) mlm_Y = torch.tensor([[7, 8, 9], [10, 11, 12]]) loss = nn.CrossEntropyLoss(reduction=\u0026#39;none\u0026#39;) mlm_loss = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1)) print(\u0026#39;mlm_loss:\u0026#39;, mlm_loss, \u0026#39;\\n shape:\u0026#39;, mlm_loss.shape) nsp = NextSentencePred(enc_outputs.shape[-1]) # NSP的输入形状:(batch_size, num_hiddens) nsp_Y_hat = nsp(enc_outputs[:, 0, :]) # 只把\u0026lt;cls\u0026gt;（每个序列第一个词元的特征维度）的特征维度输入nsp中就行 print(\u0026#39;nsp_Y_hat:\u0026#39;, nsp_Y_hat, \u0026#39;\\nnsp_Y_hat_shape:\u0026#39;, nsp_Y_hat.shape) nsp_Y = torch.tensor([0, 1]) nsp_loss = loss(nsp_Y_hat, nsp_Y) print(\u0026#39;nsp_loss:\u0026#39;, nsp_loss, \u0026#39;\\nnsp_loss_shape:\u0026#39;, nsp_loss.shape) 详细代码见本地 d2l-zh-pytorch/chapter_natural_language_processing。\n2.6 BERT用于下游任务 # 2.6.1 句子级别任务 # 这类任务通常使用 [CLS] token 的向量表示整个句子的语义信息。\n（1）句子分类任务（如：情感分析、自然语言推理）\n输入格式：\n[CLS] Sentence A [SEP] Sentence B [SEP] 对单句任务：Sentence A 是目标句子。 对句对任务：Sentence A、Sentence B 是两个句子。 使用方式：\n取 [CLS] token 的最终隐藏向量（h_CLS）作为句子的整体表示。 输入一个全连接层（+Softmax）得到分类结果。 对应任务示例：\nMNLI（多领域自然语言推理） QQP（句子是否语义相同） SST-2（情感分类） QNLI（问句与句子是否相关） 2.6.2 词语级别任务 # 这类任务关注句子中每个词的输出表示。\n（2）序列标注任务（如：命名实体识别 NER）\n输入格式：\n[CLS] word1 word2 ... wordN [SEP] 使用方式：\n对每个词的最后一层隐藏向量 hih_ihi​，添加一个线性层（Softmax）预测标签。 常见输出是 BIO/BILOU 标签形式。 对应任务示例：\nCoNLL-2003 NER 2.6.3 句对关系任务 # 这类任务需要判断两个句子之间的关系，使用 NSP 预训练任务 的能力。\n（3）自然语言推理 / 句子关系判断\n输入格式： [CLS] Sentence A [SEP] Sentence B [SEP] BERT 会根据两个句子的上下文交互表示来判断它们的语义关系（例如“蕴含”、“矛盾”、“无关”）。\n输出层与句子分类任务相同（用 [CLS] 的输出向量做分类）。\n对应任务示例：\nMNLI, QNLI, RTE 2.6.4 机器问答任务 # （4）SQuAD 问答任务\n输入格式：\n[CLS] Question [SEP] Paragraph [SEP] 模型输出：\n对段落中的每个词，预测它是答案的起始位置或结束位置。 使用两个独立的线性层： Start layer → 每个词的 start probability End layer → 每个词的 end probability 选取概率最高的 (start, end) 作为答案范围。 对应任务示例：\nSQuAD v1.1 / v2.0 参考：\nTransformer\nBERT预训练\nTransformer论文逐段精读\nBERT 论文逐段精读\n","date":"2026年6月4日","externalUrl":null,"permalink":"/posts/llm-basic/transformer-bert/","section":"文章","summary":"","title":"Transformer 与 BERT","type":"posts"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/%E6%B3%A8%E6%84%8F%E5%8A%9B%E6%9C%BA%E5%88%B6/","section":"标签","summary":"","title":"注意力机制","type":"tags"},{"content":" 1. 损失函数 # 1.1 均方误差损失（Mean Squared Error, MSE） # 公式： LMSE=1n∑i=1n(yi−y^i)2 L_{MSE} = \\frac{1}{n} \\sum_{i=1}^{n} (y_i - \\hat{y}_i)^2 LMSE​=n1​i=1∑n​(yi​−y^​i​)2 其中： yiy_iyi​ ：真实值（ground truth） y^i\\hat{y}_iy^​i​ ：模型预测值 nnn ：样本数量 解释：\nMSE度量的是预测值与真实值之间的 欧氏距离的平方，常用于回归任务。 直观上，它让预测结果尽可能接近真实值，误差越大，损失惩罚越大（因为平方项会放大偏差）。 它最小化的目标：\n最小化预测值与真实值的 均方差，等价于让预测的期望值趋近于真实值。 数学上，MSE 的最小化逼近的是条件期望 E[y∣x]E[y|x]E[y∣x]。 1.2 交叉熵损失（Cross-Entropy Loss） # 公式（多分类情况）：\nLCE=−∑i=1Cyilog⁡(y^i) L_{CE} = - \\sum_{i=1}^{C} y_i \\log(\\hat{y}_i) LCE​=−i=1∑C​yi​log(y^​i​) 其中： CCC ：类别数 yiy_iyi​ ：真实标签的 one-hot 编码（正确类别为 1，其余为 0） y^i\\hat{y}_iy^​i​ ：模型预测的该类别概率（通常由 softmax 输出） 解释： 交叉熵度量了 真实分布与预测分布之间的差异。 如果预测概率完全与真实一致（即正确类别概率为 1），交叉熵损失为 0。\n它最小化的目标： 最小化真实分布与预测分布的 KL 散度：\nDKL(P∥Q)=∑iP(i)log⁡P(i)Q(i) D_{KL}(P \\parallel Q) = \\sum_i P(i) \\log \\frac{P(i)}{Q(i)} DKL​(P∥Q)=i∑​P(i)logQ(i)P(i)​ 1.3 二元交叉熵损失 # 设：\ny∈{0,1}y \\in \\{0,1\\}y∈{0,1} ：真实标签（1 表示正类，0 表示负类） y^∈[0,1]\\hat{y} \\in [0,1]y^​∈[0,1] ：预测的正类概率（通常由 sigmoid 输出） 公式：\nLBCE=−[y⋅log⁡(y^)+(1−y)⋅log⁡(1−y^)] L_{BCE} = - \\big[ y \\cdot \\log(\\hat{y}) + (1-y) \\cdot \\log(1-\\hat{y}) \\big] LBCE​=−[y⋅log(y^​)+(1−y)⋅log(1−y^​)]解释：\n当 y=1y=1y=1（正类）时： L=−log⁡(y^) L = -\\log(\\hat{y}) L=−log(y^​) 预测概率 y^\\hat{y}y^​ 越接近 1，损失越小。\n当 y=0y=0y=0（负类）时： L=−log⁡(1−y^) L = -\\log(1-\\hat{y}) L=−log(1−y^​) 预测概率 y^\\hat{y}y^​ 越接近 0，损失越小。\n多分类交叉熵： L=−∑i=1Cyilog⁡(y^i) L = - \\sum_{i=1}^{C} y_i \\log(\\hat{y}_i) L=−i=1∑C​yi​log(y^​i​) 当 C=2C=2C=2 时，softmax 的输出为 (y^,1−y^)(\\hat{y}, 1-\\hat{y})(y^​,1−y^​)，代入上式后可以化简为： L=−[ylog⁡(y^)+(1−y)log⁡(1−y^)] L = - \\big[ y \\log(\\hat{y}) + (1-y) \\log(1-\\hat{y}) \\big] L=−[ylog(y^​)+(1−y)log(1−y^​)] 结论：二分类交叉熵是多分类交叉熵在 C=2C=2C=2 时的特例。\n1.4 损失函数总结对比 # MSE\n用于回归任务 最小化预测值与真实值之间的平方误差 收敛到条件期望 交叉熵\n用于分类任务 最小化预测分布与真实分布之间的差异 收敛到真实概率分布 1.5 损失函数与目标函数 # ​\t在深度学习论文中，经常会出现 损失函数（Loss Function） 和 目标函数（Objective Function） 两个概念。两者既有联系，也有区别。\n1.5.1 损失函数（Loss Function） # 定义： 针对 单一样本，衡量模型预测与真实值之间差异的函数。\n作用： 描述模型在一个样本上的“错误程度”。\n常见例子：\n回归：均方误差（MSE）\nL(x,y)=(y−y^)2 L(x,y) = (y - \\hat{y})^2 L(x,y)=(y−y^​)2 分类：交叉熵损失（CE）\nL(x,y)=−∑iyilog⁡(y^i) L(x,y) = - \\sum_i y_i \\log(\\hat{y}_i) L(x,y)=−i∑​yi​log(y^​i​) 1.5.2 目标函数（Objective Function） # 定义： 在整个数据集上，模型最终要 优化/最小化/最大化 的函数。\n关系： 通常是对所有样本的损失函数取平均（或加总），有时还会加上正则化项。\n形式： J(θ)=1N∑i=1NL(xi,yi;θ)+λR(θ) J(\\theta) = \\frac{1}{N} \\sum_{i=1}^N L(x_i, y_i; \\theta) + \\lambda R(\\theta) J(θ)=N1​i=1∑N​L(xi​,yi​;θ)+λR(θ)其中：\nR(θ)R(\\theta)R(θ) ：正则化项 λ\\lambdaλ ：正则化系数 👉 损失函数是“局部”的，目标函数是“全局”的。\n1.5.3 在概率建模（如 GPT）中的情况 # 在 GPT 等生成模型论文中，常说 目标函数是最大化似然函数（MLE）。\nGPT 建模条件概率： Pθ(x)=∏t=1TPθ(xt∣x\u0026lt;t) P_\\theta(x) = \\prod_{t=1}^T P_\\theta(x_t \\mid x_{\u0026lt;t}) Pθ​(x)=t=1∏T​Pθ​(xt​∣x\u0026lt;t​) 最大化似然：让训练数据的概率尽可能大\nJ(θ)=max⁡θ∑t=1Tlog⁡Pθ(xt∣x\u0026lt;t) J(\\theta) = \\max_\\theta \\sum_{t=1}^T \\log P_\\theta(x_t \\mid x_{\u0026lt;t}) J(θ)=θmax​t=1∑T​logPθ​(xt​∣x\u0026lt;t​) 对应的 损失函数 是负对数似然（NLL）：\nL(x;θ)=−∑t=1Tlog⁡Pθ(xt∣x\u0026lt;t) L(x; \\theta) = - \\sum_{t=1}^T \\log P_\\theta(x_t \\mid x_{\u0026lt;t}) L(x;θ)=−t=1∑T​logPθ​(xt​∣x\u0026lt;t​) 整体目标函数就是最小化 NLL 的平均：\nJ(θ)=min⁡θ1N∑i=1NL(xi;θ) J(\\theta) = \\min_\\theta \\frac{1}{N} \\sum_{i=1}^N L(x_i; \\theta) J(θ)=θmin​N1​i=1∑N​L(xi​;θ) 👉 最大化似然函数 ↔ 最小化负对数似然损失。\n1.5.4 总结对比 # 概念 损失函数 (Loss) 目标函数 (Objective) 作用 衡量单个样本预测的误差 衡量整体模型好坏 范围 针对单一样本 针对整个数据集 数学形式 L(x,y)L(x,y)L(x,y) J(θ)=∑L(x,y)J(\\theta) = \\sum L(x,y)J(θ)=∑L(x,y) 或加正则化 在 GPT 中 单个 token 的 NLL 整个序列的似然（或其平均 NLL） 2. KL散度 # 2.1 熵的定义 # **熵（Entropy）**是信息论中的核心概念，用来衡量一个随机变量的不确定性或平均信息量。\n**离散情况：**设随机变量 XXX 具有概率质量函数（PMF） P(x)P(x)P(x)，则熵定义为： H(P)=−∑xP(x)log⁡P(x) H(P) = - \\sum_{x} P(x) \\log P(x) H(P)=−x∑​P(x)logP(x) 这里的 P(x)P(x)P(x) 表示随机变量取值 xxx 的概率。 熵越大，表示分布越不确定。 **连续情况：**设随机变量 XXX 的概率密度函数（PDF）为 p(x)p(x)p(x)，则熵的定义为： H(P)=−∫p(x)log⁡p(x) dx H(P) = - \\int p(x) \\log p(x) \\, dx H(P)=−∫p(x)logp(x)dx 这时的熵被称为 微分熵（Differential Entropy）。 注意：微分熵可能为负数，不同于离散情形总是非负。 直观理解\n如果一个事件概率很大（确定性强），其信息量小。 如果一个事件概率很小（不确定性强），其信息量大。 熵是分布的 平均信息量。 例子：\n公平抛硬币： P(正)=0.5, P(反)=0.5 P(正)=0.5,\\; P(反)=0.5 P(正)=0.5,P(反)=0.5 熵为： H(P)=−(0.5log⁡0.5+0.5log⁡0.5)=1 比特 H(P) = - (0.5 \\log 0.5 + 0.5 \\log 0.5) = 1 \\ \\text{比特} H(P)=−(0.5log0.5+0.5log0.5)=1 比特 偏置硬币： P(正)=0.9, P(反)=0.1 P(正)=0.9,\\; P(反)=0.1 P(正)=0.9,P(反)=0.1 熵较小，因为结果更可预测。\n2.2 KL散度的定义 # ​\tKL散度（Kullback-Leibler divergence），可以以称作相对熵（relative entropy）或信息散度（information divergence）。KL散度的理论意义在于度量两个概率分布之间的差异程度，当KL散度越大的时候，说明两者的差异程度越大；而当KL散度小的时候，则说明两者的差异程度小。如果两者相同的话，则该KL散度应该为0。\n设：\nP(x)P(x)P(x) ：真实分布（目标分布） Q(x)Q(x)Q(x) ：模型分布（近似分布） 如果我们用Q(x)Q(x)Q(x)去近似P(x)P(x)P(x)，则KL散度可以表示为：\n离散情况 DKL(P∥Q)=∑xP(x)log⁡P(x)Q(x) D_{KL}(P \\parallel Q) = \\sum_x P(x) \\log \\frac{P(x)}{Q(x)} DKL​(P∥Q)=x∑​P(x)logQ(x)P(x)​连续情况 DKL(P∥Q)=∫P(x)log⁡P(x)Q(x) dx D_{KL}(P \\parallel Q) = \\int P(x) \\log \\frac{P(x)}{Q(x)} \\, dx DKL​(P∥Q)=∫P(x)logQ(x)P(x)​dx从上面的公式可以看出：\nKL散度具有非负性：DKL(P∥Q)≥0D_{KL}(P \\parallel Q) \\geq 0DKL​(P∥Q)≥0\n当且仅当 P(x)=Q(x)P(x) = Q(x)P(x)=Q(x) 对所有 xxx 成立时，DKL(P∥Q)=0D_{KL}(P \\parallel Q) = 0DKL​(P∥Q)=0\nKL散度不具备对称性，也就是说PPP对于QQQ的KL散度并不等于QQQ对于PPP的KL散度： DKL(P∥Q)≠DKL(Q∥P) D_{KL}(P \\parallel Q) \\neq D_{KL}(Q \\parallel P) DKL​(P∥Q)=DKL​(Q∥P) 在离散情况下将KL散度公式展开可得： KL(P∣∣Q)=∑P(x)logP(x)Q(x)=−∑P(x)log(Q(x))+∑P(x)log(P(x))=H(P,Q)−H(P) KL(P∣∣Q)=∑P(x)log \\frac{P(x)}{Q(x)}=−∑P(x)log(Q(x))+∑P(x)log(P(x))=H(P,Q)−H(P) KL(P∣∣Q)=∑P(x)logQ(x)P(x)​=−∑P(x)log(Q(x))+∑P(x)log(P(x))=H(P,Q)−H(P) 最后得到的第一项 H(P,Q)H(P, Q)H(P,Q) 称作 PPP 和 QQQ 的交叉熵（cross entropy），后面一项 H(P)H(P)H(P) 就是熵。\n在信息论中，熵代表着信息量，H(P)H(P)H(P) 代表着基于 PPP 分布自身的编码长度，也就是最优的编码长度（最小字节数）。\n而 H(P,Q)H(P,Q)H(P,Q) 则代表着用 QQQ 的分布去近似 PPP 分布的信息，自然需要更多的编码长度。并且两个分布差异越大，需要的编码长度越大。所以两个值相减是大于等于0的一个值，代表冗余的编码长度，也就是两个分布差异的程度。所以KL散度在信息论中还可以称为相对熵（relative entropy）。\n对深度学习中的生成模型来说，我们希望最小化真实数据分布与生成数据分布之间的KL散度，从而使得生成数据尽可能接近真实数据的分布。在实际场景中，我们是几乎不可能知道真实数据分布 Pdata(x)P_{data}(x)Pdata​(x) 的，我们使用训练数据形成的生成分布在逼近 Pdata(x)P_{data}(x)Pdata​(x) 。\n2.3 KL散度和交叉熵 # 由上推导可得： DKL(P∥Q)=H(P,Q)−H(P) D_{KL}(P \\parallel Q) = H(P, Q) - H(P) DKL​(P∥Q)=H(P,Q)−H(P)其中：\n熵（Entropy）：H(P)=−∑xP(x)log⁡P(x)H(P) = - \\sum_x P(x) \\log P(x)H(P)=−∑x​P(x)logP(x) 交叉熵（Cross-Entropy）：H(P,Q)=−∑xP(x)log⁡Q(x)H(P,Q) = - \\sum_x P(x) \\log Q(x)H(P,Q)=−∑x​P(x)logQ(x) KL散度：DKL(P∥Q)D_{KL}(P \\parallel Q)DKL​(P∥Q) 交叉熵损失函数 是深度学习中常见的优化目标。\n标签分布 PPP：真实标签（通常是 one-hot） 模型分布 QQQ：预测的 softmax 概率 它与 KL 散度 的关系为： L=H(P,Q)=−∑xP(x)log⁡Q(x)=DKL(P∥Q)+H(P) L=H(P,Q) =- \\sum_x P(x) \\log Q(x)= D_{KL}(P \\parallel Q) + H(P) L=H(P,Q)=−x∑​P(x)logQ(x)=DKL​(P∥Q)+H(P) 因为 H(P)H(P)H(P) 与模型无关，所以： 最小化交叉熵 ↔ 最小化 KL 散度。 3. 线性回归、逻辑回归、Softmax回归 # 3.1 线性回归（Linear Regression） # 目标： 用于预测连续型数值。\n模型形式：\ny=wTx+b y = \\mathbf{w}^T \\mathbf{x} + b y=wTx+b 输出： 直接给出一个实数值，适合房价预测、销量预测等回归问题。\n损失函数： 通常采用均方误差 (MSE)：\nL=1n∑i=1n(yi−y^i)2 L = \\frac{1}{n}\\sum_{i=1}^n (y_i - \\hat{y}_i)^2 L=n1​i=1∑n​(yi​−y^​i​)2 3.2 逻辑回归（Logistic Regression） # 目标： 用于二分类问题。\n核心思想： 在回归结果外加一个 Sigmoid 函数，把连续实数映射到区间 [0,1][0,1][0,1]，表示概率。\n模型形式：\nP(y=1∣x)=σ(wTx+b)=11+e−(wTx+b) P(y=1|\\mathbf{x}) = \\sigma(\\mathbf{w}^T \\mathbf{x} + b) = \\frac{1}{1 + e^{-(\\mathbf{w}^T \\mathbf{x}+b)}} P(y=1∣x)=σ(wTx+b)=1+e−(wTx+b)1​ 输出： 概率值，可根据阈值（通常为 0.5）进行分类。\n损失函数： 对数似然损失（交叉熵）：\nL=−1n∑i=1n[yilog⁡(y^i)+(1−yi)log⁡(1−y^i)] L = -\\frac{1}{n}\\sum_{i=1}^n \\Big[ y_i \\log(\\hat{y}_i) + (1-y_i)\\log(1-\\hat{y}_i)\\Big] L=−n1​i=1∑n​[yi​log(y^​i​)+(1−yi​)log(1−y^​i​)] 3.3 Softmax回归（Softmax Regression） # 目标： 逻辑回归的多分类扩展。\n核心思想： 对每一类计算一个线性函数，然后通过 Softmax 函数 把它们转化为概率分布。\n模型形式：\n对于 KKK 类分类：\nP(y=k∣x)=ewkTx+bk∑j=1KewjTx+bj P(y=k|\\mathbf{x}) = \\frac{e^{\\mathbf{w}_k^T \\mathbf{x} + b_k}}{\\sum_{j=1}^K e^{\\mathbf{w}_j^T \\mathbf{x} + b_j}} P(y=k∣x)=∑j=1K​ewjT​x+bj​ewkT​x+bk​​ 输出： 每一类的概率，取概率最大的类作为预测类别。\n损失函数： 多分类交叉熵损失：\nL=−1n∑i=1n∑k=1Kyi,klog⁡y^i,k L = -\\frac{1}{n}\\sum_{i=1}^n \\sum_{k=1}^K y_{i,k} \\log \\hat{y}_{i,k} L=−n1​i=1∑n​k=1∑K​yi,k​logy^​i,k​ 其中 yi,ky_{i,k}yi,k​ 是 one-hot 标签。\n3.4 区别与联系 # 模型基础相同： 三者都是基于 线性模型（权重和输入特征的线性组合），区别主要在于输出层和损失函数。 输出空间不同： 线性回归输出的是一个实数（无范围限制）。 逻辑回归输出的是 [0,1][0,1][0,1] 概率（二分类）。 Softmax 回归输出的是多维概率分布（多分类）。 应用场景不同： 线性回归：预测连续数值（回归问题）。 逻辑回归：预测是否属于某一类（二分类问题）。 Softmax 回归：预测属于哪一类（多分类问题）。 演变关系： 逻辑回归可以看作在线性回归的基础上加 Sigmoid 的“分类版”。 Softmax 回归是逻辑回归的多分类推广。 4. 混淆矩阵 # ​\t在二分类问题中，我们有两个类别：正类（Positive）和负类（Negative）。根据模型对每个样本的预测值和label值之间的关系，每个预测结果可以落在四个格子中的一个。然后对每个格子中的样本计数。\n真实正类（Positive） 真实负类（Negative） 预测正类 True Positive(TP) False Positive(FP) 预测负类 False Negative(FN) True Negative(TN) **TP：**模型正确地将正样本预测为正类的数量。\n**FP：**模型错误地将负样本预测为正类的数量（误报）。\n**FN：**模型错误地将正样本预测为负类的数量（漏报）。\n**TN：**模型正确地将负样本预测为负类的数量。\n有了混淆矩阵，对于准确率的定义就可以表示为： Accuracy=TP+TNTP+TN+FP+FN Accuracy= \\frac{TP+TN}{TP+TN+FP+FN} Accuracy=TP+TN+FP+FNTP+TN​ 4.1 精确率 # **精确率：**被预测为正例的样本中，真正正例的比例。它的公式为： Preciosn=TPTP+FP Preciosn=\\frac{TP}{TP+FP} Preciosn=TP+FPTP​ 4.2 召回率 # **召回率：**真实为正类的样本中，被正确预测为正类的比例。 Recal=TPTP+FN Recal=\\frac{TP}{TP+FN} Recal=TP+FNTP​ 4.3 F1-Score # F1-Score：\n​\t精确率衡量的是预测为正例的情况下真正为正例的百分比。选择精确率作为项目的目标，可以保证你得到的正例，实际为正例的比例很高。在逻辑回归中，你可以提高sigmoid输出大于0.5，比如0.8以上作为正例来提高精确率。但是这样做可能会把一些本来是正例的样本误判为负例。\n​\t召回率衡量的是真正的正例中你发现的比率。还是以对罕见病预测的例子，10万个人中，只有1个是正例，你预测10万个人都是正例，召回率是1，但是这个模型也没有什么用。\n​\t好的模型应该同时能兼顾精确率和召回率，F1-Score就是一个同时兼顾两者的指标，它的公式为： F1=2∗Precison∗RecallPrecision+Recall F1=2*\\frac{Precison*Recall}{Precision+Recall} F1=2∗Precision+RecallPrecison∗Recall​ 5. Batch Norm, Layer Norm, Instance Norm \u0026amp; Group Norm # 跨通道与跨样本的关键区别：\n跨通道：在同一个样本内，跨所有通道计算均值和方差（如 Layer Normalization 和 Group Normalization中的一部分）。举例：RGB 图像中，三个颜色通道的统计量被一起计算。 跨样本：在同一通道内，跨所有样本（比如在一个 batch内）计算均值和方差（如 Batch Normalization）。举例：在一个批次中所有图片的 R 通道上计算统计量。 5.1 Batch Norm # 针对单个通道，在批次维度 (Batch) 上计算均值和方差。（跨样本单通道）\n公式计算如下： x^=x−μbatchσbatch2+ϵ \\hat{x}=\\frac{x-\\mu_{\\text{batch}}}{\\sqrt{\\sigma^2_{\\text{batch}}+\\epsilon}} x^=σbatch2​+ϵ​x−μbatch​​ 其中：\nμbatch,σbatch2\\mu_{\\text{batch}}, \\sigma^2_{\\text{batch}}μbatch​,σbatch2​：某一通道在整个 Batch 的均值和方差。 ϵ：小常数，防止除零。 通俗理解： 想象输入了 8 张图片（Batch Size = 8，宽高分别为4*4），每张图片的 RGB 通道单独计算均值和方差。Batch Norm 就是对每个通道的所有图片（总共 8×4×4=128 个值）统一归一化，让数据在训练中更稳定。\n优缺点：\n优点：适合大批量训练，能有效提高训练收敛速度和模型精度。 缺点：对小批量数据（例如，内存限制导致每次只训练少量样本）表现较差，统计量估计不稳定 5.2 Layer Norm # 在每个样本的，所有通道 (Channel) 上归一化。（单样本跨通道）\n公式计算如下： x^=x−μlayerσlayer2+ϵ \\hat{x}=\\frac{x-\\mu_{\\text{layer}}}{\\sqrt{\\sigma^2_{\\text{layer}}+\\epsilon}} x^=σlayer2​+ϵ​x−μlayer​​ 其中：\nμbatch,σbatch2\\mu_{\\text{batch}}, \\sigma^2_{\\text{batch}}μbatch​,σbatch2​：某一样本在整个 Batch 的均值和方差。 通俗理解： 假设我们只看一张图片（例如大小 3×4×4）。Layer Norm 会把这张图片所有通道的数据（3 个通道的 4×4=48 个值）统一归一化。适合处理单张图片的场景。\n优缺点：\n优点：独立于批次大小，适合 RNN 等顺序模型。 缺点：对视觉任务（如图像分类）性能不如 BN，因为假设所有通道的分布相似，这在卷积层中可能不成立。 5.3 Instance Norm # 在每个样本的每个通道上归一化，只使用空间维度。（单样本单通道）\n计算公式如下： x^=x−μinstanceσinstance2+ϵ \\hat{x}=\\frac{x-\\mu_{\\text{instance}}}{\\sqrt{\\sigma^2_{\\text{instance}}+\\epsilon}} x^=σinstance2​+ϵ​x−μinstance​​ 其中：\nμinstance,σinstance2\\mu_{\\text{instance}}, \\sigma^2_{\\text{instance}}μinstance​,σinstance2​ ：单张图片的某个通道内的均值和方差。 通俗理解： Instance Norm 针对单张图片的每个通道进行独立归一化。例如对红色通道的 4×4=16 个值归一化，绿色和蓝色通道分别独立归一化。适合图像生成任务，如风格迁移。\n优缺点：\n优点：常用于生成模型（如风格迁移），对每个样本进行独立处理。 缺点：忽略了通道间的依赖关系，表达能力较弱。 5.4 Group Norm # 将通道划分为若干组 (Groups)，在每组内计算归一化。（单样本跨几个通道，LN 的简化版本） x^=x−μgroupσgroup2+ϵ \\hat{x}=\\frac{x-\\mu_{\\text{group}}}{\\sqrt{\\sigma^2_{\\text{group}}+\\epsilon}} x^=σgroup2​+ϵ​x−μgroup​​ 其中：\nμgroup,σgroup2\\mu_{\\text{group}}, \\sigma^2_{\\text{group}}μgroup​,σgroup2​ ：某一组通道的均值和方差。 通俗理解： 假设我们有 32 个通道，Group Norm 将其分成 4 组，每组 8 个通道。然后对每组的所有值进行归一化。例如：红、绿、蓝可能被分配到不同的组内，每组内单独计算均值和方差。适合小 Batch 的训练。\n优缺点：\n优点：摆脱了对批次大小的依赖，在各种批次大小下都能表现稳定。 缺点：计算复杂度略高于 BN。 6. 字符编码 # ​\t我们知道，计算机内部，所有信息最终都是一个二进制值。每一个二进制位（bit）有0和1两种状态，因此八个二进制位就可以组合出256种状态，这被称为一个字节（byte）。也就是说，一个字节一共可以用来表示256种不同的状态，每一个状态对应一个符号，就是256个符号，从00000000到11111111。\n6.1 ASCII编码 # ​\tASCII(American Standard Code for Information Interchange，美国标准信息交换代码)是基于拉丁字母的一套电脑编码系统，主要用于显示现代英语和其他西欧语言。它是现今最通用的单字节编码系统，并等同于国际标准ISO/IEC 646。\n​\tASCII 码使用指定的 7 位或 8 位二进制数组合来表示128 或256 种可能的字符。标准ASCII 码也叫基础ASCII码，使用 7 位二进制数（第一位二进制为0）来表示所有的大写和小写字母，数字0 到9、标点符号， 以及在美式英语中使用的特殊控制字符。如下图所示：\n其中：\n​\t0～31及127(共33个)是控制字符或通信专用字符（其余为可显示字符），如控制符：LF（换行）、CR（回车）、FF（换页）、DEL（删除）、BS（退格)、BEL（响铃）等；通信专用字符：SOH（文头）、EOT（文尾）、ACK（确认）等；ASCII值为8、9、10 和13 分别转换为退格、制表、换行和回车字符。它们并没有特定的图形显示，但会依不同的应用程序，而对文本显示有不同的影响。\n​\t32～126(共95个)是字符(32是空格），其中48～57为0到9十个阿拉伯数字。\n​\t65～90为26个大写英文字母，97～122号为26个小写英文字母，其余为一些标点符号、运算符号等。\n​\t后128个称为扩展ASCII码。许多基于x86的系统都支持使用扩展（或“高”）ASCII。扩展ASCII 码允许将每个字符的第8 位用于确定附加的128 个特殊符号字符、外来语字母和图形符号。如下图所示： 6.2 GBK编码 # ​\t由于ASCII编码不支持中文，因此，当中国人用到计算机时，就需要寻求一种编码方式来支持中文。 于是，国人就定义了一套编码规则：当字符小于127位时，与ASCII的字符相同，但当两个大于127的字符连接在一起时，就代表一个汉字，第一个字节称为高字节（从0xA1-0xF7）,第二个字节为低字节（从0xA1-0xFE）,这样大约可以组合7000多个简体汉字。这个规则叫做GB2312。 ​ 但是由于中国汉字很多，有些字无法表示，于是重新定义了规则：不在要求低字节一定是127之后的编码，只要第一个字节是大于127，就固定表示这是一个汉字的开始，不管后面跟的是不是扩展字符集里的内容。这种扩展之后的编码方案称之为GBK标准，包括了GB2312的所有内容，同时新增了近20000个新的汉字（包括繁体字）和符号。 ​ 但是，中国有56个民族，所以，我们再次对编码规则进行了扩展，又加了近几千个少数民族的字符，于是再次扩展后得编码叫做GB18030。中国的程序员觉得这一系列编码的标准是非常的好，于是统统称他们叫做\u0026quot;DBCS\u0026quot;（Double Byte Charecter Set 双字节字符集）。\n6.3 Unicode字符集 # ​\t因为世界国家很多，每个国家都定义一套自己的编码标准，结果相互之间谁也不懂谁的编码，就无法进行很好的沟通交流，所以及时的出现了一个组织ISO（国际标准化组织）决定定义一套编码方案来解决所有国家的编码问题，这个新的编码方案就叫做Unicode。注意Unicode不是一个新的编码规则，而是一套字符集（为每一个「字符」分配一个唯一的 ID（学名为码位 / 码点 / Code Point）），可以将Unicode理解为一本世界编码的字典。 ​\tISO规定：每个字符必须使用俩个字节，即用16位二进制来表示所有的字符，对于ASCII编码表里的字符，保持其编码不变，只是将长度扩展到了16位，其他国家的字符全部统一重新编码。由于传输ASCII表里的字符时，实际上可以只用一个字节就可以表示，所以，这种编码方案在传输数据比较浪费带宽，存储数据比较浪费硬盘。\n​\tUnicode的问题如下：\n​\t需要注意的是，Unicode 只是一个符号集，它只规定了符号的二进制代码，却没有规定这个二进制代码应该如何存储。\n​\t比如，汉字严的 Unicode 是十六进制数4E25，转换成二进制数足足有15位（100111000100101），也就是说，这个符号的表示至少需要2个字节。表示其他更大的符号，可能需要3个字节或者4个字节，甚至更多。\n​\t这里就有两个严重的问题，第一个问题是，如何才能区别 Unicode 和 ASCII ？计算机怎么知道三个字节表示一个符号，而不是分别表示三个符号呢？第二个问题是，我们已经知道，英文字母只用一个字节表示就够了，如果 Unicode 统一规定，每个符号用三个或四个字节表示，那么每个英文字母前都必然有二到三个字节是0，这对于存储来说是极大的浪费，文本文件的大小会因此大出二三倍，这是无法接受的。\n​\t它们造成的结果是：\n​\t1）出现了 Unicode 的多种存储方式，也就是说有许多种不同的二进制格式，可以用来表示 Unicode。\n​\t2）Unicode 在很长一段时间内无法推广，直到互联网的出现。\n6.4 UTF-8编码 # ​\t由于Unicode比较浪费网络带宽和硬盘，因此为了解决这个问题，就在Unicode的基础上，定义了一套编码规则（将「码位」转换为字节序列的规则（编码/解码 可以理解为 加密/解密 的过程）），这个新的编码规则就是UTF-8，采用1-4个字符进行传输和存储数据。\n​\t编码规则：使用下面的模板进行转换：\n​\t如果一个字节以“0”开头，说明该字符采用单字节编码方式，这与传统的ASCII编码保持兼容。若一个字节以“10”开头，则表明它不是字符的起始字节，而是某个多字节字符的中间部分，需要向前寻找起始位置。若以“110”开头，则表示该字符使用两个字节进行编码；以“1110”开头则代表采用三个字节；而若以“11110”开头，则说明该字符由四个字节构成。\nUnicode符号范围（十六进制）\t| UTF-8编码方式(二进制) ------------------------------------------------------------------------ 0000 0000-0000 007F | 0xxxxxxx\t(ASCII码) 0000 0080-0000 07FF | 110xxxxx 10xxxxxx 0000 0800-0000 FFFF | 1110xxxx 10xxxxxx 10xxxxxx 0001 0000-0010 FFFF | 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx 6.5 Little endian 和 Big endian # ​\tUCS-2 格式可以存储 Unicode 码（码点不超过0xFFFF）。以汉字严为例，Unicode 码是4E25，需要用两个字节存储，一个字节是4E，另一个字节是25。存储的时候，4E在前，25在后，这就是 Big endian 方式；25在前，4E在后，这是 Little endian 方式。\n​\t这两个古怪的名称来自英国作家斯威夫特的《格列佛游记》。在该书中，小人国里爆发了内战，战争起因是人们争论，吃鸡蛋时究竟是从大头(Big-endian)敲开还是从小头(Little-endian)敲开。为了这件事情，前后爆发了六次战争，一个皇帝送了命，另一个皇帝丢了王位。\n​\t第一个字节在前，就是\u0026quot;大头方式\u0026quot;（Big endian），第二个字节在前就是\u0026quot;小头方式\u0026quot;（Little endian）。\n​\t那么很自然的，就会出现一个问题：计算机怎么知道某一个文件到底采用哪一种方式编码？\n​\tUnicode 规范定义，每一个文件的最前面分别加入一个表示编码顺序的字符，这个字符的名字叫做\u0026quot;零宽度非换行空格\u0026quot;（zero width no-break space），用FE FF表示。这正好是两个字节，而且FF比FE大1。\n​\t如果一个文本文件的头两个字节是FE FF，就表示该文件采用大头方式；如果头两个字节是FF FE，就表示该文件采用小头方式。\n​\t计算机中notebook的编码方式：\n​\t有四个选项：ANSI，Unicode，Unicode big endian和UTF-8\n1）ANSI是默认的编码方式：对于英文文件是ASCII编码，对于简体中文文件是GB2312编码（只针对 Windows 简体中文版，如果是繁体中文版会采用 Big5 码）； 2）Unicode编码这里指的是notepad.exe使用的 UCS-2 编码方式：即直接用两个字节存入字符的 Unicode 码，这个选项用的 little endian 格式； 3）Unicode big endian编码与上一个选项相对应：我在下一节会解释 little endian 和 big endian 的涵义； 4）UTF-8编码：也就是上一节谈到的编码方法。\n6.6 字符编码的转换 # ​\tUTF-8与Unicode的转换：\n​\tUTF-8区分每个字符的开始是根据字符的高位字节来区分的，比如用一个字节表示的字符，第一个字节高位以“0”开头；用两个字节表示的字符，第一个字节的高位为以“110”开头，后面一个字节以“10开头”；用三个字节表示的字符，第一个字节以“1110”开头，后面俩字节以“10”开头；用四个字节表示的字符，第一个字节以“11110”开头，后面的三个字节以“10”开头。\n​\t比如汉字“智”，utf-8编码是“\\xe6\\x99\\xba”对应的二进制为：“111001101001100110111010”，由于utf-8中一个汉字是3个字节，所以对应的模板为“0000 0800-0000 FFFF | 1110xxxx 10xxxxxx 10xxxxxx”。\n11100110 10011001 10111010 1110xxxx 10xxxxxx 10xxxxxx 0110 011001 111010 ​\t0110011001111010代表十六进制667A，因此根据规则转换得出“智”Unicode的位置为为“667A”。同样，根据Unicode中字符的编码位置，也能找到对应的utf-8编码。\n​\tUnicode与GBK的转换：\n​\t比如汉字“路”，在gbk中的编码为“\\xc2\\xb7”,对应的二进制为：“1100 0010 1011 0111”。同时“路”在Unicode字符集中的位置是“\\u8def”(python中的Unicode类型)，因此可以通过“\\u8def”在Unicode字符集中找到“路”对应的编码为“4237”，对应的二进制为：“0100 0010 0011 0111”，由于gbk的俩个字节的高字节是为了区分中文和ASCII，所以将“1100 0010 1011 0111”高字节的“1”去掉后，就对应Unicode字符集中的0100 0010 0011 0111”。\n​\tUTF-8和Unicode与GBK的关系\n​\tutf-8\u0026mdash;\u0026mdash;\u0026ndash;decode(解码)\u0026mdash;\u0026ndash;\u0026gt;Unicode类型\u0026lt;\u0026mdash;\u0026mdash;-decode(解码)\u0026mdash;\u0026ndash;gbk\n​\tutf-8\u0026lt;\u0026mdash;\u0026mdash;\u0026ndash;encode(编码)\u0026mdash;\u0026ndash;Unicode类型\u0026mdash;\u0026mdash;-encode(编码)\u0026mdash;\u0026ndash;\u0026gt;gb\n7. 激活函数 # 7.1 Sigmoid、ReLU、Tanh、GELU、Swish # 7.1.1 Sigmoid(S型函数) # 公式： f(x)=11+e−x f(x) = \\frac{1}{1 + e^{-x}} f(x)=1+e−x1​ 图像如下所示：\n图像特征：\n输出范围在 (0, 1) 之间 在 x=0x=0x=0 处取值为 0.5 当 x→+∞x\\to +\\inftyx→+∞ 时趋近于 1 当 x→−∞x \\to -\\inftyx→−∞ 时趋近于 0 优缺点：\n优点：可平滑地将输入映射到有限区间 缺点：存在梯度消失问题，不适合深层网络 7.1.2 ReLU(线性整流单元) # 公式： f(x)=max⁡(0,x) f(x) = \\max(0, x) f(x)=max(0,x) 图像如下所示： 图像特征：\n当 x\u0026lt;0x \u0026lt; 0x\u0026lt;0 时，输出为 0 当 x≥0x \\ge 0x≥0 时，输出为 x 输出范围：[0,+∞)[0, +\\infty)[0,+∞) 优缺点：\n优点：计算简单，梯度传播效率高，不易饱和 缺点：“神经元死亡”问题（当权重更新使输出恒小于0时，梯度为0） 7.1.3 Tanh(双曲正切函数) # 公式： f(x)=tanh(x)=ex−e−xex+e−x f(x)=tanh(x)=\\frac{e^{x}-e^{-x}}{e^{x}+e^{-x}} f(x)=tanh(x)=ex+e−xex−e−x​ 图像如下：\n图像特征：\n输出范围在 (-1, 1) 之间 关于原点对称（奇函数） 当 x→+∞x \\to +\\inftyx→+∞ 时，输出趋近于 1 当 x→−∞x \\to -\\inftyx→−∞ 时，输出趋近于 -1 在中间区域（接近 0）近似线性增长 优缺点： **优点：**输出均值接近 0，有助于加快梯度下降收敛；相比 Sigmoid，饱和区间更小，梯度消失问题稍轻 **缺点：**当输入绝对值较大时，仍会出现饱和导致梯度消失；计算复杂度略高于 ReLU 7.1.4 GELU(高斯误差线性单元) # 公式： f(x)=x⋅Φ(x)=x⋅12[1+erf(x2)] f(x) = x \\cdot \\Phi(x) = x \\cdot \\frac{1}{2}\\left[1 + \\text{erf}\\left(\\frac{x}{\\sqrt{2}}\\right)\\right] f(x)=x⋅Φ(x)=x⋅21​[1+erf(2​x​)] 其中 Φ(x)\\Phi(x)Φ(x)为标准正态分布的累积分布函数，erf\\text{erf}erf 为误差函数。\n近似形式（更常用的计算形式）： f(x)≈0.5x(1+tanh⁡[2π(x+0.044715x3)]) f(x) \\approx 0.5x \\left(1 + \\tanh\\left[\\sqrt{\\frac{2}{\\pi}}(x + 0.044715x^3)\\right]\\right) f(x)≈0.5x(1+tanh[π2​​(x+0.044715x3)]) 图像如下所示：\n图像特征：\n介于 ReLU 与 Sigmoid 之间的平滑曲线 当 x\u0026lt;0x \u0026lt; 0x\u0026lt;0 时有小的负输出 平滑且连续 优缺点：\n优点：相比 ReLU 更平滑，适合 Transformer、BERT 等模型 缺点：计算复杂度较高 7.1.5 Swish(自门控激活函数) # 公式： f(x)=x⋅σ(x)=x1+e−x f(x) = x \\cdot \\sigma(x) = \\frac{x}{1 + e^{-x}} f(x)=x⋅σ(x)=1+e−xx​ 图像如下所示： 图像特征：\n介于 ReLU 与 Sigmoid 之间 当 x\u0026lt;0x \u0026lt; 0x\u0026lt;0 时输出略为负值 平滑且单调递增 优缺点：\n优点：连续可导、性能优于 ReLU 缺点：计算略复杂于 ReLU 7.2 GLU() # ​\tGLU 的全称是 Gated Linear Unit（门控线性单元），它不是一个单纯的激化函数，而是一种 “门控机制（Gating Mechanism）”，通常出现在 序列建模、语言模型（如 Transformer、ConvSeq2Seq）中。\n​\tGLU 全称为 Gated Linear Unit，即门控线性单元函数。其计算过程如下：\n​\t首先输入向量 xxx 经过两个独立的卷积层/MLP层，得到向量 AAA 和向量 BBB 。此时，向量AAA 和向量 BBB 的公式为： A=x⋅W+b，B=x⋅V+cA=x⋅W+b，B=x⋅V+cA=x⋅W+b，B=x⋅V+c。然后向量 BBB 经过一个 sigmoid 函数之后，σ(B)\\sigma(B)σ(B) 中的每个元素就都变为了0～1之间的值，就可以起到控制信息是否通过的作用。\n​\t将向量 AAA 与 σ(B)\\sigma(B)σ(B) 逐个元素相乘之后就得到了GLU层的最终输出结果，把上面描述的整个过程结合起来，GLU的公式如下： GLU(x)=A⊗σ(B)=(x⋅W+b)⊗σ(x⋅V+c)(1) \\text{GLU}(x)=A\\otimes\\sigma(B)=(x \\cdot W+b) \\otimes \\sigma(x \\cdot V+c)\\tag{1} GLU(x)=A⊗σ(B)=(x⋅W+b)⊗σ(x⋅V+c)(1) ​\t在该公式中，xxx 表示输入向量，⊗\\otimes⊗ 表示两个向量逐元素相乘，σ\\sigmaσ 表示sigmoid函数。\n​\t当GLU作为激活函数时一般不是上述公式的形式。激活函数 ReLU 的公式为 ReLU=max⁡(0,x)\\text{ReLU}=\\max(0,x)ReLU=max(0,x)，其含义就是当输入向量 xxx 的值小于 0 时直接阻断，当时输入向量 xxx 的值大于 0 值直接通过。参考ReLU激活函数，激活函数GLU的公式为如下公式的形式： GLU(x)=x⊗σ(g(x))(2) \\text{GLU}(x)=x \\otimes \\sigma(g(x))\\tag{2} GLU(x)=x⊗σ(g(x))(2) ​\t这里有一个新符号 g(x)g(x)g(x) 表示的是向量 xxx 经过一层MLP或者卷积层，其他部分的符号与 公式(1) 中是相同的。可以看出 公式(2) 的右半部分与 公式(1) 的右半部分是完全一样的，所不同的是 公式(2) 中的左半部分直接就是向量 xxx，这和上面描述的激活函数ReLU是相似的，当 σ(g(x))\\sigma(g(x))σ(g(x)) 趋近于 0 时表示对 xxx 进行阻断，当 σ(g(x))\\sigma(g(x))σ(g(x)) 趋近于 1 时表示允许 xxx 通过，以此实现门控激活函数的效果。\n优缺点：\n优点：\n能动态控制信息流，增强模型的表示能力 在语言建模和序列任务中表现优异（如 CNN-based Transformer、GPT 变体） 可以抑制噪声、选择性激活有用特征 缺点：\n引入了额外的参数和计算量（需要两个通道） 不属于“逐元素非线性函数”，结构更复杂 7.3 GLU的变种 # GLU 的问题：\nSigmoid 输出区间过窄 → 容易饱和、梯度较小； 输出仅依赖线性组合 → 非线性能力有限； 在深层模型（如 Transformer FFN 层）中，ReLU / GELU 仍更高效。 因此研究者提出了一系列改进的“门控激活函数”：\nReGLU、GEGLU、SwiGLU 这些变体的主要思想是： 保留门控结构，但用更强的激活函数（ReLU、GELU、Swish）替代 Sigmoid。\n7.3.1 ReGLU # ReGLU(A,B)=A⊗ReLU(B) \\text{ReGLU}(A,B)=A \\otimes \\text{ReLU}(B) ReGLU(A,B)=A⊗ReLU(B) 由 Google 的 PaLM 论文中提出。 替换了 Sigmoid → ReLU，使门控具有稀疏性。 优点：计算简单，梯度不易消失。 缺点：不平滑，ReLU 在 0 处不可导。 7.3.2 GEGLU # GEGLU(A,B)=A⊗GELU(B) \\text{GEGLU}(A,B)=A \\otimes \\text{GELU}(B) GEGLU(A,B)=A⊗GELU(B) 来自《PaLM: Scaling Language Models with Pathways》（Google, 2022）。 使用 GELU 代替 Sigmoid。 GELU 是平滑非线性函数，输出接近正态分布，更自然地控制信息流。 实验显示**：GEGLU 在 Transformer 的 FFN（前馈层）中效果优于 ReLU/GLU**。 7.3.3 SwiGLU # SwiGLU(A,B)=A⊗Swish(B) \\text{SwiGLU}(A,B)=A \\otimes \\text{Swish}(B) SwiGLU(A,B)=A⊗Swish(B) 来自《T5 v1.1》和后续 LLaMA 系列模型。 使用 Swish 代替 Sigmoid 门控。 Swish 比 GELU 更连续、非单调，梯度更平滑； 性能和稳定性都非常好，因此 LLaMA2 / 3、GPT-NeoX、T5.1.1、Qwen 等模型都采用它。 7.3.4 总结对比 # 引入更强的非线性：门控层的激活函数由 Sigmoid → GELU/Swish，使模型更具表达能力。 梯度传播更平滑：GELU / Swish 的平滑特性避免了 ReLU 的梯度断层问题。 计算效率提升：仍然只需一次线性层输出 2 倍维度（如 2dff2d_{ff}2dff​），再分为 A,BA,BA,B，总体代价与传统 FFN 相近。 实际验证有效 PaLM 论文：GEGLU 比 ReLU 提升约 0.3–0.5 BLEU / perplexity。 LLaMA 论文：SwiGLU 提升稳定性与收敛速度。 名称 公式 门控函数 特点与改进 GLU A⋅σ(B)A \\cdot \\sigma(B)A⋅σ(B) Sigmoid 最初形式，容易饱和 ReGLU A⋅ReLU(B)A \\cdot \\text{ReLU}(B)A⋅ReLU(B) ReLU 门控更稀疏，计算简单 GEGLU A⋅GELU(B)A \\cdot \\text{GELU}(B)A⋅GELU(B) GELU 更平滑，表现优异，已用于 PaLM SwiGLU A⋅Swish(B)A \\cdot \\text{Swish}(B)A⋅Swish(B) Swish 平滑且自适应，已用于 LLaMA、T5 v1.1、GPT-NeoX 等 ​\tGLU 变体（ReGLU、GEGLU、SwiGLU）是通过替换门控函数，使门控机制更平滑、更高效，从而增强 Transformer 的非线性表达能力和训练稳定性。 ​\t其中 SwiGLU 是目前主流大语言模型（如 LLaMA、Qwen）默认采用的版本。\n7.4 FFN公式及其变体 # ​\t我们知道FFN就是输入向量 xxx 先经过一个MLP层，将维度上升到原来的4倍，然后过一个激活函数层，最后再经过一个MLP层，将维度还原回去。其公式为： FFN(x,W1,W2,b1,b2)=ReLU(xW1+b1)W2+b2 \\text{FFN}(x,W_1,W_2,b_1,b_2)=\\text{ReLU}(xW_1+b_1)W_2+b_2 FFN(x,W1​,W2​,b1​,b2​)=ReLU(xW1​+b1​)W2​+b2​ ​\t有些研究，比如T5，将实验中的偏置项去掉了，这样FFN就变为了如下形式： FFNReLU(x,W1,W2)=ReLU(xW1)W2 \\text{FFN}_\\text{ReLU}(x,W_1,W_2)=\\text{ReLU}(xW_1)W_2 FFNReLU​(x,W1​,W2​)=ReLU(xW1​)W2​ ​\t把这里的激活函数由 ReLU 替换为 GeLU 或 Swish 之后，就得到两种新的FFN，公式如下，下面这两种也都有相应的实验结果。 FFNGeLU(x,W1,W2)=GeLU(xW1)W2 \\text{FFN}_\\text{GeLU}(x,W_1,W_2)=\\text{GeLU}(xW_1)W_2 FFNGeLU​(x,W1​,W2​)=GeLU(xW1​)W2​FFNSwish(x,W1,W2)=Swish(xW1)W2 \\text{FFN}_\\text{Swish}(x,W_1,W_2)=\\text{Swish}(xW_1)W_2 FFNSwish​(x,W1​,W2​)=Swish(xW1​)W2​​\t使用GLU 及其各种变体替换掉 FFN 中的第一个MLP和激活函数，就可以得到如下新版的 FFN 公式： FFNGLU(x,W,V)=[σ(xW)⊗(xV)]W2 \\text{FFN}_\\text{GLU}(x,W,V)=[\\sigma(xW) \\otimes (xV)]W_2 FFNGLU​(x,W,V)=[σ(xW)⊗(xV)]W2​FFNBilinear(x,W,V)=[(xW)⊗(xV)]W2 \\text{FFN}_\\text{Bilinear}(x,W,V)=[(xW) \\otimes (xV)]W_2 FFNBilinear​(x,W,V)=[(xW)⊗(xV)]W2​FFNReGLU(x,W,V)=[ReLU(xW)⊗(xV)]W2 \\text{FFN}_\\text{ReGLU}(x,W,V)=[\\text{ReLU}(xW) \\otimes (xV)]W_2 FFNReGLU​(x,W,V)=[ReLU(xW)⊗(xV)]W2​FFNGEGLU(x,W,V)=[GELU(xW)⊗(xV)]W2 \\text{FFN}_\\text{GEGLU}(x,W,V)=[\\text{GELU}(xW) \\otimes (xV)]W_2 FFNGEGLU​(x,W,V)=[GELU(xW)⊗(xV)]W2​FFNSwiGLU(x,W,V)=[Swish(xW)⊗(xV)]W2 \\text{FFN}_\\text{SwiGLU}(x,W,V)=[\\text{Swish}(xW) \\otimes (xV)]W_2 FFNSwiGLU​(x,W,V)=[Swish(xW)⊗(xV)]W2​ 8. 优化器 # 8.1 SGD # **概念：SGD 是最基础的优化算法。它每次使用一个样本（或小批量样本）**来估计梯度，从而更新参数。\n更新公式： θt+1=θt−η∇θL(θt) \\theta_{t+1} = \\theta_t - \\eta \\nabla_\\theta L(\\theta_t) θt+1​=θt​−η∇θ​L(θt​) 其中：\nθt\\theta_tθt​：模型参数 η\\etaη：学习率 ∇θL(θt)\\nabla_\\theta L(\\theta_t)∇θ​L(θt​)：损失函数关于参数的梯度 特点：\n简单易实现； 但梯度震荡较大，容易在谷底附近来回振荡； 对学习率非常敏感。 8.2 Momentum # **概念：**Momentum 在 SGD 基础上引入“惯性”思想，模拟小球在曲面上滚动的过程。即在当前梯度方向上考虑之前累积的更新趋势，使优化过程更平滑。\n更新公式： vt=βvt−1+(1−β)∇θL(θt) v_t = \\beta v_{t-1} + (1 - \\beta)\\nabla_\\theta L(\\theta_t) vt​=βvt−1​+(1−β)∇θ​L(θt​)θt+1=θt−ηvt \\theta_{t+1} = \\theta_t - \\eta v_t θt+1​=θt​−ηvt​其中：\nvtv_tvt​：动量项（梯度的指数加权平均） 有时在计算完 vtv_tvt​ 还会进行修正（v0=0v_0=0v0​=0），修正公式：vtcorrect=vt1−βtv_t^{correct}=\\frac{v_t}{1-\\beta^t}vtcorrect​=1−βtvt​​ β\\betaβ：动量系数，常取 0.9 特点：\n减少震荡，加快收敛； 在鞍点处更容易逃脱； 若动量过大可能导致越过最优点。 8.3 RMSProp # 概念：RMSProp 是一种自适应学习率算法。它对每个参数的梯度平方进行指数加权平均，从而自动调整学习率大小。\n更新公式： st=βst−1+(1−β)(∇θL(θt))2 s_t = \\beta s_{t-1} + (1 - \\beta)(\\nabla_\\theta L(\\theta_t))^2 st​=βst−1​+(1−β)(∇θ​L(θt​))2θt+1=θt−ηst+ϵ∇θL(θt) \\theta_{t+1} = \\theta_t - \\frac{\\eta}{\\sqrt{s_t + \\epsilon}} \\nabla_\\theta L(\\theta_t) θt+1​=θt​−st​+ϵ​η​∇θ​L(θt​)其中：\nsts_tst​：历史梯度平方的指数加权平均 ϵ\\epsilonϵ：防止除零的小常数（通常为 10−810^{-8}10−8） 特点：\n对梯度大（变化剧烈）的参数自动减小学习率； 对梯度小（变化缓慢）的参数自动增大学习率； 收敛速度快，尤其适合非平稳目标。 8.4 Adam # **概念：**Adam 结合了 Momentum + RMSProp 的优点：既保留动量的加速效果，又具有自适应学习率的能力。\n更新公式： mt=β1mt−1+(1−β1)∇θL(θt) m_t = \\beta_1 m_{t-1} + (1 - \\beta_1)\\nabla_\\theta L(\\theta_t) mt​=β1​mt−1​+(1−β1​)∇θ​L(θt​)vt=β2vt−1+(1−β2)(∇θL(θt))2 v_t = \\beta_2 v_{t-1} + (1 - \\beta_2)(\\nabla_\\theta L(\\theta_t))^2 vt​=β2​vt−1​+(1−β2​)(∇θ​L(θt​))2（偏差校正） m^t=mt1−β1t,v^t=vt1−β2t \\hat{m}_t = \\frac{m_t}{1 - \\beta_1^t}, \\quad \\hat{v}_t = \\frac{v_t}{1 - \\beta_2^t} m^t​=1−β1t​mt​​,v^t​=1−β2t​vt​​ （参数更新） θt+1=θt−ηv^t+ϵm^t \\theta_{t+1} = \\theta_t - \\frac{\\eta}{\\sqrt{\\hat{v}_t} + \\epsilon}\\hat{m}_t θt+1​=θt​−v^t​​+ϵη​m^t​ 计算高效； 适合大规模、稀疏梯度的任务； 几乎不需要手动调整学习率； 但有时会出现收敛到次优解的问题（AdamW 改进了这一点）。 8.5 AdamW # 概念：AdamW 是对 Adam 的改进版本，用于更合理地进行权重衰减（正则化）。Adam 的原始实现中，权重衰减被隐式耦合在梯度更新中，这会影响优化方向；而 AdamW 将权重衰减与梯度更新显式分离。\n更新公式： thetat+1=θt−η(m^tv^t+ϵ+λθt) theta_{t+1} = \\theta_t - \\eta \\left( \\frac{\\hat{m}_t}{\\sqrt{\\hat{v}_t} + \\epsilon} + \\lambda \\theta_t \\right) thetat+1​=θt​−η(v^t​​+ϵm^t​​+λθt​) 其中：\nλ\\lambdaλ：权重衰减系数（类似于 L2 正则） 特点：\n避免了 Adam 的正则化偏差问题； 改善了泛化性能； 是现代 Transformer（如 BERT、GPT 系列）默认使用的优化器。 9. 权重衰减、正则化与梯度裁剪 # 9.1 权重衰退 # 概念\n权重衰减是一种在优化过程中对权重参数进行衰减的技术，使模型的参数不会无限增大，防止过拟合。\n其核心思想是：\n在每次更新参数时，让参数值向零稍微靠近一点。\n数学原理\n在常规梯度下降中： θt+1=θt−η∇θL(θt) \\theta_{t+1} = \\theta_t - \\eta \\nabla_\\theta L(\\theta_t) θt+1​=θt​−η∇θ​L(θt​) 加入权重衰减后： θt+1=(1−ηλ)θt−η∇θL(θt) \\theta_{t+1} = (1 - \\eta \\lambda)\\theta_t - \\eta \\nabla_\\theta L(\\theta_t) θt+1​=(1−ηλ)θt​−η∇θ​L(θt​) 其中：\nλ\\lambdaλ：权重衰减系数（一般为 1e-4 ~ 1e-2） η\\etaη：学习率 等价于在损失函数中加入一个 L2 正则项： L′=L+λ2∥θ∥2L\u0026#x27; = L + \\frac{\\lambda}{2} \\|\\theta\\|^2L′=L+2λ​∥θ∥2\n实现方式\n在 PyTorch 中：\noptimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) 对于 SGD、Adam 等优化器，weight_decay 参数控制衰减强度。 注意： AdamW 明确将权重衰减与梯度更新分离（比 Adam 原始实现更正确）。 效果\n减少模型复杂度； 防止权重过大导致过拟合； 改善泛化性能； 常用于 Transformer、CNN、RNN 等模型训练中。 9.2 正则化 # 概念：正则化是一类用于防止模型过拟合的技术，通过对模型参数或输出施加约束，使其更平滑、更具泛化能力。\n常见类型：\n类型 数学表达式 效果 典型用途 L1 正则化 $L\u0026rsquo; = L + \\lambda \\sum_i \\theta_i $ L2 正则化 L′=L+λ2∑iθi2L\u0026#x27; = L + \\frac{\\lambda}{2}\\sum_i \\theta_i^2L′=L+2λ​∑i​θi2​ 抑制大权重，提升泛化 Ridge、神经网络 Dropout 随机丢弃一部分神经元 防止共适应，增强鲁棒性 全连接网络 Early Stopping 监控验证集损失，提前停止 防止过拟合 所有模型通用 BatchNorm/LayerNorm 归一化激活值 减小内部协变量偏移 CNN/Transformer L2 与权重衰减的关系\n在 SGD 中，权重衰减与 L2 正则是等价的； 在 Adam / RMSProp 中，只有 AdamW 的实现才与 L2 正则严格等价； 因此推荐用 AdamW(weight_decay=...) 而不是在损失中手动加正则项。 9.3 梯度裁剪 # 概念：梯度裁剪用于防止梯度爆炸（Gradient Explosion）——在反向传播时梯度过大导致参数更新剧烈，从而训练不稳定或发散。\n数学原理\n假设梯度向量为 ggg，设定最大范数(L2范数) clip_norm\\text{clip\\_norm}clip_norm。\n若：∥g∥2\u0026gt;clip_norm\\|g\\|_2 \u0026gt; \\text{clip\\_norm}∥g∥2​\u0026gt;clip_norm\n则缩放为： g′=clip_norm∥g∥2gg\u0026#x27; = \\frac{\\text{clip\\_norm}}{\\|g\\|_2} gg′=∥g∥2​clip_norm​g\n否则保持不变。这相当于“压缩”梯度向量长度。\n实现方式\n在 PyTorch 中：\ntorch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) 或按数值裁剪（每个元素限制在 [-x, x]）：\ntorch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5) 应用场景\nRNN / LSTM / GRU 等循环结构（梯度容易爆炸）； Transformer 大模型训练； 一般与 AdamW 搭配使用，稳定训练过程。 10. Warm-up、学习率衰减策略、梯度检查点、梯度累加 # 10.1 Warm-up # **概念：**Warm-up 是在训练初期 逐步增加学习率 的策略，用于防止模型在初始阶段因为学习率过大而导致训练不稳定或梯度爆炸。\n**原理：**模型刚开始训练时，参数是随机初始化的，若立即使用较大学习率，会导致 loss 振荡或发散。通过 Warm-up，学习率会从一个较小的值 线性或指数上升 到目标学习率。\n常见公式（线性 warm-up）：lrt=lrbase×tTwarmuplr_t = lr_{base} \\times \\frac{t}{T_{warmup}}lrt​=lrbase​×Twarmup​t​\n其中 ttt 为当前步数，TwarmupT_{warmup}Twarmup​为预热步数。\n**典型用法：**通常与学习率衰减（cosine、step、linear）配合使用，如 Transformer、YOLO 等模型。\n10.2 学习率衰减策略 # **概念：**学习率衰减是指在训练过程中 逐渐减小学习率 的方法，以便模型在后期更稳定地收敛。\n常见策略：\n**Step Decay（阶梯衰减）：**每隔一定 epoch 学习率下降一个固定比例。\nlr=lr0×γ⌊epoch/step⌋lr = lr_0 \\times \\gamma^{\\lfloor epoch / step \\rfloor}lr=lr0​×γ⌊epoch/step⌋\n**Exponential Decay（指数衰减）：**学习率按指数规律下降。\nlr=lr0×e−k×epochlr = lr_0 \\times e^{-k \\times epoch}lr=lr0​×e−k×epoch\n**Cosine Annealing（余弦退火）：**学习率按余弦曲线平滑下降，可周期性重启。\nlrt=lrmin+12(lrmax−lrmin)(1+cos⁡(tTπ))lr_t = lr_{min} + \\frac{1}{2}(lr_{max} - lr_{min})(1 + \\cos(\\frac{t}{T}\\pi))lrt​=lrmin​+21​(lrmax​−lrmin​)(1+cos(Tt​π))\n**Linear Decay（线性衰减）：**从初始值线性下降到某个最小值，常用于 fine-tuning。\n**作用：**早期快速学习，后期稳定收敛，防止震荡。\n1. 步数的定义\n步（step）= 一次参数更新（optimizer.step()）\n也就是说，每执行一次反向传播并更新模型参数，计作 1 个训练步。\n2. 步数与 epoch、batch 的关系\n假设：\n数据集共有 N 条样本 批大小（batch size）为 B 梯度累加步数（gradient_accumulation_steps）为 A 训练轮数（epochs）为 E 那么训练的总步数计算公式为：total_steps=NB×A×E\\text{total\\_steps} = \\frac{N}{B \\times A} \\times Etotal_steps=B×AN​×E\n举例说明\n假设你的数据集有 50,000 条样本：\nbatch_size = 16 gradient_accumulation_steps = 4 epochs = 3 则：total_steps=5000016×4×3=2343.75≈2344\\text{total\\_steps} = \\frac{50000}{16 \\times 4} \\times 3 = 2343.75 \\approx 2344total_steps=16×450000​×3=2343.75≈2344\n即整个训练过程会执行约 2344 次参数更新。\n10.3 梯度检查点 # 概念： 梯度检查点是一种在训练深度神经网络时 节省显存（GPU Memory） 的方法。其核心思想是：用时间换空间 —— 在前向传播阶段不保存所有层的中间激活值，只在关键节点（称为“检查点”）保存少量激活，其余部分在反向传播时通过重新计算得到。\n**原理：**在普通训练中，反向传播需要用到每层的激活值，因此这些值都会被保存在显存中。而在梯度检查点机制下，只保存若干层的输出，其他层的输出在反向传播时重新计算，从而显著减少显存占用。 这种方式能让深层模型（如 Transformer、YOLOv10、LLaMA 等）在相同显存条件下支持更大的 batch 或更长的输入序列。\n示意：\n常规训练：层1 → 层2 → 层3 → 层4 （全部保存激活） 梯度检查点：仅保存层1、层4的激活，层2、层3在反向传播时重算 效果：\n显存占用可降低约 30%～60%； 计算时间增加约 20%～50%，因为反向传播时需要重新执行部分前向计算。 典型用途：\n大模型（LLM、Transformer、YOLO）训练或微调时节省显存； 低显存显卡（如单卡 3090）进行大 batch 训练或长序列任务； 常与 LoRA / QLoRA、混合精度训练（FP16/BF16） 一起使用以进一步优化显存。 PyTorch 实现：\nfrom torch.utils.checkpoint import checkpoint def forward_pass(x): return model_block(x) out = checkpoint(forward_pass, input) 或直接在模型中启用：\nmodel.gradient_checkpointing_enable() 10.4 梯度累加 # **概念：**当 GPU 显存不足以支撑大批量（batch size）训练时，可通过梯度累加来 模拟更大的 batch。\n**原理：**不是每次反向传播都更新参数，而是累积多次 mini-batch 的梯度，再统一更新。\n**公式示例：**若希望等效 batch size = 64，但 GPU 只能放下 16，可设置：accumulation_steps=64/16=4\\text{accumulation\\_steps} = 64 / 16 = 4accumulation_steps=64/16=4\n即每 4 次反向传播后再更新一次权重。\n示例代码：\nfor step, batch in enumerate(dataloader): loss = model(batch) loss = loss / accumulation_steps loss.backward() if (step + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() 优点：\n在显存受限条件下稳定训练 保持梯度估计精度，改善收敛性 11. 大模型推理参数 # ​\t在大模型推理的最后一步，模型输出的是一个词表大小的向量（Logits）。推理参数主要用于干预这个 Logits 向量经过 Softmax 变成最终 Token 概率分布的过程，以及控制解码路径的搜索策略。\n11.1 Temperature (温度系数) # 公式： Pi=exp⁡(zi/T)∑jexp⁡(zj/T) P_i = \\frac{\\exp(z_i / T)}{\\sum_j \\exp(z_j / T)} Pi​=∑j​exp(zj​/T)exp(zi​/T)​ 其中： ziz_izi​ ：模型输出的原始 Logit 值 TTT ：温度系数 解释：\nT→0T \\to 0T→0 (Low Temperature)：\n分母上的 TTT 越小，Logit 值大的项经过指数运算后会被放大得越明显。 概率分布变得非常 尖锐 (Peaked)，模型趋向于选择概率最高的词（Greedy）。 适用场景：代码生成、数学解题（需要确定性）。 T→∞T \\to \\inftyT→∞ (High Temperature)：\n分布趋向于 均匀分布 (Uniform)。 概率分布变 平坦 (Flat)，低概率词被选中的机会增加。 适用场景：创意写作、头脑风暴（需要多样性）。 11.2 Top-k \u0026amp; Top-p (采样策略) # ​\t这两种策略用于截断概率分布的“长尾”，防止模型采样到极低概率的离谱词汇。\n11.2.1 Top-k (硬截断) # **逻辑：**只保留概率最高的 kkk 个 Token，将其他 Token 的概率清零，并重新归一化剩余的 Token。\n优缺点：\n缺点：kkk 是固定的。若分布很尖，大 kkk 会引入噪音；若分布很平，小 kkk 会切掉合理选项。 11.2.2 Top-p (Nucleus Sampling / 核采样) # 逻辑： 按照概率从高到低排序，选取累计概率之和达到阈值 ppp（如 0.9）的那一组最小的 Token 集合。\n特点：\n动态调整：当模型确信时，候选集小；当模型犹豫时，候选集自动变大。 建议：通常建议 Top-p 和 Temperature 联用（先 Top-k/Top-p 截断，后 Temperature 缩放）。 11.3 Num_beams (集束搜索) # 用于控制解码时的搜索路径数量，从 Sampling (采样) 转向 Search (搜索)。\nGreedy Search (num_beams=1)：\n每一步只选概率最高的一个词。 缺陷：容易陷入局部最优，无法回溯。 Beam Search (num_beams \u0026gt; 1)：\n在每一步生成时，时刻保留 num_beams 条当前总概率最高的路径。 特点：寻找全局最优解，生成的句子更通顺，但计算代价大。 适用场景：机器翻译、文本摘要（需要高质量、确定性的结果）。 11.4 总结对比 # 参数 核心作用 数学/逻辑本质 建议设置 (以 Qwen3 为例) Temperature 创造力 vs. 准确性 Logits 的 Scaling Factor SFT 模型通常 0.7~0.9 Top-k 排除绝对错误的词 固定数量截断 设为 50 左右兜底 Top-p 动态调整候选池大小 累积概率截断 (CDF) 推荐 0.9 或 0.95 Num_beams 搜索质量 vs. 速度 广度优先搜索变体 聊天/创作设为 1；翻译设为 3-5 参考：\n彻底搞懂：Batch Norm, Layer Norm, Instance Norm \u0026amp; Group Norm\nBatch Normalization（批归一化）和 Layer Normalization（层归一化）的一些细节可能和你想的并不一样\n字符编码那点事：快速理解ASCII、Unicode、GBK和UTF-8\nASCII、Unicode、GBK、UTF-8之间的关系_gbk和ascii\nGLU和SwiGLU\n十分钟搞明白Adam和AdamW，SGD，Momentum，RMSProp，Adam，AdamW\n用梯度检查点来节省显存 gradient checkpointing\n大模型训练如何计算显存占用\n10分钟搞明白如何设置大模型推理参数，top_k，top_p, temperature, num_beams。温度，beam search。\n","date":"2026年6月4日","externalUrl":null,"permalink":"/posts/llm-basic/general/","section":"文章","summary":"","title":"LLM 通用基础","type":"posts"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/%E6%BF%80%E6%B4%BB%E5%87%BD%E6%95%B0/","section":"标签","summary":"","title":"激活函数","type":"tags"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0/","section":"标签","summary":"","title":"损失函数","type":"tags"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/%E4%BD%8D%E7%BD%AE%E7%BC%96%E7%A0%81/","section":"标签","summary":"","title":"位置编码","type":"tags"},{"content":"","date":"2026年6月4日","externalUrl":null,"permalink":"/tags/%E4%BC%98%E5%8C%96%E5%99%A8/","section":"标签","summary":"","title":"优化器","type":"tags"},{"content":"hello hugo\nNote 普通说明、补充信息。\nTip 建议、技巧、推荐做法。\nImportant 重要信息，需要特别注意。\nWarning 警告，可能有风险或副作用。\nCaution 谨慎，风险更高或需要避免误操作。\nSuccess 成功操作。\nError 错误操作。\n","date":"2026年6月4日","externalUrl":null,"permalink":"/posts/my-first-post/","section":"文章","summary":"","title":"My First Post","type":"posts"},{"content":"","externalUrl":null,"permalink":"/authors/","section":"Authors","summary":"","title":"Authors","type":"authors"},{"content":"","externalUrl":null,"permalink":"/series/","section":"Series","summary":"","title":"Series","type":"series"}]