Step1X-Edit执行流程(一)


最近一直在做扩散模型相关工作,一方面疯狂吸收理论知识,扩散模型中的数学属实太多,经过一段疯狂学习后,算是初窥门径,后续准备把自己一些理解写出来。另一方面,关注图片编辑这个实用的领域,这个领域还处在快速发展之中,希望能做出一点有用的东西。今天这篇文章就是分享阶跃星辰的Step1X-Edit这个图片编辑模型的执行流程。先说说为什么选择这个模型,因为这是第一个使用了vlm的图片编辑模型。个人认为,使用vlm编码文字编辑指令和图片,是一个非常优雅的解决方案,因为可以不费力吸收最新的vlm模型的成果。

我会通过debug的模式,逐步记录执行过程中的关键结果。使用的代码是官方repo中的inference.py文件。

1、加载模型

这里是加载模型的代码。整个step1x-edit由三部分组成。autoencoder、dit、llm_encoder。autoencoder用来将图像编码到潜空间,dit用来执行扩散建模,llm_encoder用来编码原图和文字编辑指令。step1x-edit使用Qwen2.5-VL-7B-Instruct模型作为llm_encoder。

self.ae, self.dit, self.llm_encoder = load_models(
            dit_path=dit_path,
            ae_path=ae_path,
            qwen2vl_model_path=qwen2vl_model_path,
            max_length=max_length,
            dtype=dtype,
            device=self.device,
            mode=mode
        )

2、确定数据

我们使用的编辑指令:’给这个女生的脖子上戴一个带有红宝石的吊坠。’ 使用的图片如下所示:

20250624135433

3、数据预处理

在input_process_image方法中,是对输入图片的w和h的调整。 初始w和h是:

w = 1024
h = 1536

调整后:

w = 416
h = 624

调整逻辑主要是两点:

  • 面积是img_size*img_size
  • 宽高比不变

4、encode

ref_images = self.ae.encode(ref_images_raw.to(self.device) * 2 - 1)

上面缩放的图片经过除255乘2减1,每个像素值变成了-1到+1之间。这就可以输入给autoencoder了。 基本过程就是encode加sample。

在encoder中,首先经过一个卷积,将channel从3变为128:

torch.Size([1, 128, 624, 416])

然后就是一系列降采样,shape变为:

torch.Size([1, 512, 78, 52])

然后又经过一系列卷积和attention,最后输出:

torch.Size([1, 32, 78, 52])

这里的78, 52是与原图大小和img_size相关的。32则是固定不变的,被称为z_channels。 整体上,encode进行了8倍降采样。624/78=8, 416/52=8。同样的,decode会进行8倍上采样。 然后是sample:

mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
    if self.sample:
        std = torch.exp(0.5 * logvar)
        return mean + std * torch.randn_like(mean)

mean和logvar的shape都是

torch.Size([1, 16, 78, 52])

通过重参数化进行了抽样。

紧接着,进行了scale和shift变换:

z = self.scale_factor * (z - self.shift_factor)

其中:

scale_factor = 0.3611
shift_factor = 0.1159

这个过程是为了将latent规整到标准正态分布,其计算过程如下:

# 在大量训练数据上统计
mean_z = torch.mean(latents)  # 均值 → shift_factor
std_z = torch.std(latents)    # 标准差 → 1/scale_factor

shift_factor = mean_z         # ≈ 0.1159
scale_factor = 1/std_z        # ≈ 0.3611

这个步骤主要是为了更好地与扩散过程兼容,让扩散模型的输入分布更加稳定。

5、初始化扩散

扩散之前,先构建噪声。

x = torch.randn(
            num_samples,
            16,
            height // 8,
            width // 8,
            device=self.device,
            dtype=torch.bfloat16,
            generator=torch.Generator(device=self.device).manual_seed(seed),
        )

这里的16就是上面提到的z_channels。height和width都除以8,就是因为vae的encode进行了8倍的下采样。

然后,就是获取一个时间步的规划:

timesteps = sampling.get_schedule(
            num_steps, x.shape[-1] * x.shape[-2] // 4, shift=True
        )

'''
[0.9999999403953552, 0.9806238412857056, 0.9605796933174133, 0.9398325681686401, 0.9183447957038879, 0.8960757851600647, 0.872982382774353, 0.8490179777145386, 0.8241321444511414, 0.7982708215713501, 0.7713754177093506, 0.7433827519416809, 0.7142242193222046, 0.6838254332542419, 0.6521055698394775, 0.6189765334129333, 0.584342360496521, 0.5480981469154358, 0.5101287364959717, 0.47030821442604065, 0.42849764227867126, 0.3845440745353699, 0.3382784128189087, 0.28951331973075867, 0.23804061114788055, 0.18362829089164734, 0.1260172426700592, 0.06491677463054657, 0.0]
'''

