InvokeAI 6.9.0rc3__py3-none-any.whl → 6.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invokeai/app/api/dependencies.py +2 -0
- invokeai/app/api/routers/model_manager.py +91 -2
- invokeai/app/api/routers/workflows.py +9 -0
- invokeai/app/invocations/fields.py +19 -0
- invokeai/app/invocations/flux_denoise.py +15 -1
- invokeai/app/invocations/image_to_latents.py +23 -5
- invokeai/app/invocations/latents_to_image.py +2 -25
- invokeai/app/invocations/metadata.py +9 -1
- invokeai/app/invocations/metadata_linked.py +47 -0
- invokeai/app/invocations/model.py +8 -0
- invokeai/app/invocations/pbr_maps.py +59 -0
- invokeai/app/invocations/primitives.py +12 -0
- invokeai/app/invocations/prompt_template.py +57 -0
- invokeai/app/invocations/z_image_control.py +112 -0
- invokeai/app/invocations/z_image_denoise.py +770 -0
- invokeai/app/invocations/z_image_image_to_latents.py +102 -0
- invokeai/app/invocations/z_image_latents_to_image.py +103 -0
- invokeai/app/invocations/z_image_lora_loader.py +153 -0
- invokeai/app/invocations/z_image_model_loader.py +135 -0
- invokeai/app/invocations/z_image_text_encoder.py +197 -0
- invokeai/app/services/config/config_default.py +3 -1
- invokeai/app/services/model_install/model_install_common.py +14 -1
- invokeai/app/services/model_install/model_install_default.py +119 -19
- invokeai/app/services/model_manager/model_manager_default.py +7 -0
- invokeai/app/services/model_records/model_records_base.py +12 -0
- invokeai/app/services/model_records/model_records_sql.py +17 -0
- invokeai/app/services/shared/graph.py +132 -77
- invokeai/app/services/workflow_records/workflow_records_base.py +8 -0
- invokeai/app/services/workflow_records/workflow_records_sqlite.py +42 -0
- invokeai/app/util/step_callback.py +3 -0
- invokeai/backend/flux/denoise.py +196 -11
- invokeai/backend/flux/schedulers.py +62 -0
- invokeai/backend/image_util/pbr_maps/architecture/block.py +367 -0
- invokeai/backend/image_util/pbr_maps/architecture/pbr_rrdb_net.py +70 -0
- invokeai/backend/image_util/pbr_maps/pbr_maps.py +141 -0
- invokeai/backend/image_util/pbr_maps/utils/image_ops.py +93 -0
- invokeai/backend/model_manager/configs/controlnet.py +47 -1
- invokeai/backend/model_manager/configs/factory.py +26 -1
- invokeai/backend/model_manager/configs/lora.py +79 -1
- invokeai/backend/model_manager/configs/main.py +113 -0
- invokeai/backend/model_manager/configs/qwen3_encoder.py +156 -0
- invokeai/backend/model_manager/load/model_cache/model_cache.py +104 -2
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_diffusers_rms_norm.py +40 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_layer_norm.py +25 -0
- invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +11 -2
- invokeai/backend/model_manager/load/model_loaders/cogview4.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/flux.py +13 -6
- invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +4 -2
- invokeai/backend/model_manager/load/model_loaders/lora.py +11 -0
- invokeai/backend/model_manager/load/model_loaders/onnx.py +1 -0
- invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +2 -1
- invokeai/backend/model_manager/load/model_loaders/z_image.py +969 -0
- invokeai/backend/model_manager/load/model_util.py +6 -1
- invokeai/backend/model_manager/metadata/metadata_base.py +12 -5
- invokeai/backend/model_manager/model_on_disk.py +3 -0
- invokeai/backend/model_manager/starter_models.py +79 -0
- invokeai/backend/model_manager/taxonomy.py +5 -0
- invokeai/backend/model_manager/util/select_hf_files.py +23 -8
- invokeai/backend/patches/layer_patcher.py +34 -16
- invokeai/backend/patches/layers/lora_layer_base.py +2 -1
- invokeai/backend/patches/lora_conversions/flux_aitoolkit_lora_conversion_utils.py +17 -2
- invokeai/backend/patches/lora_conversions/flux_xlabs_lora_conversion_utils.py +92 -0
- invokeai/backend/patches/lora_conversions/formats.py +5 -0
- invokeai/backend/patches/lora_conversions/z_image_lora_constants.py +8 -0
- invokeai/backend/patches/lora_conversions/z_image_lora_conversion_utils.py +189 -0
- invokeai/backend/quantization/gguf/ggml_tensor.py +38 -4
- invokeai/backend/quantization/gguf/loaders.py +47 -12
- invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +13 -0
- invokeai/backend/util/devices.py +25 -0
- invokeai/backend/util/hotfixes.py +2 -2
- invokeai/backend/z_image/__init__.py +16 -0
- invokeai/backend/z_image/extensions/__init__.py +1 -0
- invokeai/backend/z_image/extensions/regional_prompting_extension.py +205 -0
- invokeai/backend/z_image/text_conditioning.py +74 -0
- invokeai/backend/z_image/z_image_control_adapter.py +238 -0
- invokeai/backend/z_image/z_image_control_transformer.py +643 -0
- invokeai/backend/z_image/z_image_controlnet_extension.py +531 -0
- invokeai/backend/z_image/z_image_patchify_utils.py +135 -0
- invokeai/backend/z_image/z_image_transformer_patch.py +234 -0
- invokeai/frontend/web/dist/assets/App-BBELGD-n.js +161 -0
- invokeai/frontend/web/dist/assets/{browser-ponyfill-CN1j0ARZ.js → browser-ponyfill-4xPFTMT3.js} +1 -1
- invokeai/frontend/web/dist/assets/index-vCDSQboA.js +530 -0
- invokeai/frontend/web/dist/index.html +1 -1
- invokeai/frontend/web/dist/locales/de.json +24 -6
- invokeai/frontend/web/dist/locales/en-GB.json +1 -0
- invokeai/frontend/web/dist/locales/en.json +78 -3
- invokeai/frontend/web/dist/locales/es.json +0 -5
- invokeai/frontend/web/dist/locales/fr.json +0 -6
- invokeai/frontend/web/dist/locales/it.json +17 -64
- invokeai/frontend/web/dist/locales/ja.json +379 -44
- invokeai/frontend/web/dist/locales/ru.json +0 -6
- invokeai/frontend/web/dist/locales/vi.json +7 -54
- invokeai/frontend/web/dist/locales/zh-CN.json +0 -6
- invokeai/version/invokeai_version.py +1 -1
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/METADATA +4 -4
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/RECORD +102 -71
- invokeai/frontend/web/dist/assets/App-Cn9UyjoV.js +0 -161
- invokeai/frontend/web/dist/assets/index-BDrf9CL-.js +0 -530
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/WHEEL +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/entry_points.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/licenses/LICENSE +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/licenses/LICENSE-SD1+SD2.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/licenses/LICENSE-SDXL.txt +0 -0
- {invokeai-6.9.0rc3.dist-info → invokeai-6.10.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class ZImageTextConditioning:
|
|
10
|
+
"""Z-Image text conditioning with optional regional mask.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
prompt_embeds: Text embeddings from Qwen3 encoder. Shape: (seq_len, hidden_size).
|
|
14
|
+
mask: Optional binary mask for regional prompting. If None, the prompt is global.
|
|
15
|
+
Shape: (1, 1, img_seq_len) where img_seq_len = (H // patch_size) * (W // patch_size).
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
prompt_embeds: torch.Tensor
|
|
19
|
+
mask: torch.Tensor | None = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class ZImageRegionalTextConditioning:
|
|
24
|
+
"""Container for multiple regional text conditionings concatenated together.
|
|
25
|
+
|
|
26
|
+
In Z-Image, the unified sequence is [img_tokens, txt_tokens], which is different
|
|
27
|
+
from FLUX where it's [txt_tokens, img_tokens]. The attention mask must account for this.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
prompt_embeds: Concatenated text embeddings from all regional prompts.
|
|
31
|
+
Shape: (total_seq_len, hidden_size).
|
|
32
|
+
image_masks: List of binary masks for each regional prompt.
|
|
33
|
+
image_masks[i] corresponds to embedding_ranges[i].
|
|
34
|
+
If None, the prompt is global (applies to entire image).
|
|
35
|
+
Shape: (1, 1, img_seq_len).
|
|
36
|
+
embedding_ranges: List of ranges indicating which portion of prompt_embeds
|
|
37
|
+
corresponds to each regional prompt.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
prompt_embeds: torch.Tensor
|
|
41
|
+
image_masks: list[torch.Tensor | None]
|
|
42
|
+
embedding_ranges: list[Range]
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def from_text_conditionings(
|
|
46
|
+
cls,
|
|
47
|
+
text_conditionings: list[ZImageTextConditioning],
|
|
48
|
+
) -> "ZImageRegionalTextConditioning":
|
|
49
|
+
"""Create a ZImageRegionalTextConditioning from a list of ZImageTextConditioning objects.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
text_conditionings: List of text conditionings, each with optional mask.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
A single ZImageRegionalTextConditioning with concatenated embeddings.
|
|
56
|
+
"""
|
|
57
|
+
concat_embeds: list[torch.Tensor] = []
|
|
58
|
+
concat_ranges: list[Range] = []
|
|
59
|
+
image_masks: list[torch.Tensor | None] = []
|
|
60
|
+
|
|
61
|
+
cur_embed_len = 0
|
|
62
|
+
for tc in text_conditionings:
|
|
63
|
+
concat_embeds.append(tc.prompt_embeds)
|
|
64
|
+
concat_ranges.append(Range(start=cur_embed_len, end=cur_embed_len + tc.prompt_embeds.shape[0]))
|
|
65
|
+
image_masks.append(tc.mask)
|
|
66
|
+
cur_embed_len += tc.prompt_embeds.shape[0]
|
|
67
|
+
|
|
68
|
+
prompt_embeds = torch.cat(concat_embeds, dim=0)
|
|
69
|
+
|
|
70
|
+
return cls(
|
|
71
|
+
prompt_embeds=prompt_embeds,
|
|
72
|
+
image_masks=image_masks,
|
|
73
|
+
embedding_ranges=concat_ranges,
|
|
74
|
+
)
|
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
# Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/videox_fun/models/z_image_transformer2d_control.py
|
|
2
|
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
3
|
+
# Apache License 2.0
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Z-Image Control Adapter for InvokeAI.
|
|
7
|
+
|
|
8
|
+
This module provides a standalone control adapter that can be combined with
|
|
9
|
+
a base ZImageTransformer2DModel at runtime. The adapter contains only the
|
|
10
|
+
control-specific layers (control_layers, control_all_x_embedder, control_noise_refiner).
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import List, Optional
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
import torch.nn as nn
|
|
17
|
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
18
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
19
|
+
from diffusers.models.transformers.transformer_z_image import (
|
|
20
|
+
SEQ_MULTI_OF,
|
|
21
|
+
ZImageTransformerBlock,
|
|
22
|
+
)
|
|
23
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ZImageControlTransformerBlock(ZImageTransformerBlock):
|
|
27
|
+
"""Control-specific transformer block with skip connections for hint generation."""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
layer_id: int,
|
|
32
|
+
dim: int,
|
|
33
|
+
n_heads: int,
|
|
34
|
+
n_kv_heads: int,
|
|
35
|
+
norm_eps: float,
|
|
36
|
+
qk_norm: bool,
|
|
37
|
+
modulation: bool = True,
|
|
38
|
+
block_id: int = 0,
|
|
39
|
+
):
|
|
40
|
+
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
|
|
41
|
+
self.block_id = block_id
|
|
42
|
+
if block_id == 0:
|
|
43
|
+
self.before_proj = nn.Linear(dim, dim)
|
|
44
|
+
nn.init.zeros_(self.before_proj.weight)
|
|
45
|
+
nn.init.zeros_(self.before_proj.bias)
|
|
46
|
+
self.after_proj = nn.Linear(dim, dim)
|
|
47
|
+
nn.init.zeros_(self.after_proj.weight)
|
|
48
|
+
nn.init.zeros_(self.after_proj.bias)
|
|
49
|
+
|
|
50
|
+
def forward(
|
|
51
|
+
self,
|
|
52
|
+
c: torch.Tensor,
|
|
53
|
+
x: torch.Tensor,
|
|
54
|
+
attn_mask: torch.Tensor,
|
|
55
|
+
freqs_cis: torch.Tensor,
|
|
56
|
+
adaln_input: Optional[torch.Tensor] = None,
|
|
57
|
+
) -> torch.Tensor:
|
|
58
|
+
if self.block_id == 0:
|
|
59
|
+
c = self.before_proj(c) + x
|
|
60
|
+
all_c: list[torch.Tensor] = []
|
|
61
|
+
else:
|
|
62
|
+
all_c = list(torch.unbind(c))
|
|
63
|
+
c = all_c.pop(-1)
|
|
64
|
+
|
|
65
|
+
c = super().forward(c, attn_mask=attn_mask, freqs_cis=freqs_cis, adaln_input=adaln_input)
|
|
66
|
+
c_skip = self.after_proj(c)
|
|
67
|
+
all_c += [c_skip, c]
|
|
68
|
+
c = torch.stack(all_c)
|
|
69
|
+
return c
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class ZImageControlAdapter(ModelMixin, ConfigMixin):
|
|
73
|
+
"""Standalone Z-Image Control Adapter.
|
|
74
|
+
|
|
75
|
+
This adapter contains only the control-specific layers and can be combined
|
|
76
|
+
with a base ZImageTransformer2DModel at runtime. It computes control hints
|
|
77
|
+
that are added to the transformer's hidden states.
|
|
78
|
+
|
|
79
|
+
The adapter supports 5 control modes: Canny, HED, Depth, Pose, MLSD.
|
|
80
|
+
Recommended control_context_scale: 0.65-0.80.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
@register_to_config
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
num_control_blocks: int = 6, # Number of control layer blocks
|
|
87
|
+
control_in_dim: int = 16,
|
|
88
|
+
all_patch_size: tuple[int, ...] = (2,),
|
|
89
|
+
all_f_patch_size: tuple[int, ...] = (1,),
|
|
90
|
+
dim: int = 3840,
|
|
91
|
+
n_refiner_layers: int = 2,
|
|
92
|
+
n_heads: int = 30,
|
|
93
|
+
n_kv_heads: int = 30,
|
|
94
|
+
norm_eps: float = 1e-5,
|
|
95
|
+
qk_norm: bool = True,
|
|
96
|
+
):
|
|
97
|
+
super().__init__()
|
|
98
|
+
|
|
99
|
+
self.dim = dim
|
|
100
|
+
self.control_in_dim = control_in_dim
|
|
101
|
+
self.all_patch_size = all_patch_size
|
|
102
|
+
self.all_f_patch_size = all_f_patch_size
|
|
103
|
+
|
|
104
|
+
# Control patch embeddings
|
|
105
|
+
all_x_embedder = {}
|
|
106
|
+
for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size, strict=True):
|
|
107
|
+
x_embedder = nn.Linear(
|
|
108
|
+
f_patch_size * patch_size * patch_size * control_in_dim,
|
|
109
|
+
dim,
|
|
110
|
+
bias=True,
|
|
111
|
+
)
|
|
112
|
+
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
|
113
|
+
|
|
114
|
+
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
|
115
|
+
|
|
116
|
+
# Control noise refiner
|
|
117
|
+
self.control_noise_refiner = nn.ModuleList(
|
|
118
|
+
[
|
|
119
|
+
ZImageTransformerBlock(
|
|
120
|
+
1000 + layer_id,
|
|
121
|
+
dim,
|
|
122
|
+
n_heads,
|
|
123
|
+
n_kv_heads,
|
|
124
|
+
norm_eps,
|
|
125
|
+
qk_norm,
|
|
126
|
+
modulation=True,
|
|
127
|
+
)
|
|
128
|
+
for layer_id in range(n_refiner_layers)
|
|
129
|
+
]
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Control transformer blocks
|
|
133
|
+
self.control_layers = nn.ModuleList(
|
|
134
|
+
[
|
|
135
|
+
ZImageControlTransformerBlock(
|
|
136
|
+
i,
|
|
137
|
+
dim,
|
|
138
|
+
n_heads,
|
|
139
|
+
n_kv_heads,
|
|
140
|
+
norm_eps,
|
|
141
|
+
qk_norm,
|
|
142
|
+
block_id=i,
|
|
143
|
+
)
|
|
144
|
+
for i in range(num_control_blocks)
|
|
145
|
+
]
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Padding token for control context
|
|
149
|
+
self.x_pad_token = nn.Parameter(torch.empty(dim))
|
|
150
|
+
nn.init.normal_(self.x_pad_token, std=0.02)
|
|
151
|
+
|
|
152
|
+
def forward(
|
|
153
|
+
self,
|
|
154
|
+
control_context: List[torch.Tensor],
|
|
155
|
+
unified_hidden_states: torch.Tensor,
|
|
156
|
+
cap_feats: torch.Tensor,
|
|
157
|
+
timestep_emb: torch.Tensor,
|
|
158
|
+
attn_mask: torch.Tensor,
|
|
159
|
+
freqs_cis: torch.Tensor,
|
|
160
|
+
rope_embedder,
|
|
161
|
+
patchify_fn,
|
|
162
|
+
patch_size: int = 2,
|
|
163
|
+
f_patch_size: int = 1,
|
|
164
|
+
) -> tuple[torch.Tensor, ...]:
|
|
165
|
+
"""Compute control hints from control context.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
control_context: List of control image latents [C, 1, H, W]
|
|
169
|
+
unified_hidden_states: Combined image+caption embeddings from main path
|
|
170
|
+
cap_feats: Caption feature embeddings
|
|
171
|
+
timestep_emb: Timestep embeddings
|
|
172
|
+
attn_mask: Attention mask
|
|
173
|
+
freqs_cis: RoPE frequencies
|
|
174
|
+
rope_embedder: RoPE embedder from base model
|
|
175
|
+
patchify_fn: Patchify function from base model
|
|
176
|
+
patch_size: Spatial patch size
|
|
177
|
+
f_patch_size: Frame patch size
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Tuple of hint tensors to be added at each control layer position
|
|
181
|
+
"""
|
|
182
|
+
bsz = len(control_context)
|
|
183
|
+
device = control_context[0].device
|
|
184
|
+
|
|
185
|
+
# Patchify control context using base model's patchify
|
|
186
|
+
(
|
|
187
|
+
control_context_patches,
|
|
188
|
+
x_size,
|
|
189
|
+
x_pos_ids,
|
|
190
|
+
x_inner_pad_mask,
|
|
191
|
+
) = patchify_fn(control_context, patch_size, f_patch_size, cap_feats.size(1))
|
|
192
|
+
|
|
193
|
+
# Embed control context
|
|
194
|
+
x_item_seqlens = [len(_) for _ in control_context_patches]
|
|
195
|
+
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
|
196
|
+
x_max_item_seqlen = max(x_item_seqlens)
|
|
197
|
+
|
|
198
|
+
control_context_cat = torch.cat(control_context_patches, dim=0)
|
|
199
|
+
control_context_cat = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context_cat)
|
|
200
|
+
|
|
201
|
+
# Match timestep dtype
|
|
202
|
+
adaln_input = timestep_emb.type_as(control_context_cat)
|
|
203
|
+
control_context_cat[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
|
204
|
+
control_context_list = list(control_context_cat.split(x_item_seqlens, dim=0))
|
|
205
|
+
x_freqs_cis = list(rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
|
206
|
+
|
|
207
|
+
control_context_padded = pad_sequence(control_context_list, batch_first=True, padding_value=0.0)
|
|
208
|
+
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
|
209
|
+
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
|
210
|
+
for i, seq_len in enumerate(x_item_seqlens):
|
|
211
|
+
x_attn_mask[i, :seq_len] = 1
|
|
212
|
+
|
|
213
|
+
# Refine control context
|
|
214
|
+
for layer in self.control_noise_refiner:
|
|
215
|
+
control_context_padded = layer(control_context_padded, x_attn_mask, x_freqs_cis, adaln_input)
|
|
216
|
+
|
|
217
|
+
# Unify with caption features
|
|
218
|
+
cap_item_seqlens = [cap_feats.size(1)] * bsz
|
|
219
|
+
control_context_unified = []
|
|
220
|
+
for i in range(bsz):
|
|
221
|
+
x_len = x_item_seqlens[i]
|
|
222
|
+
cap_len = cap_item_seqlens[i]
|
|
223
|
+
control_context_unified.append(torch.cat([control_context_padded[i][:x_len], cap_feats[i][:cap_len]]))
|
|
224
|
+
control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
|
|
225
|
+
c = control_context_unified
|
|
226
|
+
|
|
227
|
+
# Process through control layers
|
|
228
|
+
for layer in self.control_layers:
|
|
229
|
+
c = layer(
|
|
230
|
+
c,
|
|
231
|
+
x=unified_hidden_states,
|
|
232
|
+
attn_mask=attn_mask,
|
|
233
|
+
freqs_cis=freqs_cis,
|
|
234
|
+
adaln_input=adaln_input,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
hints = torch.unbind(c)[:-1]
|
|
238
|
+
return hints
|