最近一直在做扩散模型相关工作,一方面疯狂吸收理论知识,扩散模型中的数学属实太多,经过一段疯狂学习后,算是初窥门径,后续准备把自己一些理解写出来。另一方面,关注图片编辑这个实用的领域,这个领域还处在快速发展之中,希望能做出一点有用的东西。今天这篇文章就是分享阶跃星辰的Step1X-Edit这个图片编辑模型的执行流程。先说说为什么选择这个模型,因为这是第一个使用了vlm的图片编辑模型。个人认为,使用vlm编码文字编辑指令和图片,是一个非常优雅的解决方案,因为可以不费力吸收最新的vlm模型的成果。
我会通过debug的模式,逐步记录执行过程中的关键结果。使用的代码是官方repo中的inference.py文件。
这里是加载模型的代码。整个step1x-edit由三部分组成。autoencoder、dit、llm_encoder。autoencoder用来将图像编码到潜空间,dit用来执行扩散建模,llm_encoder用来编码原图和文字编辑指令。step1x-edit使用Qwen2.5-VL-7B-Instruct模型作为llm_encoder。
self.ae, self.dit, self.llm_encoder = load_models(
dit_path=dit_path,
ae_path=ae_path,
qwen2vl_model_path=qwen2vl_model_path,
max_length=max_length,
dtype=dtype,
device=self.device,
mode=mode
)
我们使用的编辑指令:’给这个女生的脖子上戴一个带有红宝石的吊坠。’ 使用的图片如下所示:
在input_process_image方法中,是对输入图片的w和h的调整。 初始w和h是:
w = 1024
h = 1536
调整后:
w = 416
h = 624
调整逻辑主要是两点:
ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1)
上面缩放的图片经过除255乘2减1,每个像素值变成了-1到+1之间。这就可以输入给autoencoder了。 基本过程就是encode加sample。
在encoder中,首先经过一个卷积,将channel从3变为128:
torch.Size([1, 128, 624, 416])
然后就是一系列降采样,shape变为:
torch.Size([1, 512, 78, 52])
然后又经过一系列卷积和attention,最后输出:
torch.Size([1, 32, 78, 52])
这里的78, 52是与原图大小和img_size相关的。32则是固定不变的,被称为z_channels。 整体上,encode进行了8倍降采样。624/78=8, 416/52=8。同样的,decode会进行8倍上采样。 然后是sample:
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
mean和logvar的shape都是
torch.Size([1, 16, 78, 52])
通过重参数化进行了抽样。
紧接着,进行了scale和shift变换:
z = self.scale_factor * (z - self.shift_factor)
其中:
scale_factor = 0.3611
shift_factor = 0.1159
这个过程是为了将latent规整到标准正态分布,其计算过程如下:
# 在大量训练数据上统计
mean_z = torch.mean(latents) # 均值 → shift_factor
std_z = torch.std(latents) # 标准差 → 1/scale_factor
shift_factor = mean_z # ≈ 0.1159
scale_factor = 1/std_z # ≈ 0.3611
这个步骤主要是为了更好地与扩散过程兼容,让扩散模型的输入分布更加稳定。
扩散之前,先构建噪声。
x = torch.randn(
num_samples,
16,
height // 8,
width // 8,
device=self.device,
dtype=torch.bfloat16,
generator=torch.Generator(device=self.device).manual_seed(seed),
)
这里的16就是上面提到的z_channels。height和width都除以8,就是因为vae的encode进行了8倍的下采样。
然后,就是获取一个时间步的规划:
timesteps = sampling.get_schedule(
num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True
)
'''
[0.9999999403953552, 0.9806238412857056, 0.9605796933174133, 0.9398325681686401, 0.9183447957038879, 0.8960757851600647, 0.872982382774353, 0.8490179777145386, 0.8241321444511414, 0.7982708215713501, 0.7713754177093506, 0.7433827519416809, 0.7142242193222046, 0.6838254332542419, 0.6521055698394775, 0.6189765334129333, 0.584342360496521, 0.5480981469154358, 0.5101287364959717, 0.47030821442604065, 0.42849764227867126, 0.3845440745353699, 0.3382784128189087, 0.28951331973075867, 0.23804061114788055, 0.18362829089164734, 0.1260172426700592, 0.06491677463054657, 0.0]
'''
这里需要指出一个小技巧。28步中,在分配在高时间的步数更多一点。统计一下可知,t>0.5的步骤数是19。 而且,这个19,会随着图片的分辨率增大而增大。
这么做的目的是,对于高分辨率图像,在初期,噪声较大时,多进行几次迭代,小心求索。后期,噪声比较小时,则可以适当加大步伐,尽快收敛。
x = torch.cat([x, x], dim=0)
ref_images = torch.cat([ref_images, ref_images], dim=0)
ref_images_raw = torch.cat([ref_images_raw, ref_images_raw], dim=0)
inputs = self.prepare([prompt, negative_prompt], x, ref_image=ref_images, ref_image_raw=ref_images_raw)
这里:
x.shape
torch.Size([2, 16, 78, 52])
ref_images.shape
torch.Size([2, 16, 78, 52])
ref_images_raw.shape
torch.Size([2, 3, 624, 416])
之所以搞双份,是因为要一次性将cond和uncond扩散都完成,属于cfg的基操了。 从这时起,batch就一直是2了。第一个数据是带文本编辑指令的扩散,第二个是没有文本编辑指令的扩散。
下面看看prepare函数: 里面,首先对img和ref_image进行了2x2的分块:
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
ref_img = rearrange(ref_image, "b c (ref_h ph) (ref_w pw) -> b (ref_h ref_w) (c ph pw)", ph=2, pw=2)
做完这一步时, img和ref_img的shape变为:
torch.Size([2, 1014, 64])
64=16x4
1014=78x52/4
这是vit的核心预处理步骤。可参考下面的示意代码:
# 步骤1:图像分块
patches = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=16, pw=16)
# 步骤2:线性投影到embedding空间
patch_embeddings = self.patch_embed(patches) # (b, num_patches, embed_dim)
# 步骤3:加入位置编码
embeddings = patch_embeddings + self.pos_encoding
# 步骤4:输入Transformer
output = self.transformer(embeddings)
然后是,计算img和ref_image的位置id。
img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)
ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None]
ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :]
ref_img_ids = repeat(ref_img_ids, "ref_h ref_w c -> b (ref_h ref_w) c", b=bs)
因为是2x2的分块,所以这里的位置编码也都是h和w除以2的。这里的位置编码同时编码了行和列。
img和ref_img、img_ids和ref_img_ids最后拼接到一起返回。
再然后,就是使用llm_encoder对prompt和原始图像进行编码。
其返回txt和masks
txt, mask = self.llm_encoder(prompt, ref_image_raw)
txt.shape
torch.Size([2, 640, 3584])
masks.shape
torch.Size([2, 640])
对于指令中有文字编辑的部分,就是带引号的字,比如:给图像添加文字“你好”。做了特殊处理,确保每个字都单独进行tokenize。没有使用qwen默认的tokenize结果。如果不单独处理,qwen tokenizer可能将 你好 作为一个单独的token。就丢失了要写的文字的信息了。 请注意,上面的代码中,llm_encoder返回的是文本编辑指令和参考图像的编码,但是却命名为txt和mask。原因是,在往后的处理逻辑中,就是按照文生图的逻辑来了。txt就是代表了“文”,虽然其中包含了图像信息,往后并不care了。
其次,txt中不包含system prompt部分。从下面代码可以看出:
embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
: self.max_length
]
217就是system prompt的长度。
其次,对于无条件的部分,prompt是空的。这一点也需要注意。 输入的prompt实际上是:
['给这个女生的脖子上戴一个带有红宝石的吊坠。', '']
在扩散过程中。txt也有一个位置编码:
txt_ids = torch.zeros(bs, txt.shape[1], 3)
保持了和img相同的channel 3。
最后返回的是一个dict:
{
"img": img,
"mask": mask,
"img_ids": img_ids.to(img.device),
"llm_embedding": txt.to(img.device),
"txt_ids": txt_ids.to(img.device),
}
对应shape:
{
'img': torch.Size([2, 2028, 64]),
'mask': torch.Size([2, 640]),
'img_ids': torch.Size([2, 2028, 3]),
'llm_embedding': torch.Size([2, 640, 3584]),
'txt_ids': torch.Size([2, 640, 3])
}
这里的2028来自于将img和ref_image进行了concat。本来是1014,变成了2028。 注意img中的每条数据中,前一半是要扩散到图像的噪声,后一半是参考图像。也就是为了在扩散过程中的每一步都能有所参考。不过这里已经是经过vae的latent图像了。这是和一般的文生图另一个不同的地方。
第一个不同的地方是编码txt时,同时将文本编辑指令和原始参考图通过vlm编码到一起了。注意这里的参考图不是latent图像,而是真实的图像像素信息。
细细品味,可以知道。后面的扩散过程中,txt中有原始像素信息,img中是latent的图像信息。latent和真实像素信息相互交织。
640来自于qwen vl模型对图片和文本编辑指令的编码。
有了上面的5个值,就可以进行去噪了。
with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
x = self.denoise(
**inputs,
cfg_guidance=cfg_guidance,
timesteps=timesteps,
show_progress=show_progress,
timesteps_truncate=1.0,
)
denoise方法整体:
def denoise(
self,
img: torch.Tensor,
img_ids: torch.Tensor,
llm_embedding: torch.Tensor,
txt_ids: torch.Tensor,
timesteps: list[float],
cfg_guidance: float = 4.5,
mask=None,
show_progress=False,
timesteps_truncate=1.0,
):
if self.offload:
self.dit = self.dit.to(self.device)
if show_progress:
pbar = tqdm(itertools.pairwise(timesteps), desc='denoising...')
else:
pbar = itertools.pairwise(timesteps)
for t_curr, t_prev in pbar:
if img.shape[0] == 1 and cfg_guidance != -1:
img = torch.cat([img, img], dim=0)
t_vec = torch.full(
(img.shape[0],), t_curr, dtype=img.dtype, device=img.device
)
pred = self.dit(
img=img,
img_ids=img_ids,
txt_ids=txt_ids,
timesteps=t_vec,
llm_embedding=llm_embedding,
t_vec=t_vec,
mask=mask,
)
# txt, vec = self.dit.connector(llm_embedding, t_vec, mask)
# pred = self.dit(
# img=img,
# img_ids=img_ids,
# txt=txt,
# txt_ids=txt_ids,
# y=vec,
# timesteps=t_vec,
# )
if cfg_guidance != -1:
cond, uncond = (
pred[0 : pred.shape[0] // 2, :],
pred[pred.shape[0] // 2 :, :],
)
if t_curr > timesteps_truncate:
diff = cond - uncond
diff_norm = torch.norm(diff, dim=(2), keepdim=True)
pred = uncond + cfg_guidance * (
cond - uncond
) / self.process_diff_norm(diff_norm, k=0.4)
else:
pred = uncond + cfg_guidance * (cond - uncond)
tem_img = img[0 : img.shape[0] // 2, :] + (t_prev - t_curr) * pred
img_input_length = img.shape[1] // 2
img = torch.cat(
[
tem_img[:, :img_input_length],
img[ : img.shape[0] // 2, img_input_length:],
], dim=1
)
if self.offload:
self.dit = self.dit.cpu()
cudagc()
return img[:, :img.shape[1] // 2]
这段代码有两个细节值得关注:
img = torch.cat(
[
tem_img[:, :img_input_length],
img[ : img.shape[0] // 2, img_input_length:],
], dim=1
)
这段代码,将ref_img恢复了原样,也就是在每次去噪时,噪声图像旁边都有一个不变的参考图像。具体说,就是img[ : img.shape[0] // 2, img_input_length:]这段代码的作用。
第二个细节,后面的每个时间步中,第二条数据中的img都是和第一条数据保持一致的。上面代码中的img的batch为1,在进入下次循环时,会进行复制:
if img.shape[0] == 1 and cfg_guidance != -1:
img = torch.cat([img, img], dim=0)
这样做的目的是使得在每个时间步中。cond和uncond都在同一起跑线。这才方便对比。
剩下就是比较复杂的dit部分。这里我们先把dit的输入进行明确:
pred = self.dit(
img=img,
img_ids=img_ids,
txt_ids=txt_ids,
timesteps=t_vec,
llm_embedding=llm_embedding,
t_vec=t_vec,
mask=mask,
)
每个输入的信息如下:
img.shape = [2, 2028, 64]
img_ids.shape = [2, 2028, 3]
txt_ids.shape = [2, 640, 3]
timesteps = [0.9805, 0.9805] # 实际值,会动态变化
t_vec = timesteps
llm_embedding.shape = [2, 640, 3584]
mask.shape = [2, 640]
将在下篇中详细追踪dit的执行流程。