这里需要指出一个小技巧。28步中,在分配在高时间的步数更多一点。统计一下可知,t>0.5的步骤数是19。 而且,这个19,会随着图片的分辨率增大而增大。

这么做的目的是,对于高分辨率图像,在初期,噪声较大时,多进行几次迭代,小心求索。后期,噪声比较小时,则可以适当加大步伐,尽快收敛。

6、条件编码

x = torch.cat([x, x], dim=0)
ref_images = torch.cat([ref_images, ref_images], dim=0)
ref_images_raw = torch.cat([ref_images_raw, ref_images_raw], dim=0)
inputs = self.prepare([prompt, negative_prompt], x, ref_image=ref_images, ref_image_raw=ref_images_raw)

这里:

x.shape 
torch.Size([2, 16, 78, 52])

ref_images.shape 
torch.Size([2, 16, 78, 52])

ref_images_raw.shape 
torch.Size([2, 3, 624, 416])

之所以搞双份,是因为要一次性将cond和uncond扩散都完成,属于cfg的基操了。 从这时起,batch就一直是2了。第一个数据是带文本编辑指令的扩散,第二个是没有文本编辑指令的扩散。

下面看看prepare函数: 里面,首先对img和ref_image进行了2x2的分块:

img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
ref_img = rearrange(ref_image, "b c (ref_h ph) (ref_w pw) -> b (ref_h ref_w) (c ph pw)", ph=2, pw=2)

做完这一步时, img和ref_img的shape变为:

torch.Size([2, 1014, 64])

64=16x4

1014=78x52/4

这是vit的核心预处理步骤。可参考下面的示意代码:

# 步骤1:图像分块
patches = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=16, pw=16)

# 步骤2:线性投影到embedding空间
patch_embeddings = self.patch_embed(patches)  # (b, num_patches, embed_dim)

# 步骤3:加入位置编码
embeddings = patch_embeddings + self.pos_encoding

# 步骤4:输入Transformer
output = self.transformer(embeddings)

然后是,计算img和ref_image的位置id。

img_ids = torch.zeros(h // 2, w // 2, 3)

img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)

