从离散到连续:扩散模型与自回归模型的融合生成范式深度解析
一、背景介绍
在生成式AI的演进历程中,两类主流范式长期占据着主导地位:自回归模型与扩散模型。前者以GPT、DALL-E为代表,通过逐步预测离散token实现生成;后者则以Stable Diffusion、Imagen为代表,通过连续空间中的逐步去噪获得高质量图像。长期以来,这两条技术路线各自发展,鲜有交集。
然而,随着2023年DiT(Diffusion Transformer)和2024年MAR(Masked Autoregressive)系列工作的出现,一个令人振奋的趋势逐渐清晰:将扩散过程的连续去噪与自回归的离散预测相结合,正在成为文生图领域的新主流方向。这种融合并非简单的技术堆叠,而是在概率建模层面实现了深刻的统一。
传统自回归模型面临的核心挑战在于:离散token的预测天然缺乏对全局一致性的建模能力,导致长距离依赖难以捕捉。而扩散模型虽然在图像质量上表现出色,但其连续去噪过程缺乏显式的结构约束,难以实现灵活的局部控制。融合范式正是为了取长补短——用自回归的因果结构提供生成框架,用扩散的连续去噪保证视觉质量。
从应用角度看,这种融合范式在多个维度展现出显著优势:生成质量达到甚至超越纯扩散模型,推理速度较纯自回归模型提升数倍,同时支持条件控制、局部编辑等高级功能。在视频生成领域,这种范式更是展现出独特价值——利用自回归的时间结构结合扩散的空间建模,能够生成既连贯又高质的视频内容。
二、技术原理
2.1 核心思想:离散骨架与连续纹理
融合范式的核心洞察在于:视觉生成可以分解为两个阶段——离散的“骨架”预测和连续的“纹理”填充。自回归模型擅长捕捉离散token之间的内在结构关系,这恰好对应于图像的语义骨架;扩散模型擅长从噪声中恢复连续细节,这对应于图像的纹理质感。
具体而言,融合模型通常采用两阶段架构:
- 离散编码阶段:使用VQ-VAE或类似方法将图像编码为离散token序列
- 混合生成阶段:自回归模型预测token序列,扩散模型在token对应的连续空间中进行去噪
这种设计巧妙地将两种范式的优势结合:自回归部分提供因果约束和灵活的条件控制,扩散部分确保每个token对应的视觉区域具有高质量的局部细节。
2.2 数学基础:从交叉熵到扩散损失
理解融合范式的关键在于统一两种损失函数。自回归模型使用交叉熵损失:
L_ar = -Σ log p(x_i | x_{<i})扩散模型使用噪声预测损失:
L_diff = E[||ε - ε_θ(x_t, t)||²]在融合范式中,这两种损失被巧妙地结合。以MAR(Masked Autoregressive)为例,其核心创新在于引入“掩码自回归”机制:
- 随机掩码部分token
- 使用自回归方式预测掩码token
- 对预测结果应用扩散损失进行细化
数学上,这等价于构建一个混合概率模型:
p(x) = Σ_m p(m) · p_ar(x_m | x_{¬m}) · p_diff(x_{¬m} | x_m)其中m为掩码模式,p_ar为自回归预测分布,p_diff为条件扩散分布。
2.3 关键创新:连续token表示
传统自回归模型将每个token映射为离散类别,而融合范式引入连续token表示。每个token对应一个连续向量,扩散过程在这个连续空间中执行去噪。这种设计带来了几个关键优势:
- 信息密度提升:连续表示可以编码更丰富的视觉信息
- 梯度传播友好:避免离散化导致的梯度截断
- 自然支持插值:连续空间中的线性插值对应视觉上的平滑过渡
具体实现上,通常采用“量化-反量化”策略:编码器将图像映射为连续向量,经过向量量化得到离散索引,解码器将离散索引映射回连续空间。扩散模型作用于解码器输出的连续表示上。
三、系统架构设计
3.1 整体架构
[
系统采用分层架构设计,从上到下依次为:
- 控制层:接收文本提示、图像条件等输入
- 生成层:包含自回归模块和扩散模块
- 表示层:负责图像与token之间的转换
- 优化层:提供推理加速和内存管理
3.2 模块详细设计
VQ-VAE编码器:
- 输入:RGB图像 (H x W x 3)
- 输出:离散token序列 (h x w)
- 压缩比:通常为16x或8x
自回归Transformer:
- 架构:Causal Transformer Decoder
- 输入:部分可见的token序列
- 输出:预测的下一个token分布
扩散去噪器:
- 架构:U-Net或DiT
- 输入:噪声化连续表示 + 时间步
- 输出:预测噪声
条件融合模块:
- 将文本嵌入与视觉特征交叉注意力
- 支持多种条件形式(文本、图像、掩码)
3.3 数据流设计
生成过程的数据流分为三个阶段:
阶段一:骨架生成
文本 → 文本编码器 → 自回归Transformer → 离散token序列阶段二:连续映射
离散token → 嵌入表 → 连续向量序列阶段三:细节细化
连续向量 + 噪声 → 扩散去噪器 → 精细连续表示 → VQ-VAE解码器 → 图像四、核心实现(Golang代码)
4.1 基础数据结构
// token表示图像中的离散标记typeTokenstruct{Indexint// 离散索引Embed[]float32// 对应的连续嵌入向量Maskedbool// 是否为掩码状态}// 图像表示,包含离散和连续两种形式typeImageRepresentationstruct{DiscreteTokens[]Token// 离散token序列ContinuousLatent[]float32// 连续潜在表示Height,Widthint// 空间维度}// 扩散配置参数typeDiffusionConfigstruct{Timestepsint// 总去噪步数BetaStartfloat32// 噪声调度起始BetaEndfloat32// 噪声调度结束ScheduleTypestring// 调度类型:linear/cosine}// 自回归配置参数typeARConfigstruct{MaxSeqLenint// 最大序列长度NumLayersint// Transformer层数NumHeadsint// 注意力头数EmbedDimint// 嵌入维度VocabSizeint// 词汇表大小}4.2 VQ-VAE编码器实现
// VQVAEEncoder 将图像编码为离散tokentypeVQVAEEncoderstruct{ConvLayers[]ConvLayer Codebook[]float32// 码本向量EmbedDimint}func(e*VQVAEEncoder)Encode(image[]float32)(*ImageRepresentation,error){// 1. 卷积下采样latent:=imagefor_,layer:=rangee.ConvLayers{latent=layer.Forward(latent)}// 2. 向量量化:找到最近的码本向量h,w:=len(latent)/e.EmbedDim,e.EmbedDim tokens:=make([]Token,h*w)fori:=0;i<h*w;i++{// 计算当前向量与所有码本的距离minDist:=float32(math.MaxFloat32)bestIdx:=0forj,code:=rangee.Codebook{dist:=euclideanDistance(latent[i*e.EmbedDim:(i+1)*e.EmbedDim],code)ifdist<minDist{minDist=dist bestIdx=j}}// 记录离散索引和连续嵌入tokens[i]=Token{Index:bestIdx,Embed:e.Codebook[bestIdx*e.EmbedDim:(bestIdx+1)*e.EmbedDim],Masked:false,}}return&ImageRepresentation{DiscreteTokens:tokens,ContinuousLatent:latent,Height:h,Width:w,},nil}4.3 自回归生成器
// ARGenerator 自回归token预测器typeARGeneratorstruct{Transformer*CausalTransformer Config*ARConfig}// Generate 自回归生成token序列