ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl → 0.3.0.dev20240926__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/__init__.py +5 -4
- ai_edge_torch/_convert/conversion.py +112 -0
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/{convert → _convert}/converter.py +94 -48
- ai_edge_torch/_convert/fx_passes/__init__.py +22 -0
- ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +107 -44
- ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +23 -20
- ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +5 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/__init__.py +1 -1
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +39 -9
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +17 -8
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +9 -8
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +31 -18
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +2 -2
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +34 -24
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
- ai_edge_torch/_convert/signature.py +66 -0
- ai_edge_torch/_convert/test/test_convert.py +495 -0
- ai_edge_torch/_convert/test/test_convert_composites.py +234 -0
- ai_edge_torch/_convert/test/test_convert_multisig.py +189 -0
- ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -5
- ai_edge_torch/{convert → _convert}/to_channel_last_io.py +10 -3
- ai_edge_torch/config.py +27 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/culprit.py +72 -40
- ai_edge_torch/debug/test/test_culprit.py +7 -5
- ai_edge_torch/debug/test/test_search_model.py +8 -7
- ai_edge_torch/debug/utils.py +14 -3
- ai_edge_torch/fx_pass_base.py +101 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/gemma/{gemma.py → gemma1.py} +69 -55
- ai_edge_torch/generative/examples/gemma/gemma2.py +267 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py +56 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +57 -0
- ai_edge_torch/generative/examples/gemma/verify_util.py +143 -0
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +206 -0
- ai_edge_torch/generative/examples/openelm/verify.py +64 -0
- ai_edge_torch/generative/examples/phi/__init__.py +14 -0
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/{phi2 → phi}/phi2.py +70 -51
- ai_edge_torch/generative/examples/phi/phi3.py +286 -0
- ai_edge_torch/generative/examples/phi/verify.py +65 -0
- ai_edge_torch/generative/examples/phi/verify_phi3.py +70 -0
- ai_edge_torch/generative/examples/smollm/__init__.py +14 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +101 -0
- ai_edge_torch/generative/examples/smollm/verify.py +62 -0
- ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -13
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +27 -14
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +74 -9
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +179 -37
- ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +83 -58
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
- ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
- ai_edge_torch/generative/examples/t5/convert_to_tflite.py +28 -25
- ai_edge_torch/generative/examples/t5/t5.py +208 -159
- ai_edge_torch/generative/examples/t5/t5_attention.py +45 -30
- ai_edge_torch/generative/examples/test_models/convert_toy_model.py +105 -0
- ai_edge_torch/generative/examples/test_models/toy_model.py +69 -41
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +50 -64
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +41 -39
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +67 -54
- ai_edge_torch/generative/examples/tiny_llama/verify.py +64 -0
- ai_edge_torch/generative/fx_passes/__init__.py +4 -5
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +10 -7
- ai_edge_torch/generative/layers/attention.py +141 -102
- ai_edge_torch/generative/layers/attention_utils.py +53 -12
- ai_edge_torch/generative/layers/builder.py +37 -7
- ai_edge_torch/generative/layers/feed_forward.py +39 -14
- ai_edge_torch/generative/layers/kv_cache.py +162 -50
- ai_edge_torch/generative/layers/model_config.py +84 -30
- ai_edge_torch/generative/layers/normalization.py +185 -7
- ai_edge_torch/generative/layers/rotary_position_embedding.py +6 -4
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +48 -21
- ai_edge_torch/generative/layers/unet/blocks_2d.py +136 -77
- ai_edge_torch/generative/layers/unet/builder.py +7 -4
- ai_edge_torch/generative/layers/unet/model_config.py +17 -15
- ai_edge_torch/generative/quantize/example.py +7 -8
- ai_edge_torch/generative/quantize/quant_recipe.py +10 -7
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -1
- ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
- ai_edge_torch/generative/test/test_kv_cache.py +120 -0
- ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +9 -7
- ai_edge_torch/generative/test/test_model_conversion.py +124 -188
- ai_edge_torch/generative/test/test_model_conversion_large.py +251 -0
- ai_edge_torch/generative/test/test_quantize.py +76 -60
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/converter.py +82 -0
- ai_edge_torch/generative/utilities/loader.py +120 -57
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +165 -57
- ai_edge_torch/generative/utilities/t5_loader.py +110 -81
- ai_edge_torch/generative/utilities/verifier.py +247 -0
- ai_edge_torch/hlfb/__init__.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -7
- ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +39 -30
- ai_edge_torch/hlfb/test/test_mark_pattern.py +46 -20
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +24 -11
- ai_edge_torch/lowertools/__init__.py +18 -0
- ai_edge_torch/lowertools/_shim.py +80 -0
- ai_edge_torch/lowertools/common_utils.py +142 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +255 -0
- ai_edge_torch/lowertools/test_utils.py +60 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +284 -0
- ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +29 -14
- ai_edge_torch/model.py +53 -18
- ai_edge_torch/odml_torch/__init__.py +20 -0
- ai_edge_torch/odml_torch/_torch_future.py +61 -0
- ai_edge_torch/odml_torch/_torch_library.py +19 -0
- ai_edge_torch/odml_torch/composite/__init__.py +16 -0
- ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
- ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
- ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
- ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
- ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
- ai_edge_torch/odml_torch/export.py +357 -0
- ai_edge_torch/odml_torch/export_utils.py +168 -0
- ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +150 -0
- ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +25 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +258 -0
- ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +241 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +252 -0
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/odml_torch/lowerings/context.py +42 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +96 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
- ai_edge_torch/odml_torch/passes/__init__.py +38 -0
- ai_edge_torch/odml_torch/tf_integration.py +194 -0
- ai_edge_torch/quantize/pt2e_quantizer.py +52 -24
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +43 -23
- ai_edge_torch/quantize/quant_config.py +13 -9
- ai_edge_torch/testing/model_coverage/model_coverage.py +29 -16
- ai_edge_torch/version.py +16 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/METADATA +7 -3
- ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/RECORD +177 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/WHEEL +1 -1
- ai_edge_torch/convert/conversion.py +0 -117
- ai_edge_torch/convert/conversion_utils.py +0 -400
- ai_edge_torch/convert/fx_passes/__init__.py +0 -59
- ai_edge_torch/convert/fx_passes/_pass_base.py +0 -49
- ai_edge_torch/convert/fx_passes/canonicalize_pass.py +0 -37
- ai_edge_torch/convert/test/test_convert.py +0 -311
- ai_edge_torch/convert/test/test_convert_composites.py +0 -192
- ai_edge_torch/convert/test/test_convert_multisig.py +0 -139
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +0 -66
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -64
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -161
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
- ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +0 -121
- /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
- /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{phi2 → openelm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240714.dist-info → ai_edge_torch_nightly-0.3.0.dev20240926.dist-info}/top_level.txt +0 -0
@@ -16,11 +16,10 @@
|
|
16
16
|
from dataclasses import dataclass
|
17
17
|
from typing import Dict, List, Optional, Tuple
|
18
18
|
|
19
|
-
import torch
|
20
|
-
|
21
19
|
import ai_edge_torch.generative.layers.model_config as layers_config
|
22
20
|
import ai_edge_torch.generative.layers.unet.model_config as unet_config
|
23
21
|
import ai_edge_torch.generative.utilities.loader as loader
|
22
|
+
import torch
|
24
23
|
|
25
24
|
|
26
25
|
@dataclass
|
@@ -80,27 +79,35 @@ class TransformerBlockTensorNames:
|
|
80
79
|
class MidBlockTensorNames:
|
81
80
|
residual_block_tensor_names: List[ResidualBlockTensorNames]
|
82
81
|
attention_block_tensor_names: Optional[List[AttentionBlockTensorNames]] = None
|
83
|
-
transformer_block_tensor_names: Optional[
|
82
|
+
transformer_block_tensor_names: Optional[
|
83
|
+
List[TransformerBlockTensorNames]
|
84
|
+
] = None
|
84
85
|
|
85
86
|
|
86
87
|
@dataclass
|
87
88
|
class DownEncoderBlockTensorNames:
|
88
89
|
residual_block_tensor_names: List[ResidualBlockTensorNames]
|
89
|
-
transformer_block_tensor_names: Optional[
|
90
|
+
transformer_block_tensor_names: Optional[
|
91
|
+
List[TransformerBlockTensorNames]
|
92
|
+
] = None
|
90
93
|
downsample_conv: str = None
|
91
94
|
|
92
95
|
|
93
96
|
@dataclass
|
94
97
|
class UpDecoderBlockTensorNames:
|
95
98
|
residual_block_tensor_names: List[ResidualBlockTensorNames]
|
96
|
-
transformer_block_tensor_names: Optional[
|
99
|
+
transformer_block_tensor_names: Optional[
|
100
|
+
List[TransformerBlockTensorNames]
|
101
|
+
] = None
|
97
102
|
upsample_conv: str = None
|
98
103
|
|
99
104
|
|
100
105
|
@dataclass
|
101
106
|
class SkipUpDecoderBlockTensorNames:
|
102
107
|
residual_block_tensor_names: List[ResidualBlockTensorNames]
|
103
|
-
transformer_block_tensor_names: Optional[
|
108
|
+
transformer_block_tensor_names: Optional[
|
109
|
+
List[TransformerBlockTensorNames]
|
110
|
+
] = None
|
104
111
|
upsample_conv: str = None
|
105
112
|
|
106
113
|
|
@@ -119,7 +126,9 @@ def _map_to_converted_state(
|
|
119
126
|
converted_state[f"{converted_state_param}.weight"]
|
120
127
|
)
|
121
128
|
if f"{state_param}.bias" in state:
|
122
|
-
converted_state[f"{converted_state_param}.bias"] = state.pop(
|
129
|
+
converted_state[f"{converted_state_param}.bias"] = state.pop(
|
130
|
+
f"{state_param}.bias"
|
131
|
+
)
|
123
132
|
if squeeze_dims:
|
124
133
|
converted_state[f"{converted_state_param}.bias"] = torch.squeeze(
|
125
134
|
converted_state[f"{converted_state_param}.bias"]
|
@@ -220,25 +229,41 @@ class BaseLoader(loader.ModelLoader):
|
|
220
229
|
f"{attention_layer_prefix}.v_projection",
|
221
230
|
squeeze_dims=True,
|
222
231
|
)
|
223
|
-
converted_state[f"{attention_layer_prefix}.qkv_projection.weight"] =
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
232
|
+
converted_state[f"{attention_layer_prefix}.qkv_projection.weight"] = (
|
233
|
+
torch.concat(
|
234
|
+
[
|
235
|
+
converted_state[
|
236
|
+
f"{attention_layer_prefix}.q_projection.weight"
|
237
|
+
],
|
238
|
+
converted_state[
|
239
|
+
f"{attention_layer_prefix}.k_projection.weight"
|
240
|
+
],
|
241
|
+
converted_state[
|
242
|
+
f"{attention_layer_prefix}.v_projection.weight"
|
243
|
+
],
|
244
|
+
],
|
245
|
+
axis=0,
|
246
|
+
)
|
230
247
|
)
|
231
248
|
del converted_state[f"{attention_layer_prefix}.q_projection.weight"]
|
232
249
|
del converted_state[f"{attention_layer_prefix}.k_projection.weight"]
|
233
250
|
del converted_state[f"{attention_layer_prefix}.v_projection.weight"]
|
234
251
|
if config.attention_config.qkv_use_bias:
|
235
|
-
converted_state[f"{attention_layer_prefix}.qkv_projection.bias"] =
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
252
|
+
converted_state[f"{attention_layer_prefix}.qkv_projection.bias"] = (
|
253
|
+
torch.concat(
|
254
|
+
[
|
255
|
+
converted_state[
|
256
|
+
f"{attention_layer_prefix}.q_projection.bias"
|
257
|
+
],
|
258
|
+
converted_state[
|
259
|
+
f"{attention_layer_prefix}.k_projection.bias"
|
260
|
+
],
|
261
|
+
converted_state[
|
262
|
+
f"{attention_layer_prefix}.v_projection.bias"
|
263
|
+
],
|
264
|
+
],
|
265
|
+
axis=0,
|
266
|
+
)
|
242
267
|
)
|
243
268
|
del converted_state[f"{attention_layer_prefix}.q_projection.bias"]
|
244
269
|
del converted_state[f"{attention_layer_prefix}.k_projection.bias"]
|
@@ -316,11 +341,17 @@ class BaseLoader(loader.ModelLoader):
|
|
316
341
|
)
|
317
342
|
else:
|
318
343
|
_map_to_converted_state(
|
319
|
-
state,
|
344
|
+
state,
|
345
|
+
tensor_names.w1,
|
346
|
+
converted_state,
|
347
|
+
f"{converted_state_param_prefix}.w1",
|
320
348
|
)
|
321
349
|
|
322
350
|
_map_to_converted_state(
|
323
|
-
state,
|
351
|
+
state,
|
352
|
+
tensor_names.w2,
|
353
|
+
converted_state,
|
354
|
+
f"{converted_state_param_prefix}.w2",
|
324
355
|
)
|
325
356
|
|
326
357
|
def _map_transformer_block(
|
@@ -381,6 +412,7 @@ class BaseLoader(loader.ModelLoader):
|
|
381
412
|
):
|
382
413
|
residual_block_config = unet_config.ResidualBlock2DConfig(
|
383
414
|
in_channels=config.in_channels,
|
415
|
+
hidden_channels=config.in_channels,
|
384
416
|
out_channels=config.in_channels,
|
385
417
|
time_embedding_channels=config.time_embedding_channels,
|
386
418
|
normalization_config=config.normalization_config,
|
@@ -435,6 +467,7 @@ class BaseLoader(loader.ModelLoader):
|
|
435
467
|
f"{converted_state_param_prefix}.resnets.{i}",
|
436
468
|
unet_config.ResidualBlock2DConfig(
|
437
469
|
in_channels=input_channels,
|
470
|
+
hidden_channels=config.out_channels,
|
438
471
|
out_channels=config.out_channels,
|
439
472
|
time_embedding_channels=config.time_embedding_channels,
|
440
473
|
normalization_config=config.normalization_config,
|
@@ -477,6 +510,7 @@ class BaseLoader(loader.ModelLoader):
|
|
477
510
|
f"{converted_state_param_prefix}.resnets.{i}",
|
478
511
|
unet_config.ResidualBlock2DConfig(
|
479
512
|
in_channels=input_channels,
|
513
|
+
hidden_channels=config.out_channels,
|
480
514
|
out_channels=config.out_channels,
|
481
515
|
time_embedding_channels=config.time_embedding_channels,
|
482
516
|
normalization_config=config.normalization_config,
|
@@ -509,9 +543,13 @@ class BaseLoader(loader.ModelLoader):
|
|
509
543
|
):
|
510
544
|
for i in range(config.num_layers):
|
511
545
|
res_skip_channels = (
|
512
|
-
config.in_channels
|
546
|
+
config.in_channels
|
547
|
+
if (i == config.num_layers - 1)
|
548
|
+
else config.out_channels
|
549
|
+
)
|
550
|
+
resnet_in_channels = (
|
551
|
+
config.prev_out_channels if i == 0 else config.out_channels
|
513
552
|
)
|
514
|
-
resnet_in_channels = config.prev_out_channels if i == 0 else config.out_channels
|
515
553
|
self._map_residual_block(
|
516
554
|
state,
|
517
555
|
converted_state,
|
@@ -519,6 +557,7 @@ class BaseLoader(loader.ModelLoader):
|
|
519
557
|
f"{converted_state_param_prefix}.resnets.{i}",
|
520
558
|
unet_config.ResidualBlock2DConfig(
|
521
559
|
in_channels=resnet_in_channels + res_skip_channels,
|
560
|
+
hidden_channels=config.out_channels,
|
522
561
|
out_channels=config.out_channels,
|
523
562
|
time_embedding_channels=config.time_embedding_channels,
|
524
563
|
normalization_config=config.normalization_config,
|
@@ -559,11 +598,13 @@ class AutoEncoderModelLoader(BaseLoader):
|
|
559
598
|
up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
|
560
599
|
|
561
600
|
def __init__(self, file_name: str, names: TensorNames):
|
562
|
-
"""AutoEncoderModelLoader constructor.
|
601
|
+
"""AutoEncoderModelLoader constructor.
|
602
|
+
|
603
|
+
Can be used to load encoder and decoder models.
|
563
604
|
|
564
605
|
Args:
|
565
|
-
file_name (str): Path to the checkpoint. Can be a directory or an
|
566
|
-
|
606
|
+
file_name (str): Path to the checkpoint. Can be a directory or an exact
|
607
|
+
file.
|
567
608
|
names (TensorNames): An instance of `TensorNames` to determine mappings.
|
568
609
|
"""
|
569
610
|
self._file_name = file_name
|
@@ -582,7 +623,8 @@ class AutoEncoderModelLoader(BaseLoader):
|
|
582
623
|
|
583
624
|
Returns:
|
584
625
|
missing_keys (List[str]): a list of str containing the missing keys.
|
585
|
-
unexpected_keys (List[str]): a list of str containing the unexpected
|
626
|
+
unexpected_keys (List[str]): a list of str containing the unexpected
|
627
|
+
keys.
|
586
628
|
|
587
629
|
Raises:
|
588
630
|
ValueError: If conversion results in unmapped tensors and strict mode is
|
@@ -599,9 +641,13 @@ class AutoEncoderModelLoader(BaseLoader):
|
|
599
641
|
state, self._names.post_quant_conv, converted_state, "post_quant_conv"
|
600
642
|
)
|
601
643
|
if self._names.conv_in is not None:
|
602
|
-
_map_to_converted_state(
|
644
|
+
_map_to_converted_state(
|
645
|
+
state, self._names.conv_in, converted_state, "conv_in"
|
646
|
+
)
|
603
647
|
if self._names.conv_out is not None:
|
604
|
-
_map_to_converted_state(
|
648
|
+
_map_to_converted_state(
|
649
|
+
state, self._names.conv_out, converted_state, "conv_out"
|
650
|
+
)
|
605
651
|
if self._names.final_norm is not None:
|
606
652
|
_map_to_converted_state(
|
607
653
|
state, self._names.final_norm, converted_state, "final_norm"
|
@@ -614,7 +660,9 @@ class AutoEncoderModelLoader(BaseLoader):
|
|
614
660
|
model.config.mid_block_config,
|
615
661
|
)
|
616
662
|
|
617
|
-
reversed_block_out_channels = list(
|
663
|
+
reversed_block_out_channels = list(
|
664
|
+
reversed(model.config.block_out_channels)
|
665
|
+
)
|
618
666
|
block_out_channels = reversed_block_out_channels[0]
|
619
667
|
for i, out_channels in enumerate(reversed_block_out_channels):
|
620
668
|
prev_output_channel = block_out_channels
|
@@ -642,6 +690,31 @@ class AutoEncoderModelLoader(BaseLoader):
|
|
642
690
|
return model.load_state_dict(converted_state, strict=strict)
|
643
691
|
|
644
692
|
|
693
|
+
def build_attention_config(
|
694
|
+
num_heads,
|
695
|
+
dim,
|
696
|
+
num_query_groups,
|
697
|
+
rotary_percentage=0.0,
|
698
|
+
qkv_transpose_before_split=True,
|
699
|
+
qkv_use_bias=False,
|
700
|
+
output_proj_use_bias=True,
|
701
|
+
enable_kv_cache=False,
|
702
|
+
qkv_fused_interleaved=False,
|
703
|
+
):
|
704
|
+
|
705
|
+
return layers_config.AttentionConfig(
|
706
|
+
num_heads=num_heads,
|
707
|
+
head_dim=dim // num_heads,
|
708
|
+
num_query_groups=num_query_groups,
|
709
|
+
rotary_percentage=rotary_percentage,
|
710
|
+
qkv_transpose_before_split=qkv_transpose_before_split,
|
711
|
+
qkv_use_bias=qkv_use_bias,
|
712
|
+
output_proj_use_bias=output_proj_use_bias,
|
713
|
+
enable_kv_cache=enable_kv_cache,
|
714
|
+
qkv_fused_interleaved=qkv_fused_interleaved,
|
715
|
+
)
|
716
|
+
|
717
|
+
|
645
718
|
class DiffusionModelLoader(BaseLoader):
|
646
719
|
|
647
720
|
@dataclass
|
@@ -655,11 +728,13 @@ class DiffusionModelLoader(BaseLoader):
|
|
655
728
|
up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
|
656
729
|
|
657
730
|
def __init__(self, file_name: str, names: TensorNames):
|
658
|
-
"""DiffusionModelLoader constructor.
|
731
|
+
"""DiffusionModelLoader constructor.
|
732
|
+
|
733
|
+
Can be used to load diffusion models of Stable Diffusion.
|
659
734
|
|
660
735
|
Args:
|
661
|
-
file_name (str): Path to the checkpoint. Can be a directory or an
|
662
|
-
|
736
|
+
file_name (str): Path to the checkpoint. Can be a directory or an exact
|
737
|
+
file.
|
663
738
|
names (TensorNames): An instance of `TensorNames` to determine mappings.
|
664
739
|
"""
|
665
740
|
self._file_name = file_name
|
@@ -678,7 +753,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
678
753
|
|
679
754
|
Returns:
|
680
755
|
missing_keys (List[str]): a list of str containing the missing keys.
|
681
|
-
unexpected_keys (List[str]): a list of str containing the unexpected
|
756
|
+
unexpected_keys (List[str]): a list of str containing the unexpected
|
757
|
+
keys.
|
682
758
|
|
683
759
|
Raises:
|
684
760
|
ValueError: If conversion results in unmapped tensors and strict mode is
|
@@ -690,20 +766,14 @@ class DiffusionModelLoader(BaseLoader):
|
|
690
766
|
self._map_time_embedding(
|
691
767
|
state, converted_state, "time_embedding", self._names.time_embedding
|
692
768
|
)
|
693
|
-
_map_to_converted_state(state, self._names.conv_in, converted_state, "conv_in")
|
694
|
-
_map_to_converted_state(state, self._names.conv_out, converted_state, "conv_out")
|
695
769
|
_map_to_converted_state(
|
696
|
-
state, self._names.
|
770
|
+
state, self._names.conv_in, converted_state, "conv_in"
|
697
771
|
)
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
qkv_transpose_before_split=True,
|
704
|
-
qkv_use_bias=False,
|
705
|
-
output_proj_use_bias=True,
|
706
|
-
enable_kv_cache=False,
|
772
|
+
_map_to_converted_state(
|
773
|
+
state, self._names.conv_out, converted_state, "conv_out"
|
774
|
+
)
|
775
|
+
_map_to_converted_state(
|
776
|
+
state, self._names.final_norm, converted_state, "final_norm"
|
707
777
|
)
|
708
778
|
|
709
779
|
# Map down_encoders.
|
@@ -736,13 +806,23 @@ class DiffusionModelLoader(BaseLoader):
|
|
736
806
|
attention_block_config=unet_config.AttentionBlock2DConfig(
|
737
807
|
dim=output_channel,
|
738
808
|
normalization_config=config.transformer_norm_config,
|
739
|
-
attention_config=
|
809
|
+
attention_config=build_attention_config(
|
810
|
+
num_heads=config.transformer_num_attention_heads,
|
811
|
+
dim=output_channel,
|
812
|
+
num_query_groups=config.transformer_num_attention_heads,
|
813
|
+
),
|
740
814
|
),
|
741
815
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
742
816
|
query_dim=output_channel,
|
743
817
|
cross_dim=config.transformer_cross_attention_dim,
|
818
|
+
hidden_dim=output_channel,
|
819
|
+
output_dim=output_channel,
|
744
820
|
normalization_config=config.transformer_norm_config,
|
745
|
-
attention_config=
|
821
|
+
attention_config=build_attention_config(
|
822
|
+
num_heads=config.transformer_num_attention_heads,
|
823
|
+
dim=output_channel,
|
824
|
+
num_query_groups=config.transformer_num_attention_heads,
|
825
|
+
),
|
746
826
|
),
|
747
827
|
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
748
828
|
feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
|
@@ -794,13 +874,23 @@ class DiffusionModelLoader(BaseLoader):
|
|
794
874
|
attention_block_config=unet_config.AttentionBlock2DConfig(
|
795
875
|
dim=mid_block_channels,
|
796
876
|
normalization_config=config.transformer_norm_config,
|
797
|
-
attention_config=
|
877
|
+
attention_config=build_attention_config(
|
878
|
+
num_heads=config.transformer_num_attention_heads,
|
879
|
+
dim=mid_block_channels,
|
880
|
+
num_query_groups=config.transformer_num_attention_heads,
|
881
|
+
),
|
798
882
|
),
|
799
883
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
800
884
|
query_dim=mid_block_channels,
|
801
885
|
cross_dim=config.transformer_cross_attention_dim,
|
886
|
+
hidden_dim=mid_block_channels,
|
887
|
+
output_dim=mid_block_channels,
|
802
888
|
normalization_config=config.transformer_norm_config,
|
803
|
-
attention_config=
|
889
|
+
attention_config=build_attention_config(
|
890
|
+
num_heads=config.transformer_num_attention_heads,
|
891
|
+
dim=mid_block_channels,
|
892
|
+
num_query_groups=config.transformer_num_attention_heads,
|
893
|
+
),
|
804
894
|
),
|
805
895
|
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
806
896
|
feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
|
@@ -825,7 +915,9 @@ class DiffusionModelLoader(BaseLoader):
|
|
825
915
|
)
|
826
916
|
|
827
917
|
# Map up_decoders.
|
828
|
-
reversed_block_out_channels = list(
|
918
|
+
reversed_block_out_channels = list(
|
919
|
+
reversed(model.config.block_out_channels)
|
920
|
+
)
|
829
921
|
up_decoder_layers_per_block = config.layers_per_block + 1
|
830
922
|
output_channel = reversed_block_out_channels[0]
|
831
923
|
for i, block_out_channel in enumerate(reversed_block_out_channels):
|
@@ -857,13 +949,23 @@ class DiffusionModelLoader(BaseLoader):
|
|
857
949
|
attention_block_config=unet_config.AttentionBlock2DConfig(
|
858
950
|
dim=output_channel,
|
859
951
|
normalization_config=config.transformer_norm_config,
|
860
|
-
attention_config=
|
952
|
+
attention_config=build_attention_config(
|
953
|
+
num_heads=config.transformer_num_attention_heads,
|
954
|
+
dim=output_channel,
|
955
|
+
num_query_groups=config.transformer_num_attention_heads,
|
956
|
+
),
|
861
957
|
),
|
862
958
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
863
959
|
query_dim=output_channel,
|
864
960
|
cross_dim=config.transformer_cross_attention_dim,
|
961
|
+
hidden_dim=output_channel,
|
962
|
+
output_dim=output_channel,
|
865
963
|
normalization_config=config.transformer_norm_config,
|
866
|
-
attention_config=
|
964
|
+
attention_config=build_attention_config(
|
965
|
+
num_heads=config.transformer_num_attention_heads,
|
966
|
+
dim=output_channel,
|
967
|
+
num_query_groups=config.transformer_num_attention_heads,
|
968
|
+
),
|
867
969
|
),
|
868
970
|
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
869
971
|
feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
|
@@ -917,8 +1019,14 @@ class DiffusionModelLoader(BaseLoader):
|
|
917
1019
|
tensor_names: TimeEmbeddingTensorNames,
|
918
1020
|
):
|
919
1021
|
_map_to_converted_state(
|
920
|
-
state,
|
1022
|
+
state,
|
1023
|
+
tensor_names.w1,
|
1024
|
+
converted_state,
|
1025
|
+
f"{converted_state_param_prefix}.w1",
|
921
1026
|
)
|
922
1027
|
_map_to_converted_state(
|
923
|
-
state,
|
1028
|
+
state,
|
1029
|
+
tensor_names.w2,
|
1030
|
+
converted_state,
|
1031
|
+
f"{converted_state_param_prefix}.w2",
|
924
1032
|
)
|