本文详细介绍了如何使用 PyTorch 从零开始编写 Stable Diffusion LoRA 训练脚本,包括数据预处理、模型修改、训练循环、参数保存与加载等关键步骤。特别强调了 LoRA 层的手动实现和在 UNet 的 Cross-Attention 层注入 LoRA 的原因,以及在其他层应用 LoRA 的可能性和注意事项。此外,还提供了代码示例和参数效率的讨论,帮助读者深入理解 LoRA 在 Stable Diffusion 微调中的应用。
Q: 训练一个 Stable Diffusion 的 LoRA 模型,要有哪些步骤?
训练 Stable Diffusion 的 LoRA(Low-Rank Adaptation)模型是一种高效的微调方法,可以在保留原模型大部分参数的同时,通过低秩矩阵调整模型行为。以下是详细的步骤指南:
1. 数据准备
- 收集数据集:
- 需要与目标风格/主题相关的高质量图片(建议至少 100-200 张,数量越多效果越好)。
- 图片格式通常为
.jpg
或.png
,分辨率建议与 Stable Diffusion 的基础模型对齐(如 512x512)。
- 标注数据:
- 每张图片需搭配文本描述(
.txt
文件),描述需清晰简洁(如"a photo of a [keyword], [style], [details]"
)。 - 若训练特定对象(如人物、画风),确保描述中统一使用触发词(如
"my_cat_v1"
)。
- 每张图片需搭配文本描述(
2. 环境配置
-
安装依赖库:
# 基础环境 pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu117 pip install diffusers transformers accelerate peft pip install datasets huggingface_hub # 可选:xformers 优化(A卡需调整) pip install xformers
-
硬件要求:
- GPU 显存至少 8GB(如 RTX 3060/4090),24GB 可支持更大批次。
- 使用
accelerate
库配置分布式训练(可选)。
3. 模型准备
-
选择基础模型:
- 常用模型:
runwayml/stable-diffusion-v1-5
、stabilityai/stable-diffusion-2-1
。 - 从 Hugging Face 下载并缓存模型:
from diffusers import StableDiffusionPipeline pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
- 常用模型:
-
配置 LoRA 参数:
- Rank(秩): 控制低秩矩阵的维度(典型值:4-64,越大拟合能力越强,但可能过拟合)。
- 在训练脚本中设置 LoRA 层(通常作用于 Cross-Attention 模块):
from peft import LoraConfig lora_config = LoraConfig( r=8, # Rank 值 target_modules=["to_k", "to_q", "to_v", "to_out.0"], # 注入的模块 init_lora_weights="gaussian", )
4. 训练配置
-
超参数设置:
batch_size: 1-4(根据显存调整) learning_rate: 1e-5 到 1e-4(常用 1e-4) num_epochs: 50-200(观察 loss 调整) optimizer: AdamW(带权重衰减) mixed_precision: fp16(节省显存) gradient_checkpointing: True(进一步省显存)
-
训练脚本:
- 使用
diffusers
的train_text_to_image.py
或自定义训练循环。 - 关键代码片段:
from diffusers import DDPMScheduler scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") trainer = StableDiffusionTrainer( model=pipeline.unet, args=training_args, data_collator=collate_fn, train_dataset=dataset, scheduler=scheduler, ) trainer.add_callback(LoRACallback(lora_config)) # 注入 LoRA
- 使用
5. 训练过程
-
监控与调试:
- 观察 Loss 曲线(理想情况应稳步下降后趋于平稳)。
- 定期生成测试图片(如每 100 步)检查过拟合/欠拟合。
- 调整策略:
- Loss 不降:提高学习率、增加 Rank 值。
- Loss 震荡:降低学习率、增大批次大小。
-
保存检查点:
- 每隔一定步数保存中间模型(如
./output/checkpoint-500
)。 - 最终模型保存为
.safetensors
或合并到原模型。
- 每隔一定步数保存中间模型(如
6. 推理与部署
-
加载 LoRA 权重:
from diffusers import StableDiffusionPipeline pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") pipeline.unet.load_attn_procs("path/to/lora_weights.safetensors")
-
生成图片:
prompt = "A photo of my_cat_v1 sitting on a sofa, best quality" image = pipeline(prompt, num_inference_steps=50).images[0] image.save("output.png")
常见问题
-
Q: 需要多少数据?
- A: 简单风格/物体约 100 张,复杂概念需 500+ 张。
-
Q: 训练时间多久?
- A: 单 GPU 约 1-6 小时(取决于数据量和参数)。
-
Q: 能否用 Colab 免费版?
- A: 可以,但需启用
fp16
和梯度检查点,批次设为 1。
- A: 可以,但需启用
通过以上步骤,你可以高效地训练出针对特定任务的 LoRA 模型,显著降低资源消耗同时保持生成质量。
你确定训练一个 LoRA 需要 100 多张图吗?
Reasoning 过程省略
训练 LoRA 模型所需的图片数量并非固定,具体取决于训练目标、数据质量和参数设置。根据多个来源的建议,最低可接受数量为 10-20 张高质量图片,而复杂场景或高精度需求可能需要 50 张以上。 ...
Q: 我有 Python 与 PyTorch 基础。需要自己写一个 LoRA 训练脚本,应如何写?
以下是基于 PyTorch 实现 Stable Diffusion LoRA 训练脚本的完整指南,从底层原理到代码实现逐步解析。假设你已准备好数据集(图片+文本描述),我们将以 手动实现 LoRA 层 为核心,不依赖 peft
库。
1. LoRA 层的原理与实现
LoRA 的核心思想:在原始权重矩阵旁增加一个低秩矩阵,只训练这部分参数。
import torch
import torch.nn as nn
class LoRALayer(nn.Module):
def __init__(self, original_layer, rank=8, alpha=16):
super().__init__()
self.original = original_layer # 原始线性层 (e.g., nn.Linear)
self.original.requires_grad_(False) # 冻结原始参数
in_dim = original_layer.in_features
out_dim = original_layer.out_features
self.rank = rank
self.alpha = alpha
# 初始化低秩矩阵 A 和 B
self.A = nn.Parameter(torch.randn(in_dim, rank))
self.B = nn.Parameter(torch.zeros(rank, out_dim))
def forward(self, x):
# 原始输出 + 低秩增量
orig_out = self.original(x)
lora_out = x @ self.A @ self.B # [batch, out_dim]
return orig_out + self.alpha * lora_out
2. 修改 Stable Diffusion 的 UNet
需将 LoRA 注入到 UNet 的 Cross-Attention 层(关键步骤):
from diffusers import UNet2DConditionModel
def inject_lora(unet, rank=8, alpha=16):
# 遍历 UNet 的所有注意力层
for name, module in unet.named_modules():
if "attn2" in name and "to_" in name: # 定位 Cross-Attention 的 Q/K/V/Out 层
if isinstance(module, nn.Linear):
# 替换原始 Linear 层为 LoRA 层
lora_layer = LoRALayer(module, rank=rank, alpha=alpha)
# 通过字符串路径替换模块 (需要解析层级结构)
parent = unet
parts = name.split('.')
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1], lora_layer)
return unet
# 加载原始 UNet 并注入 LoRA
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
unet = inject_lora(unet, rank=8)
3. 数据预处理
将图片和文本转换为模型输入:
from torchvision import transforms
from transformers import CLIPTokenizer, CLIPTextModel
# 加载 Stable Diffusion 的文本编码器
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
# 图像预处理(转换为潜在空间)
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
def preprocess_image(image_path):
image = Image.open(image_path).convert("RGB")
transform = transforms.Compose([
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
])
image = transform(image).unsqueeze(0) # [1, 3, 512, 512]
with torch.no_grad():
latents = vae.encode(image).latent_dist.sample() * 0.18215 # [1, 4, 64, 64]
return latents
# 文本编码
def encode_text(text):
inputs = tokenizer(text, return_tensors="pt", padding="max_length", max_length=77, truncation=True)
with torch.no_grad():
text_embeddings = text_encoder(inputs.input_ids)[0] # [1, 77, 768]
return text_embeddings
4. 训练循环
import torch.optim as optim
from diffusers import DDPMScheduler
# 初始化噪声调度器
noise_scheduler = DDPMScheduler(
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear"
)
# 优化器(仅训练 LoRA 参数)
optimizer = optim.AdamW(
filter(lambda p: p.requires_grad, unet.parameters()), # 只选择可训练参数
lr=1e-4,
weight_decay=1e-4
)
# 训练循环
for epoch in range(num_epochs):
for batch in dataloader:
# 加载数据
images, texts = batch
latents = preprocess_image(images)
text_embeds = encode_text(texts)
# 添加噪声
noise = torch.randn_like(latents)
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.shape[0],))
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# 前向传播
noise_pred = unet(noisy_latents, timesteps, text_embeds).sample
# 计算损失
loss = nn.functional.mse_loss(noise_pred, noise)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item()}")
5. 保存 LoRA 权重
def save_lora_weights(unet, save_path):
lora_state_dict = {}
for name, module in unet.named_modules():
if isinstance(module, LoRALayer):
# 保存 A 和 B 矩阵
lora_state_dict[f"{name}.A"] = module.A
lora_state_dict[f"{name}.B"] = module.B
torch.save(lora_state_dict, save_path)
save_lora_weights(unet, "lora_weights.pth")
6. 推理时加载 LoRA
# 加载原始模型
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
# 注入 LoRA
inject_lora(pipeline.unet, rank=8)
pipeline.unet.load_state_dict(torch.load("lora_weights.pth"), strict=False) # 仅加载 LoRA 参数
# 生成图片
image = pipeline("A photo of [your trigger word]").images[0]
关键注意事项
- 参数冻结:确保原始模型的参数被冻结,只训练 LoRA 的
A
和B
矩阵。 - Rank 选择:通常
rank=8
足够,复杂任务可尝试rank=16-64
。 - 混合精度训练:添加
with torch.autocast("cuda")
加速训练并减少显存。 - 梯度检查点:对
unet.enable_gradient_checkpointing()
以减少显存占用。 - 学习率调整:LoRA 学习率通常比全模型微调大(例如
1e-4
vs1e-5
)。
通过此脚本,你可以完全控制 LoRA 的实现细节,适用于定制化需求或研究场景。若追求快速开发,仍推荐使用 peft
库的封装。
Q: 你说“需将 LoRA 注入到 UNet 的 Cross-Attention 层”, LoRA 只能加到 CrossAttention 层吗?为什么?
你的问题非常关键!LoRA 并不只能加到 Cross-Attention 层,但在 Stable Diffusion 微调中,Cross-Attention 层是最关键的目标。以下是深度解析:
1. 理论层面:LoRA 可以加在哪里?
从数学原理上,LoRA 是一种通用的参数高效微调方法,可应用于任何线性层(如 nn.Linear
或等效的矩阵运算)。例如:
- Self-Attention 中的 Q/K/V 层(处理图像自身特征)
- Feed-Forward Network (FFN) 的中间层
- 甚至 VAE 或 Text Encoder 的某些层(但较少见)
但实际应用中,选择注入位置需权衡效果与效率。
2. 为什么 Stable Diffusion 优先选择 Cross-Attention 层?
(1) 文本-图像对齐的核心枢纽
- Cross-Attention 层是 UNet 中文本条件(Text Embedding)与图像特征交互的唯一通道。
- 修改这些层能直接控制生成内容如何响应文本提示(Prompt),例如:
- 学习新概念(如特定物体
my_cat_v1
) - 调整风格(如油画风
van gogh style
)
- 学习新概念(如特定物体
(2) 参数效率最大化
- Stable Diffusion 的 UNet 参数量巨大(约 860M),但 Cross-Attention 层仅占约 10% 的参数量。
- 仅微调这部分,即可用 0.1%~1% 的参数量达到接近全参数微调的效果,性价比极高。
(3) 实践经验验证
- 大量实验表明,调整 Cross-Attention 层对生成结果影响最显著。
- 若同时微调 Self-Attention 或 FFN 层,可能引入噪声(如破坏预训练模型的构图能力)。
3. 其他可能的目标层
(1) Self-Attention 层
- 作用:控制图像内部结构(如物体形状、布局)。
- 适用场景:需要修改生成图像的构图规则(如强制对称性)。
- 风险:容易导致模型遗忘预训练知识(如失去多样性)。
(2) FFN 层
- 作用:非线性特征变换。
- 适用场景:调整颜色、纹理等低级特征。
- 缺点:参数量大,效率较低。
(3) 组合策略
- 同时微调 Cross-Attention + 部分 Self-Attention,可平衡控制力与稳定性,但需更多显存和训练数据。
4. 代码示例:注入到 Self-Attention
只需修改 inject_lora
函数中的目标模块筛选条件:
def inject_lora(unet, rank=8):
for name, module in unet.named_modules():
# 定位 Self-Attention 的 Q/K/V 层("attn1" 表示 Self-Attention)
if "attn1" in name and "to_" in name:
if isinstance(module, nn.Linear):
# 替换为 LoRALayer
...
return unet
5. 如何选择注入位置?
目标层类型 | 参数量占比 | 影响范围 | 推荐场景 |
---|---|---|---|
Cross-Attention | ~10% | 文本控制、概念学习 | 学习新物体/风格(主流选择) |
Self-Attention | ~20% | 图像结构、布局 | 特殊构图需求 |
FFN | ~70% | 颜色、纹理 | 低级特征微调 |
6. 终极建议
- 优先 Cross-Attention:95% 的场景只需调整这里。
- 谨慎添加其他层:仅在数据充足且任务特殊时尝试(如医学图像生成需精确形状控制)。
- 监控过拟合:添加更多可训练参数会增加过拟合风险,需配合数据增强/早停法。
UNet2DConditionModel
的整个流程中, Cross-Attention 主要作用在下采样、中间块、上采样三个步骤中。在 __init__
参数中传入每个步骤中用的 Module ,其中用 CrossAttnDownBlock2D
等 Module self.has_cross_attention
为 true ,能处理 cross_attention_mask 等参数。其中的 attentions 层即为交叉注意力层。推测,是通过 Attentions 层对下采样层入参做处理,并列的 ResNet 应该是为了提升注意力层的训练性能。
UNet 代码中的 Cross-Attension 层
以下为在 Cursor 中有 diffusers 代码上下文时与 claude-3.7-sonnet 的对话内容。
UNet 中 Cross-Attension 是做什么的?具体到代码里如何作用? @unet_2d_condition.py
Cross-Attention 在 UNet 中主要用于将条件信息(如文本特征)引入到图像生成过程中。从代码来看,主要体现在以下几个方面:
-
数据流向:
- UNet 接收图像特征(
sample
)和条件编码(encoder_hidden_states
) encoder_hidden_states
通常是文本编码器输出的特征- 条件信息通过 Cross-Attention 在 UNet 的不同层级与图像特征交互
- UNet 接收图像特征(
-
代码实现:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, **additional_residuals, )
-
核心组件:
- 在下采样、中间块和上采样阶段都可能使用 Cross-Attention
- 通过
has_cross_attention
属性判断块是否需要使用 Cross-Attention - 当使用 Cross-Attention 时,
encoder_hidden_states
会作为额外输入传递
-
参数控制:
only_cross_attention
参数控制是否只使用 Cross-Attention 而不使用 Self-Attentioncross_attention_dim
指定了条件编码的维度大小transformer_layers_per_block
控制每个块中 Transformer 层的数量
Cross-Attention 允许模型在保持图像空间结构的同时,根据文本或其他条件调整生成内容,实现了条件生成的核心机制。
TransformerBlock 具体是在@transformer_2d.py 里实现吗?如何体现 Self-Attention 与 Cross-Attention ?
TransformerBlock 的核心实现是在attention.py
中的BasicTransformerBlock
类中,而不是直接在transformer_2d.py
中。在transformer_2d.py
中只是调用了这个类来构建 Transformer 模型。
从transformer_2d.py
代码中可以看到 Self-Attention 和 Cross-Attention 的几个关键特点:
-
实例化 transformer 块时:
BasicTransformerBlock( self.inner_dim, self.config.num_attention_heads, self.config.attention_head_dim, dropout=self.config.dropout, cross_attention_dim=self.config.cross_attention_dim, activation_fn=self.config.activation_fn, num_embeds_ada_norm=self.config.num_embeds_ada_norm, attention_bias=self.config.attention_bias, only_cross_attention=self.config.only_cross_attention, double_self_attention=self.config.double_self_attention, ... )
-
调用时的数据流:
hidden_states = block( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, )
核心区别:
- Self-Attention: 使用
hidden_states
自身进行注意力计算,attention_mask
作为掩码 - Cross-Attention: 使用
hidden_states
作为 query,encoder_hidden_states
作为 key 和 value,encoder_attention_mask
作为掩码
配置选项:
only_cross_attention
: 如果为 True,只使用 Cross-Attention 而跳过 Self-Attentiondouble_self_attention
: 如果为 True,使用两次 Self-Attentioncross_attention_dim
: 指定 Cross-Attention 的维度大小
在实际运行时,BasicTransformerBlock 内部会先进行 Self-Attention 操作,然后进行 Cross-Attention 操作(如果提供了encoder_hidden_states
)。这样就实现了 UNet 中文本条件对图像生成的影响。
@attention.py 中 BasicTransformerBlock 如何体现 Self-Attention 与 Cross-Attention ?
让我看看 attention.py 中的 BasicTransformerBlock 类,特别是它如何实现 Self-Attention 和 Cross-Attention 的部分。
从代码中,我可以看到 BasicTransformerBlock 类包含了以下几个关键组件:
初始化函数中,它设置了多个 Attention 模块和标准化层:
self.norm1 和 self.attn1 用于 Self-Attention
self.norm2 和 self.attn2 用于 Cross-Attention(如果需要的话)
self.norm3 和 self.ff 用于前馈网络
在 forward 函数中,处理流程如下:
首先对输入进行 normalization
然后应用 Self-Attention(self.attn1)
如果有 cross_attention_dim 或 double_self_attention,则应用 Cross-Attention(self.attn2)
最后是 Feed-Forward 网络
特别需要注意的是,对于 Self-Attention 和 Cross-Attention 的区别:
Self-Attention:使用 self.attn1,只使用 hidden_states 作为输入,在 only_cross_attention 为 True 时可能使用 encoder_hidden_states。
Cross-Attention:使用 self.attn2,hidden_states 作为 query,encoder_hidden_states 作为 key 和 value。
在 forward 方法中,我们可以清楚看到这一点:
对于 Self-Attention,attn_output = self.attn1(norm_hidden_states, ...)
对于 Cross-Attention,attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states, ...)
我会重点关注 forward 方法中实现 Self-Attention 和 Cross-Attention 的关键部分。
在attention.py
中,BasicTransformerBlock
类通过以下方式实现了 Self-Attention 和 Cross-Attention:
1. 架构设计
# 1. Self-Attn
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
2. 前向传播过程
# 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
# ... (处理输出)
hidden_states = attn_output + hidden_states
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
3. 关键区别
-
Self-Attention (attn1):
- 默认情况下只使用
hidden_states
作为输入(query, key, value 都来自同一源) - 只有当
only_cross_attention=True
时才会用encoder_hidden_states
- 使用
attention_mask
作为掩码
- 默认情况下只使用
-
Cross-Attention (attn2):
norm_hidden_states
作为 queryencoder_hidden_states
作为 key 和 value- 使用
encoder_attention_mask
作为掩码 - 当
encoder_hidden_states=None
时会退化为 Self-Attention
4. 配置控制
only_cross_attention
:是否在 attn1 中使用 cross-attentiondouble_self_attention
:是否将 attn2 配置为另一个 self-attention 层cross_attention_dim
:指定条件编码的维度
这种设计允许 UNet 模型在不同层级灵活组合 Self-Attention 和 Cross-Attention,实现对图像特征的内部处理和对条件信息的整合。
我注意到 @attention.py 中在 Attention 层前都会有一个 LayerNorm 层。这是必须要有的吗?有什么作用?
是的,在attention.py
的BasicTransformerBlock
中,每个 Attention 层前都有一个归一化层(通常是 LayerNorm),这是非常重要的设计:
LayerNorm 的必要性
-
训练稳定性:
# 在Self-Attention前应用归一化 if self.norm_type == "ada_norm": norm_hidden_states = self.norm1(hidden_states, timestep) elif self.norm_type == "ada_norm_zero": norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype ) elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: norm_hidden_states = self.norm1(hidden_states)
-
梯度流动:
- 归一化层将数据分布标准化,防止梯度消失或爆炸
- 特别是在深层网络中,保持梯度流动对训练至关重要
LayerNorm 的作用
-
特征分布稳定:
- 将特征归一化到相同尺度,减少批次间差异
- 使网络对输入变化更加鲁棒,提高泛化能力
-
残差连接的配合:
# 应用残差连接 hidden_states = attn_output + hidden_states
- Pre-LayerNorm 结构(先归一化后 Attention)与残差连接配合使用
- 这种"归一化 → 操作 → 残差"的模式是 Transformer 架构的基本组成部分
-
多种归一化选择:
# 不同类型的归一化层 if self.norm_type == "ada_norm": norm_hidden_states = self.norm1(hidden_states, timestep) elif self.norm_type == "ada_norm_zero": # ... elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: # ...
- 基本的 LayerNorm
- 带条件的 AdaLayerNorm(使用 timestep 信息调整)
- 其他特殊归一化变体,适应不同场景
总之,LayerNorm 在 Attention 层前是必不可少的,它是保证 Transformer 结构有效训练和良好性能的关键组件,特别是在生成模型这样需要处理大量复杂特征的深度网络中。