本文是Step1X-Edit执行流程的第二篇,主要介绍去噪过程。
进入denoise方法后,首先就是connector:
txt, y = self.connector(
llm_embedding, t_vec, mask
)
输入的shape:
t_vec = timesteps = [0.9805, 0.9805] # 实际值,会动态变化
llm_embedding.shape = [2, 640, 3584]
mask.shape = [2, 640]
输出的shape:
txt.shape = [2, 640, 4096]
y.shape = [2, 768]
txt中融入了时间信息。
其中y的获取比较简单:
def forward(self, x,t,mask):
t = t * 1000 # fix the times embedding bug
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
x_mean = (x * mask_float).sum(
dim=1
) / mask_float.sum(dim=1) * (1 + self.scale_factor.to(x.dtype))
global_out=self.global_proj_out(x_mean)
encoder_hidden_states = self.S(x,t,mask)
return encoder_hidden_states,global_out
y就是上面代码中的global_out,文本token的平均值,然后通过一个linear层,将维度从3584降到768。 需要时刻牢记,这里所谓的文本token,是vlm的输出,包含了文本编辑指令和图片信息。
encoder_hidden_states 是SingleTokenRefiner的输出。大致就是将时间信息融入文本token中,并将token的维度从3584升级到4096。
其内部调用:
# class SingleTokenRefiner
def forward(
self,
x: torch.Tensor,
t: torch.LongTensor,
mask: Optional[torch.LongTensor] = None,
y: torch.LongTensor=None,
):
timestep_aware_representations = self.t_embedder(t)
if mask is None:
context_aware_representations = x.mean(dim=1)
else:
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
context_aware_representations = (x * mask_float).sum(
dim=1
) / mask_float.sum(dim=1)
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations
x = self.input_embedder(x)
if self.need_CA:
y = self.input_embedder_CA(y)
x = self.individual_token_refiner(x, c, mask, y)
else:
x = self.individual_token_refiner(x, c, mask)
return x
这里面就是将时间t做嵌入,然后对文本token取平均,映射到和时间一个维度,然后相加作为总的context c。c的维度是:[2,4096] 时间嵌入部分就是做了一个正余弦位置编码,然后通过一个linear层,将维度从256升到4096。
x = self.input_embedder(x) 这里也是将x的维度从3584升到4096。 这样x,c的hidden_state维度都是4096。
就传入individual_token_refiner了:
x = self.individual_token_refiner(x, c, mask)
进入到里面:
def forward(
self,
x: torch.Tensor,
c: torch.LongTensor,
mask: Optional[torch.Tensor] = None,
y:torch.Tensor=None,
):
self_attn_mask = None
if mask is not None:
batch_size = mask.shape[0]
seq_len = mask.shape[1]
mask = mask.to(x.device)
# batch_size x 1 x seq_len x seq_len
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
1, 1, seq_len, 1
)
# batch_size x 1 x seq_len x seq_len
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
# avoids self-attention weight being NaN for padding tokens
self_attn_mask[:, :, :, 0] = True
for block in self.blocks:
x = block(x, c, self_attn_mask,y)
return x
前面处理mask。将其从[2, 640] 变换到:[2,1,640,640]
然后就是循环调用blocks中的每个block:
def forward(
self,
x: torch.Tensor,
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
attn_mask: torch.Tensor = None,
y: torch.Tensor = None,
):
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn_qkv(norm_x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
# Apply QK-Norm if needed
q = self.self_attn_q_norm(q).to(v)
k = self.self_attn_k_norm(k).to(v)
# Self-Attention
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
if self.need_CA:
x = self.cross_attnblock(x, c, attn_mask, y)
# FFN Layer
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
return x
上面的代码主要逻辑就三步:
c里面融合了时间信息和全局信息,应该是想使用这些特征,对token的值进行一个精练。
这个模块叫IndividualTokenRefinerBlock,也体现了其作用。
这样的block循环了两次。
最终得到:encoder_hidden_states
为什么这里要加入时间信息对文本token进行精练呢?
一个容易想到的解释,就是不同的时间步中,希望模型关注到文本token中不同部分的特征。
接下来的重头操作就是19层的DoubleStreamBlock变换和38层的SingleStreamBlock变换。将文本和扩散图片信息进行充分的融合。
这里,我比较关注的一点是:ref_image的信息是否得到了特殊的处理。 为什么关注这一点? 因为在使用时,发现step1x-edit经常容易修改指令范围外的内容。我怀疑是ref_image没有得到特殊的重视。 事实确实也是这样,后续噪声图x和ref_image的交互主要是通过注意力进行交互,并没有特殊对待。
img = self.img_in(img)
vec = self.time_in(self.timestep_embedding(timesteps, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
这些是进入DoubleStreamBlock前的变换工作。vec有点像connector中的c,融合了token全局信息和时间信息。pe是位置编码。
下面是DoubleStreamBlock内部:
def _forward(
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor
) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=1)
k = torch.cat((txt_k, img_k), dim=1)
v = torch.cat((txt_v, img_v), dim=1)
attn = attention_after_rope(q, k, v, pe, self.mode)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img_mlp = self.img_mlp(
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
)
img = scale_add_residual(img_mlp, img_mod2.gate, img)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt_mlp = self.txt_mlp(
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
)
txt = scale_add_residual(txt_mlp, txt_mod2.gate, txt)
return img, txt
首先,确定下,入参的shape:
img.shape = [2, 2028, 3072]
txt.shape = [2, 640, 3072]
vec.shape = [2, 3072]
pe.shape = [2, 1, 2668, 64, 2, 2]
好奇这里为什么不需要txt的mask了。这里确实是有问题,不过据说这样的操作,会让guidance更加稳定。也是很迷的一件事儿。
pe中的2668 = 2028 + 640
这里的pe的shape乍看起来有点怪,其实,如果记得之前64是2x2的分块,就容易理解这里的shape。
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
每个mod中有三部分,shift,scale,gate。shape都是[2, 1, 3072] 根据名字就可以猜出后面的用处。
这里,我们停下来,关注下img。之前的代码显示,这个img是x和ref_img拼接到一起的。shape是[2,2028,64]。其中2028维度,一半是x,要去噪得到的图像,一半是ref_img,是vae encode出来的结果。 最后的64维度,是16 channel 加2x2分块的结果。上篇文章中讲过了。
这里,第一次对最后的64进行变换,变成了3072。
img = self.img_in(img)
然后进行LN,scale和shift,拆分成qkv:
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
文本也做类似的操作。然后就是自注意力。scale shift gate一顿残差连接。
attn = attention_after_rope(q, k, v, pe, self.mode)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img_mlp = self.img_mlp(
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
)
img = scale_add_residual(img_mlp, img_mod2.gate, img)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt_mlp = self.txt_mlp(
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
)
txt = scale_add_residual(txt_mlp, txt_mod2.gate, txt)
return img, txt
这里的双流,就是文本和图像各自做各自的注意力,残差连接。彼此之间没有做交互。
19层的DoubleStreamBlock后,就是SingleStreamBlock了。
img = torch.cat((txt, img), 1)
先把两个concat到一起,此时img的shape:
img.shape = [2, 2668, 3072]
单流部分比较简单:
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
# compute attention
attn = attention_after_rope(q, k, v, pe, self.mode)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return scale_add_residual(output, mod.gate, x)
整体操作和双流中差不多。
过完单流之后。将img部分单独拆出来,准备过最后的layer。
img = img[:, txt.shape[1] :, ...]
final layer的处理也比较简单:
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x
至此,我们就分析完了Step1X-Edit的所有执行过程。 最后,女孩成功戴上了项链。