Step1X-Edit执行流程(二)


本文是Step1X-Edit执行流程的第二篇,主要介绍去噪过程。

1、connector

进入denoise方法后,首先就是connector:

txt, y = self.connector(
    llm_embedding, t_vec, mask
)

输入的shape:

t_vec = timesteps = [0.9805, 0.9805]  # 实际值,会动态变化
llm_embedding.shape = [2, 640, 3584]
mask.shape = [2, 640]

输出的shape:

txt.shape = [2, 640, 4096]
y.shape = [2, 768]

txt中融入了时间信息。

其中y的获取比较简单:

def forward(self, x,t,mask):
    t = t * 1000 # fix the times embedding bug
    mask_float = mask.unsqueeze(-1)  # [b, s1, 1]
    x_mean = (x * mask_float).sum(
            dim=1
        ) / mask_float.sum(dim=1) * (1 + self.scale_factor.to(x.dtype))

    global_out=self.global_proj_out(x_mean)
    encoder_hidden_states = self.S(x,t,mask)
    return encoder_hidden_states,global_out

y就是上面代码中的global_out,文本token的平均值,然后通过一个linear层,将维度从3584降到768。 需要时刻牢记,这里所谓的文本token,是vlm的输出,包含了文本编辑指令和图片信息。

encoder_hidden_states 是SingleTokenRefiner的输出。大致就是将时间信息融入文本token中,并将token的维度从3584升级到4096。

其内部调用:

# class SingleTokenRefiner
def forward(
        self,
        x: torch.Tensor,
        t: torch.LongTensor,
        mask: Optional[torch.LongTensor] = None,
        y: torch.LongTensor=None,
    ):
        timestep_aware_representations = self.t_embedder(t)

        if mask is None:
            context_aware_representations = x.mean(dim=1)
        else:
            mask_float = mask.unsqueeze(-1)  # [b, s1, 1]
            context_aware_representations = (x * mask_float).sum(
                dim=1
            ) / mask_float.sum(dim=1)
        context_aware_representations = self.c_embedder(context_aware_representations)
        c = timestep_aware_representations + context_aware_representations

        x = self.input_embedder(x)
        if self.need_CA:
            y = self.input_embedder_CA(y)
            x = self.individual_token_refiner(x, c, mask, y)
        else:
            x = self.individual_token_refiner(x, c, mask)

        return x

这里面就是将时间t做嵌入,然后对文本token取平均,映射到和时间一个维度,然后相加作为总的context c。c的维度是:[2,4096] 时间嵌入部分就是做了一个正余弦位置编码,然后通过一个linear层,将维度从256升到4096。

x = self.input_embedder(x) 这里也是将x的维度从3584升到4096。 这样x,c的hidden_state维度都是4096。

就传入individual_token_refiner了:

x = self.individual_token_refiner(x, c, mask)

进入到里面:

def forward(
        self,
        x: torch.Tensor,
        c: torch.LongTensor,
        mask: Optional[torch.Tensor] = None,
        y:torch.Tensor=None,
    ):
        self_attn_mask = None
        if mask is not None:
            batch_size = mask.shape[0]
            seq_len = mask.shape[1]
            mask = mask.to(x.device)
            # batch_size x 1 x seq_len x seq_len
            self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
                1, 1, seq_len, 1
            )
            # batch_size x 1 x seq_len x seq_len
            self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
            # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
            self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
            # avoids self-attention weight being NaN for padding tokens
            self_attn_mask[:, :, :, 0] = True
        
        
        for block in self.blocks:
            x = block(x, c, self_attn_mask,y)

        return x

前面处理mask。将其从[2, 640] 变换到:[2,1,640,640]

然后就是循环调用blocks中的每个block:

 def forward(
        self,
        x: torch.Tensor,
        c: torch.Tensor,  # timestep_aware_representations + context_aware_representations
        attn_mask: torch.Tensor = None,
        y: torch.Tensor = None,
    ):
        gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)

        norm_x = self.norm1(x)
        qkv = self.self_attn_qkv(norm_x)
        q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
        # Apply QK-Norm if needed
        q = self.self_attn_q_norm(q).to(v)
        k = self.self_attn_k_norm(k).to(v)

        # Self-Attention
        attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)

        x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
        
        if self.need_CA:
            x = self.cross_attnblock(x, c, attn_mask, y)

        # FFN Layer
        x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)

        return x

上面的代码主要逻辑就三步:

  • 计算x的自注意力
  • 将自注意力乘以以c为条件的gate,加到x上。
  • 将x进行norm和mlp后,再次乘以以c为条件的gate,加到x上。

c里面融合了时间信息和全局信息,应该是想使用这些特征,对token的值进行一个精练。

这个模块叫IndividualTokenRefinerBlock,也体现了其作用。

这样的block循环了两次。

最终得到:encoder_hidden_states

为什么这里要加入时间信息对文本token进行精练呢?

一个容易想到的解释,就是不同的时间步中,希望模型关注到文本token中不同部分的特征。

2、DoubleStreamBlock和SingleStreamBlock