ref_img_ids = torch.zeros(ref_h // 2, ref_w // 2, 3)

ref_img_ids[..., 1] = ref_img_ids[..., 1] + torch.arange(ref_h // 2)[:, None]
ref_img_ids[..., 2] = ref_img_ids[..., 2] + torch.arange(ref_w // 2)[None, :]
ref_img_ids = repeat(ref_img_ids, "ref_h ref_w c -> b (ref_h ref_w) c", b=bs)

因为是2x2的分块,所以这里的位置编码也都是h和w除以2的。这里的位置编码同时编码了行和列。

img和ref_img、img_ids和ref_img_ids最后拼接到一起返回。

再然后,就是使用llm_encoder对prompt和原始图像进行编码。

其返回txt和masks

txt, mask = self.llm_encoder(prompt, ref_image_raw)
txt.shape 
torch.Size([2, 640, 3584])
masks.shape 
torch.Size([2, 640])

对于指令中有文字编辑的部分,就是带引号的字,比如:给图像添加文字“你好”。做了特殊处理,确保每个字都单独进行tokenize。没有使用qwen默认的tokenize结果。如果不单独处理,qwen tokenizer可能将 你好 作为一个单独的token。就丢失了要写的文字的信息了。 请注意,上面的代码中,llm_encoder返回的是文本编辑指令和参考图像的编码,但是却命名为txt和mask。原因是,在往后的处理逻辑中,就是按照文生图的逻辑来了。txt就是代表了“文”,虽然其中包含了图像信息,往后并不care了。

其次,txt中不包含system prompt部分。从下面代码可以看出:

embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
                : self.max_length
]

217就是system prompt的长度。

其次,对于无条件的部分,prompt是空的。这一点也需要注意。 输入的prompt实际上是:

['给这个女生的脖子上戴一个带有红宝石的吊坠。', '']

在扩散过程中。txt也有一个位置编码:

 txt_ids = torch.zeros(bs, txt.shape[1], 3)

保持了和img相同的channel 3。

最后返回的是一个dict:

 {
    "img": img,
    "mask": mask,
    "img_ids": img_ids.to(img.device),
    "llm_embedding": txt.to(img.device),
    "txt_ids": txt_ids.to(img.device),
}

对应shape:

{
    'img': torch.Size([2, 2028, 64]), 
    'mask': torch.Size([2, 640]), 
    'img_ids': torch.Size([2, 2028, 3]), 
    'llm_embedding': torch.Size([2, 640, 3584]), 
    'txt_ids': torch.Size([2, 640, 3])
}

这里的2028来自于将img和ref_image进行了concat。本来是1014,变成了2028。 注意img中的每条数据中,前一半是要扩散到图像的噪声,后一半是参考图像。也就是为了在扩散过程中的每一步都能有所参考。不过这里已经是经过vae的latent图像了。这是和一般的文生图另一个不同的地方。

第一个不同的地方是编码txt时,同时将文本编辑指令和原始参考图通过vlm编码到一起了。注意这里的参考图不是latent图像,而是真实的图像像素信息。

细细品味,可以知道。后面的扩散过程中,txt中有原始像素信息,img中是latent的图像信息。latent和真实像素信息相互交织。

640来自于qwen vl模型对图片和文本编辑指令的编码。

7、去噪

有了上面的5个值,就可以进行去噪了。

with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
    x = self.denoise(
        **inputs,
        cfg_guidance=cfg_guidance,
        timesteps=timesteps,
        show_progress=show_progress,
        timesteps_truncate=1.0,
    )

denoise方法整体:

def denoise(
        self,
        img: torch.Tensor,
        img_ids: torch.Tensor,
        llm_embedding: torch.Tensor,
        txt_ids: torch.Tensor,
        timesteps: list[float],
        cfg_guidance: float = 4.5,
        mask=None,
        show_progress=False,
        timesteps_truncate=1.0,
    ):
        if self.offload:
            self.dit = self.dit.to(self.device)
        if show_progress:
            pbar = tqdm(itertools.pairwise(timesteps), desc='denoising...')
        else:
            pbar = itertools.pairwise(timesteps)
        for t_curr, t_prev in pbar:
            if img.shape[0] == 1 and cfg_guidance != -1:
                img = torch.cat([img, img], dim=0)
            t_vec = torch.full(
                (img.shape[0],), t_curr, dtype=img.dtype, device=img.device
            )

            pred = self.dit(
                img=img,
                img_ids=img_ids,
                txt_ids=txt_ids,
                timesteps=t_vec,
                llm_embedding=llm_embedding,
                t_vec=t_vec,
                mask=mask,
            )
            # txt, vec = self.dit.connector(llm_embedding, t_vec, mask)


            # pred = self.dit(
            #     img=img,
            #     img_ids=img_ids,
            #     txt=txt,
            #     txt_ids=txt_ids,
            #     y=vec,
            #     timesteps=t_vec,
            # )

            if cfg_guidance != -1:
                cond, uncond = (
                    pred[0 : pred.shape[0] // 2, :],
                    pred[pred.shape[0] // 2 :, :],
                )
                if t_curr > timesteps_truncate:
                    diff = cond - uncond
                    diff_norm = torch.norm(diff, dim=(2), keepdim=True)
                    pred = uncond + cfg_guidance * (
                        cond - uncond
                    ) / self.process_diff_norm(diff_norm, k=0.4)
                else:
                    pred = uncond + cfg_guidance * (cond - uncond)
            tem_img = img[0 : img.shape[0] // 2, :] + (t_prev - t_curr) * pred
            img_input_length = img.shape[1] // 2
            img = torch.cat(
                [
                tem_img[:, :img_input_length],
                img[ : img.shape[0] // 2, img_input_length:],
                ], dim=1
            )
        if self.offload:
            self.dit = self.dit.cpu()
            cudagc()

        return img[:, :img.shape[1] // 2]

这段代码有两个细节值得关注:

img = torch.cat(
                [
                tem_img[:, :img_input_length],
                img[ : img.shape[0] // 2, img_input_length:],
                ], dim=1
            )

这段代码,将ref_img恢复了原样,也就是在每次去噪时,噪声图像旁边都有一个不变的参考图像。具体说,就是img[ : img.shape[0] // 2, img_input_length:]这段代码的作用。

第二个细节,后面的每个时间步中,第二条数据中的img都是和第一条数据保持一致的。上面代码中的img的batch为1,在进入下次循环时,会进行复制:

if img.shape[0] == 1 and cfg_guidance != -1:
    img = torch.cat([img, img], dim=0)

这样做的目的是使得在每个时间步中。cond和uncond都在同一起跑线。这才方便对比。

8、dit

剩下就是比较复杂的dit部分。这里我们先把dit的输入进行明确:

pred = self.dit(
    img=img,
    img_ids=img_ids,
    txt_ids=txt_ids,
    timesteps=t_vec,
    llm_embedding=llm_embedding,
    t_vec=t_vec,
    mask=mask,
)

每个输入的信息如下:

img.shape = [2, 2028, 64]
img_ids.shape = [2, 2028, 3]
txt_ids.shape = [2, 640, 3]
timesteps = [0.9805, 0.9805]  # 实际值,会动态变化
t_vec = timesteps
llm_embedding.shape = [2, 640, 3584]
mask.shape = [2, 640]

将在下篇中详细追踪dit的执行流程。


原创文章,转载请注明出处,否则拒绝转载!
本文链接:抬头看浏览器地址栏

上篇: 女儿突然知道关心人了
下篇: Step1X-Edit执行流程(二)