接下来的重头操作就是19层的DoubleStreamBlock变换和38层的SingleStreamBlock变换。将文本和扩散图片信息进行充分的融合。

这里,我比较关注的一点是:ref_image的信息是否得到了特殊的处理。 为什么关注这一点? 因为在使用时,发现step1x-edit经常容易修改指令范围外的内容。我怀疑是ref_image没有得到特殊的重视。 事实确实也是这样,后续噪声图x和ref_image的交互主要是通过注意力进行交互,并没有特殊对待。

img = self.img_in(img)
vec = self.time_in(self.timestep_embedding(timesteps, 256))

vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)

这些是进入DoubleStreamBlock前的变换工作。vec有点像connector中的c,融合了token全局信息和时间信息。pe是位置编码。

下面是DoubleStreamBlock内部:

def _forward(
        self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
    ) -> tuple[Tensor, Tensor]:
        img_mod1, img_mod2 = self.img_mod(vec)
        txt_mod1, txt_mod2 = self.txt_mod(vec)

        # prepare image for attention
        img_modulated = self.img_norm1(img)
        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
        img_qkv = self.img_attn.qkv(img_modulated)
        img_q, img_k, img_v = rearrange(
            img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
        )
        img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)

        # prepare txt for attention
        txt_modulated = self.txt_norm1(txt)
        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
        txt_qkv = self.txt_attn.qkv(txt_modulated)
        txt_q, txt_k, txt_v = rearrange(
            txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
        )
        txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)

        # run actual attention
        q = torch.cat((txt_q, img_q), dim=1)
        k = torch.cat((txt_k, img_k), dim=1)
        v = torch.cat((txt_v, img_v), dim=1)

        attn = attention_after_rope(q, k, v, pe, self.mode)
        txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]

        # calculate the img bloks
        img = img + img_mod1.gate * self.img_attn.proj(img_attn)
        img_mlp = self.img_mlp(
            (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
        )
        img = scale_add_residual(img_mlp, img_mod2.gate, img)

        # calculate the txt bloks
        txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
        txt_mlp = self.txt_mlp(
            (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
        )
        txt = scale_add_residual(txt_mlp, txt_mod2.gate, txt)
        return img, txt

首先,确定下,入参的shape:

img.shape = [2, 2028, 3072]
txt.shape = [2, 640, 3072]
vec.shape = [2, 3072]
pe.shape = [2, 1, 2668, 64, 2, 2]

好奇这里为什么不需要txt的mask了。这里确实是有问题,不过据说这样的操作,会让guidance更加稳定。也是很迷的一件事儿。

pe中的2668 = 2028 + 640

这里的pe的shape乍看起来有点怪,其实,如果记得之前64是2x2的分块,就容易理解这里的shape。

img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)

每个mod中有三部分,shift,scale,gate。shape都是[2, 1, 3072] 根据名字就可以猜出后面的用处。

这里,我们停下来,关注下img。之前的代码显示,这个img是x和ref_img拼接到一起的。shape是[2,2028,64]。其中2028维度,一半是x,要去噪得到的图像,一半是ref_img,是vae encode出来的结果。 最后的64维度,是16 channel 加2x2分块的结果。上篇文章中讲过了。

这里,第一次对最后的64进行变换,变成了3072。

img = self.img_in(img)

然后进行LN,scale和shift,拆分成qkv:

# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(
    img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)

文本也做类似的操作。然后就是自注意力。scale shift gate一顿残差连接。

attn = attention_after_rope(q, k, v, pe, self.mode)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]

# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img_mlp = self.img_mlp(
    (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
)
img = scale_add_residual(img_mlp, img_mod2.gate, img)

# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt_mlp = self.txt_mlp(
    (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
)
txt = scale_add_residual(txt_mlp, txt_mod2.gate, txt)
return img, txt

这里的双流,就是文本和图像各自做各自的注意力,残差连接。彼此之间没有做交互。

19层的DoubleStreamBlock后,就是SingleStreamBlock了。

img = torch.cat((txt, img), 1)

先把两个concat到一起,此时img的shape:

img.shape = [2, 2668, 3072]

单流部分比较简单:

def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
    mod, _ = self.modulation(vec)
    x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
    qkv, mlp = torch.split(
        self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
    )

    q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
    q, k = self.norm(q, k, v)

    # compute attention
    attn = attention_after_rope(q, k, v, pe, self.mode)
    # compute activation in mlp stream, cat again and run second linear layer
    output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
    return scale_add_residual(output, mod.gate, x)

整体操作和双流中差不多。

3、final layer

过完单流之后。将img部分单独拆出来,准备过最后的layer。

img = img[:, txt.shape[1] :, ...]

final layer的处理也比较简单:

def forward(self, x: Tensor, vec: Tensor) -> Tensor:
    shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
    x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
    x = self.linear(x)
    return x

至此,我们就分析完了Step1X-Edit的所有执行过程。 最后,女孩成功戴上了项链。

20250625175953


原创文章,转载请注明出处,否则拒绝转载!
本文链接:抬头看浏览器地址栏

上篇: Step1X-Edit执行流程(一)