ai-edge-torch-nightly 0.2.0.dev20240714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (121) hide show
  1. ai_edge_torch/__init__.py +31 -0
  2. ai_edge_torch/convert/__init__.py +14 -0
  3. ai_edge_torch/convert/conversion.py +117 -0
  4. ai_edge_torch/convert/conversion_utils.py +400 -0
  5. ai_edge_torch/convert/converter.py +202 -0
  6. ai_edge_torch/convert/fx_passes/__init__.py +59 -0
  7. ai_edge_torch/convert/fx_passes/_pass_base.py +49 -0
  8. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +225 -0
  9. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +123 -0
  10. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +37 -0
  11. ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py +73 -0
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +215 -0
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py +48 -0
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +17 -0
  16. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +59 -0
  17. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +215 -0
  18. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +400 -0
  19. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +30 -0
  20. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +293 -0
  21. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py +62 -0
  22. ai_edge_torch/convert/test/__init__.py +14 -0
  23. ai_edge_torch/convert/test/test_convert.py +311 -0
  24. ai_edge_torch/convert/test/test_convert_composites.py +192 -0
  25. ai_edge_torch/convert/test/test_convert_multisig.py +139 -0
  26. ai_edge_torch/convert/test/test_to_channel_last_io.py +96 -0
  27. ai_edge_torch/convert/to_channel_last_io.py +85 -0
  28. ai_edge_torch/debug/__init__.py +17 -0
  29. ai_edge_torch/debug/culprit.py +464 -0
  30. ai_edge_torch/debug/test/__init__.py +14 -0
  31. ai_edge_torch/debug/test/test_culprit.py +133 -0
  32. ai_edge_torch/debug/test/test_search_model.py +50 -0
  33. ai_edge_torch/debug/utils.py +48 -0
  34. ai_edge_torch/experimental/__init__.py +14 -0
  35. ai_edge_torch/generative/__init__.py +14 -0
  36. ai_edge_torch/generative/examples/__init__.py +14 -0
  37. ai_edge_torch/generative/examples/gemma/__init__.py +14 -0
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +66 -0
  39. ai_edge_torch/generative/examples/gemma/gemma.py +174 -0
  40. ai_edge_torch/generative/examples/phi2/__init__.py +14 -0
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +64 -0
  42. ai_edge_torch/generative/examples/phi2/phi2.py +164 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/__init__.py +14 -0
  44. ai_edge_torch/generative/examples/stable_diffusion/attention.py +106 -0
  45. ai_edge_torch/generative/examples/stable_diffusion/clip.py +115 -0
  46. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +142 -0
  47. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +317 -0
  48. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +573 -0
  49. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +118 -0
  50. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +222 -0
  51. ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py +19 -0
  52. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +61 -0
  53. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +65 -0
  54. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +73 -0
  55. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +38 -0
  56. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +108 -0
  57. ai_edge_torch/generative/examples/stable_diffusion/util.py +71 -0
  58. ai_edge_torch/generative/examples/t5/__init__.py +14 -0
  59. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +135 -0
  60. ai_edge_torch/generative/examples/t5/t5.py +608 -0
  61. ai_edge_torch/generative/examples/t5/t5_attention.py +231 -0
  62. ai_edge_torch/generative/examples/test_models/__init__.py +14 -0
  63. ai_edge_torch/generative/examples/test_models/toy_model.py +122 -0
  64. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
  65. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +143 -0
  66. ai_edge_torch/generative/examples/tiny_llama/__init__.py +0 -0
  67. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +66 -0
  68. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +164 -0
  69. ai_edge_torch/generative/fx_passes/__init__.py +31 -0
  70. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +47 -0
  71. ai_edge_torch/generative/layers/__init__.py +14 -0
  72. ai_edge_torch/generative/layers/attention.py +354 -0
  73. ai_edge_torch/generative/layers/attention_utils.py +169 -0
  74. ai_edge_torch/generative/layers/builder.py +131 -0
  75. ai_edge_torch/generative/layers/feed_forward.py +95 -0
  76. ai_edge_torch/generative/layers/kv_cache.py +83 -0
  77. ai_edge_torch/generative/layers/model_config.py +158 -0
  78. ai_edge_torch/generative/layers/normalization.py +62 -0
  79. ai_edge_torch/generative/layers/rotary_position_embedding.py +36 -0
  80. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
  81. ai_edge_torch/generative/layers/unet/__init__.py +14 -0
  82. ai_edge_torch/generative/layers/unet/blocks_2d.py +711 -0
  83. ai_edge_torch/generative/layers/unet/builder.py +47 -0
  84. ai_edge_torch/generative/layers/unet/model_config.py +269 -0
  85. ai_edge_torch/generative/quantize/__init__.py +14 -0
  86. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -0
  87. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +148 -0
  88. ai_edge_torch/generative/quantize/example.py +45 -0
  89. ai_edge_torch/generative/quantize/quant_attrs.py +68 -0
  90. ai_edge_torch/generative/quantize/quant_recipe.py +151 -0
  91. ai_edge_torch/generative/quantize/quant_recipe_utils.py +51 -0
  92. ai_edge_torch/generative/quantize/quant_recipes.py +48 -0
  93. ai_edge_torch/generative/quantize/supported_schemes.py +32 -0
  94. ai_edge_torch/generative/test/__init__.py +14 -0
  95. ai_edge_torch/generative/test/loader_test.py +80 -0
  96. ai_edge_torch/generative/test/test_model_conversion.py +235 -0
  97. ai_edge_torch/generative/test/test_quantize.py +162 -0
  98. ai_edge_torch/generative/utilities/__init__.py +15 -0
  99. ai_edge_torch/generative/utilities/loader.py +328 -0
  100. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +924 -0
  101. ai_edge_torch/generative/utilities/t5_loader.py +483 -0
  102. ai_edge_torch/hlfb/__init__.py +16 -0
  103. ai_edge_torch/hlfb/mark_pattern/__init__.py +139 -0
  104. ai_edge_torch/hlfb/mark_pattern/passes.py +42 -0
  105. ai_edge_torch/hlfb/mark_pattern/pattern.py +273 -0
  106. ai_edge_torch/hlfb/test/__init__.py +14 -0
  107. ai_edge_torch/hlfb/test/test_mark_pattern.py +133 -0
  108. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +270 -0
  109. ai_edge_torch/model.py +142 -0
  110. ai_edge_torch/quantize/__init__.py +16 -0
  111. ai_edge_torch/quantize/pt2e_quantizer.py +438 -0
  112. ai_edge_torch/quantize/pt2e_quantizer_utils.py +1041 -0
  113. ai_edge_torch/quantize/quant_config.py +81 -0
  114. ai_edge_torch/testing/__init__.py +14 -0
  115. ai_edge_torch/testing/model_coverage/__init__.py +16 -0
  116. ai_edge_torch/testing/model_coverage/model_coverage.py +132 -0
  117. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/LICENSE +202 -0
  118. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/METADATA +38 -0
  119. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/RECORD +121 -0
  120. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/WHEEL +5 -0
  121. ai_edge_torch_nightly-0.2.0.dev20240714.dist-info/top_level.txt +1 -0
@@ -0,0 +1,924 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Common utility functions for data loading etc.
16
+ from dataclasses import dataclass
17
+ from typing import Dict, List, Optional, Tuple
18
+
19
+ import torch
20
+
21
+ import ai_edge_torch.generative.layers.model_config as layers_config
22
+ import ai_edge_torch.generative.layers.unet.model_config as unet_config
23
+ import ai_edge_torch.generative.utilities.loader as loader
24
+
25
+
26
+ @dataclass
27
+ class ResidualBlockTensorNames:
28
+ norm_1: str = None
29
+ conv_1: str = None
30
+ norm_2: str = None
31
+ conv_2: str = None
32
+ residual_layer: str = None
33
+ time_embedding: str = None
34
+
35
+
36
+ @dataclass
37
+ class AttentionBlockTensorNames:
38
+ norm: str = None
39
+ fused_qkv_proj: str = None
40
+ q_proj: str = None
41
+ k_proj: str = None
42
+ v_proj: str = None
43
+ output_proj: str = None
44
+
45
+
46
+ @dataclass
47
+ class CrossAttentionBlockTensorNames:
48
+ norm: str = None
49
+ q_proj: str = None
50
+ k_proj: str = None
51
+ v_proj: str = None
52
+ output_proj: str = None
53
+
54
+
55
+ @dataclass
56
+ class TimeEmbeddingTensorNames:
57
+ w1: str = None
58
+ w2: str = None
59
+
60
+
61
+ @dataclass
62
+ class FeedForwardBlockTensorNames:
63
+ w1: str = None
64
+ w2: str = None
65
+ norm: str = None
66
+ ge_glu: str = None
67
+
68
+
69
+ @dataclass
70
+ class TransformerBlockTensorNames:
71
+ pre_conv_norm: str
72
+ conv_in: str
73
+ self_attention: AttentionBlockTensorNames
74
+ cross_attention: CrossAttentionBlockTensorNames
75
+ feed_forward: FeedForwardBlockTensorNames
76
+ conv_out: str
77
+
78
+
79
+ @dataclass
80
+ class MidBlockTensorNames:
81
+ residual_block_tensor_names: List[ResidualBlockTensorNames]
82
+ attention_block_tensor_names: Optional[List[AttentionBlockTensorNames]] = None
83
+ transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None
84
+
85
+
86
+ @dataclass
87
+ class DownEncoderBlockTensorNames:
88
+ residual_block_tensor_names: List[ResidualBlockTensorNames]
89
+ transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None
90
+ downsample_conv: str = None
91
+
92
+
93
+ @dataclass
94
+ class UpDecoderBlockTensorNames:
95
+ residual_block_tensor_names: List[ResidualBlockTensorNames]
96
+ transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None
97
+ upsample_conv: str = None
98
+
99
+
100
+ @dataclass
101
+ class SkipUpDecoderBlockTensorNames:
102
+ residual_block_tensor_names: List[ResidualBlockTensorNames]
103
+ transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None
104
+ upsample_conv: str = None
105
+
106
+
107
+ def _map_to_converted_state(
108
+ state: Dict[str, torch.Tensor],
109
+ state_param: str,
110
+ converted_state: Dict[str, torch.Tensor],
111
+ converted_state_param: str,
112
+ squeeze_dims: bool = False,
113
+ ):
114
+ converted_state[f"{converted_state_param}.weight"] = state.pop(
115
+ f"{state_param}.weight"
116
+ )
117
+ if squeeze_dims:
118
+ converted_state[f"{converted_state_param}.weight"] = torch.squeeze(
119
+ converted_state[f"{converted_state_param}.weight"]
120
+ )
121
+ if f"{state_param}.bias" in state:
122
+ converted_state[f"{converted_state_param}.bias"] = state.pop(f"{state_param}.bias")
123
+ if squeeze_dims:
124
+ converted_state[f"{converted_state_param}.bias"] = torch.squeeze(
125
+ converted_state[f"{converted_state_param}.bias"]
126
+ )
127
+
128
+
129
+ class BaseLoader(loader.ModelLoader):
130
+
131
+ def _map_residual_block(
132
+ self,
133
+ state: Dict[str, torch.Tensor],
134
+ converted_state: Dict[str, torch.Tensor],
135
+ tensor_names: ResidualBlockTensorNames,
136
+ converted_state_param_prefix: str,
137
+ config: unet_config.ResidualBlock2DConfig,
138
+ ):
139
+ _map_to_converted_state(
140
+ state,
141
+ tensor_names.norm_1,
142
+ converted_state,
143
+ f"{converted_state_param_prefix}.norm_1",
144
+ )
145
+ _map_to_converted_state(
146
+ state,
147
+ tensor_names.conv_1,
148
+ converted_state,
149
+ f"{converted_state_param_prefix}.conv_1",
150
+ )
151
+ _map_to_converted_state(
152
+ state,
153
+ tensor_names.norm_2,
154
+ converted_state,
155
+ f"{converted_state_param_prefix}.norm_2",
156
+ )
157
+ _map_to_converted_state(
158
+ state,
159
+ tensor_names.conv_2,
160
+ converted_state,
161
+ f"{converted_state_param_prefix}.conv_2",
162
+ )
163
+ if config.in_channels != config.out_channels:
164
+ _map_to_converted_state(
165
+ state,
166
+ tensor_names.residual_layer,
167
+ converted_state,
168
+ f"{converted_state_param_prefix}.residual_layer",
169
+ )
170
+ if config.time_embedding_channels is not None:
171
+ _map_to_converted_state(
172
+ state,
173
+ tensor_names.time_embedding,
174
+ converted_state,
175
+ f"{converted_state_param_prefix}.time_emb_proj",
176
+ )
177
+
178
+ def _map_attention_block(
179
+ self,
180
+ state: Dict[str, torch.Tensor],
181
+ converted_state: Dict[str, torch.Tensor],
182
+ tensor_names: AttentionBlockTensorNames,
183
+ converted_state_param_prefix: str,
184
+ config: unet_config.AttentionBlock2DConfig,
185
+ ):
186
+ if config.normalization_config.type != layers_config.NormalizationType.NONE:
187
+ _map_to_converted_state(
188
+ state,
189
+ tensor_names.norm,
190
+ converted_state,
191
+ f"{converted_state_param_prefix}.norm",
192
+ )
193
+ attention_layer_prefix = f"{converted_state_param_prefix}.attention"
194
+ if tensor_names.fused_qkv_proj is not None:
195
+ _map_to_converted_state(
196
+ state,
197
+ tensor_names.fused_qkv_proj,
198
+ converted_state,
199
+ f"{attention_layer_prefix}.qkv_projection",
200
+ )
201
+ else:
202
+ _map_to_converted_state(
203
+ state,
204
+ tensor_names.q_proj,
205
+ converted_state,
206
+ f"{attention_layer_prefix}.q_projection",
207
+ squeeze_dims=True,
208
+ )
209
+ _map_to_converted_state(
210
+ state,
211
+ tensor_names.k_proj,
212
+ converted_state,
213
+ f"{attention_layer_prefix}.k_projection",
214
+ squeeze_dims=True,
215
+ )
216
+ _map_to_converted_state(
217
+ state,
218
+ tensor_names.v_proj,
219
+ converted_state,
220
+ f"{attention_layer_prefix}.v_projection",
221
+ squeeze_dims=True,
222
+ )
223
+ converted_state[f"{attention_layer_prefix}.qkv_projection.weight"] = torch.concat(
224
+ [
225
+ converted_state[f"{attention_layer_prefix}.q_projection.weight"],
226
+ converted_state[f"{attention_layer_prefix}.k_projection.weight"],
227
+ converted_state[f"{attention_layer_prefix}.v_projection.weight"],
228
+ ],
229
+ axis=0,
230
+ )
231
+ del converted_state[f"{attention_layer_prefix}.q_projection.weight"]
232
+ del converted_state[f"{attention_layer_prefix}.k_projection.weight"]
233
+ del converted_state[f"{attention_layer_prefix}.v_projection.weight"]
234
+ if config.attention_config.qkv_use_bias:
235
+ converted_state[f"{attention_layer_prefix}.qkv_projection.bias"] = torch.concat(
236
+ [
237
+ converted_state[f"{attention_layer_prefix}.q_projection.bias"],
238
+ converted_state[f"{attention_layer_prefix}.k_projection.bias"],
239
+ converted_state[f"{attention_layer_prefix}.v_projection.bias"],
240
+ ],
241
+ axis=0,
242
+ )
243
+ del converted_state[f"{attention_layer_prefix}.q_projection.bias"]
244
+ del converted_state[f"{attention_layer_prefix}.k_projection.bias"]
245
+ del converted_state[f"{attention_layer_prefix}.v_projection.bias"]
246
+
247
+ _map_to_converted_state(
248
+ state,
249
+ tensor_names.output_proj,
250
+ converted_state,
251
+ f"{attention_layer_prefix}.output_projection",
252
+ squeeze_dims=True,
253
+ )
254
+
255
+ def _map_cross_attention_block(
256
+ self,
257
+ state: Dict[str, torch.Tensor],
258
+ converted_state: Dict[str, torch.Tensor],
259
+ tensor_names: CrossAttentionBlockTensorNames,
260
+ converted_state_param_prefix: str,
261
+ config: unet_config.CrossAttentionBlock2DConfig,
262
+ ):
263
+ if config.normalization_config.type != layers_config.NormalizationType.NONE:
264
+ _map_to_converted_state(
265
+ state,
266
+ tensor_names.norm,
267
+ converted_state,
268
+ f"{converted_state_param_prefix}.norm",
269
+ )
270
+ attention_layer_prefix = f"{converted_state_param_prefix}.attention"
271
+ _map_to_converted_state(
272
+ state,
273
+ tensor_names.q_proj,
274
+ converted_state,
275
+ f"{attention_layer_prefix}.q_projection",
276
+ )
277
+ _map_to_converted_state(
278
+ state,
279
+ tensor_names.k_proj,
280
+ converted_state,
281
+ f"{attention_layer_prefix}.k_projection",
282
+ )
283
+ _map_to_converted_state(
284
+ state,
285
+ tensor_names.v_proj,
286
+ converted_state,
287
+ f"{attention_layer_prefix}.v_projection",
288
+ )
289
+ _map_to_converted_state(
290
+ state,
291
+ tensor_names.output_proj,
292
+ converted_state,
293
+ f"{attention_layer_prefix}.output_projection",
294
+ )
295
+
296
+ def _map_feedforward_block(
297
+ self,
298
+ state: Dict[str, torch.Tensor],
299
+ converted_state: Dict[str, torch.Tensor],
300
+ tensor_names: FeedForwardBlockTensorNames,
301
+ converted_state_param_prefix: str,
302
+ config: unet_config.FeedForwardBlock2DConfig,
303
+ ):
304
+ _map_to_converted_state(
305
+ state,
306
+ tensor_names.norm,
307
+ converted_state,
308
+ f"{converted_state_param_prefix}.norm",
309
+ )
310
+ if config.activation_config.type == layers_config.ActivationType.GE_GLU:
311
+ _map_to_converted_state(
312
+ state,
313
+ tensor_names.ge_glu,
314
+ converted_state,
315
+ f"{converted_state_param_prefix}.act.proj",
316
+ )
317
+ else:
318
+ _map_to_converted_state(
319
+ state, tensor_names.w1, converted_state, f"{converted_state_param_prefix}.w1"
320
+ )
321
+
322
+ _map_to_converted_state(
323
+ state, tensor_names.w2, converted_state, f"{converted_state_param_prefix}.w2"
324
+ )
325
+
326
+ def _map_transformer_block(
327
+ self,
328
+ state: Dict[str, torch.Tensor],
329
+ converted_state: Dict[str, torch.Tensor],
330
+ tensor_names: TransformerBlockTensorNames,
331
+ converted_state_param_prefix: str,
332
+ config: unet_config.TransformerBlock2DConfig,
333
+ ):
334
+ _map_to_converted_state(
335
+ state,
336
+ tensor_names.pre_conv_norm,
337
+ converted_state,
338
+ f"{converted_state_param_prefix}.pre_conv_norm",
339
+ )
340
+ _map_to_converted_state(
341
+ state,
342
+ tensor_names.conv_in,
343
+ converted_state,
344
+ f"{converted_state_param_prefix}.conv_in",
345
+ )
346
+ self._map_attention_block(
347
+ state,
348
+ converted_state,
349
+ tensor_names.self_attention,
350
+ f"{converted_state_param_prefix}.self_attention",
351
+ config.attention_block_config,
352
+ )
353
+ self._map_cross_attention_block(
354
+ state,
355
+ converted_state,
356
+ tensor_names.cross_attention,
357
+ f"{converted_state_param_prefix}.cross_attention",
358
+ config.cross_attention_block_config,
359
+ )
360
+ self._map_feedforward_block(
361
+ state,
362
+ converted_state,
363
+ tensor_names.feed_forward,
364
+ f"{converted_state_param_prefix}.feed_forward",
365
+ config.feed_forward_block_config,
366
+ )
367
+ _map_to_converted_state(
368
+ state,
369
+ tensor_names.conv_out,
370
+ converted_state,
371
+ f"{converted_state_param_prefix}.conv_out",
372
+ )
373
+
374
+ def _map_mid_block(
375
+ self,
376
+ state: Dict[str, torch.Tensor],
377
+ converted_state: Dict[str, torch.Tensor],
378
+ tensor_names: MidBlockTensorNames,
379
+ converted_state_param_prefix: str,
380
+ config: unet_config.MidBlock2DConfig,
381
+ ):
382
+ residual_block_config = unet_config.ResidualBlock2DConfig(
383
+ in_channels=config.in_channels,
384
+ out_channels=config.in_channels,
385
+ time_embedding_channels=config.time_embedding_channels,
386
+ normalization_config=config.normalization_config,
387
+ activation_config=config.activation_config,
388
+ )
389
+ self._map_residual_block(
390
+ state,
391
+ converted_state,
392
+ tensor_names.residual_block_tensor_names[0],
393
+ f"{converted_state_param_prefix}.resnets.0",
394
+ residual_block_config,
395
+ )
396
+ for i in range(config.num_layers):
397
+ if config.attention_block_config:
398
+ self._map_attention_block(
399
+ state,
400
+ converted_state,
401
+ tensor_names.attention_block_tensor_names[i],
402
+ f"{converted_state_param_prefix}.attentions.{i}",
403
+ config.attention_block_config,
404
+ )
405
+ if config.transformer_block_config:
406
+ self._map_transformer_block(
407
+ state,
408
+ converted_state,
409
+ tensor_names.transformer_block_tensor_names[i],
410
+ f"{converted_state_param_prefix}.transformers.{i}",
411
+ config.transformer_block_config,
412
+ )
413
+ self._map_residual_block(
414
+ state,
415
+ converted_state,
416
+ tensor_names.residual_block_tensor_names[i + 1],
417
+ f"{converted_state_param_prefix}.resnets.{i+1}",
418
+ residual_block_config,
419
+ )
420
+
421
+ def _map_down_encoder_block(
422
+ self,
423
+ state: Dict[str, torch.Tensor],
424
+ converted_state: Dict[str, torch.Tensor],
425
+ converted_state_param_prefix: str,
426
+ config: unet_config.DownEncoderBlock2DConfig,
427
+ tensor_names: DownEncoderBlockTensorNames,
428
+ ):
429
+ for i in range(config.num_layers):
430
+ input_channels = config.in_channels if i == 0 else config.out_channels
431
+ self._map_residual_block(
432
+ state,
433
+ converted_state,
434
+ tensor_names.residual_block_tensor_names[i],
435
+ f"{converted_state_param_prefix}.resnets.{i}",
436
+ unet_config.ResidualBlock2DConfig(
437
+ in_channels=input_channels,
438
+ out_channels=config.out_channels,
439
+ time_embedding_channels=config.time_embedding_channels,
440
+ normalization_config=config.normalization_config,
441
+ activation_config=config.activation_config,
442
+ ),
443
+ )
444
+ if config.transformer_block_config:
445
+ self._map_transformer_block(
446
+ state,
447
+ converted_state,
448
+ tensor_names.transformer_block_tensor_names[i],
449
+ f"{converted_state_param_prefix}.transformers.{i}",
450
+ config.transformer_block_config,
451
+ )
452
+ if (
453
+ config.add_downsample
454
+ and config.sampling_config.mode == unet_config.SamplingType.CONVOLUTION
455
+ ):
456
+ _map_to_converted_state(
457
+ state,
458
+ tensor_names.downsample_conv,
459
+ converted_state,
460
+ f"{converted_state_param_prefix}.downsampler",
461
+ )
462
+
463
+ def _map_up_decoder_block(
464
+ self,
465
+ state: Dict[str, torch.Tensor],
466
+ converted_state: Dict[str, torch.Tensor],
467
+ converted_state_param_prefix: str,
468
+ config: unet_config.UpDecoderBlock2DConfig,
469
+ tensor_names: UpDecoderBlockTensorNames,
470
+ ):
471
+ for i in range(config.num_layers):
472
+ input_channels = config.in_channels if i == 0 else config.out_channels
473
+ self._map_residual_block(
474
+ state,
475
+ converted_state,
476
+ tensor_names.residual_block_tensor_names[i],
477
+ f"{converted_state_param_prefix}.resnets.{i}",
478
+ unet_config.ResidualBlock2DConfig(
479
+ in_channels=input_channels,
480
+ out_channels=config.out_channels,
481
+ time_embedding_channels=config.time_embedding_channels,
482
+ normalization_config=config.normalization_config,
483
+ activation_config=config.activation_config,
484
+ ),
485
+ )
486
+ if config.transformer_block_config:
487
+ self._map_transformer_block(
488
+ state,
489
+ converted_state,
490
+ tensor_names.transformer_block_tensor_names[i],
491
+ f"{converted_state_param_prefix}.transformers.{i}",
492
+ config.transformer_block_config,
493
+ )
494
+ if config.add_upsample and config.upsample_conv:
495
+ _map_to_converted_state(
496
+ state,
497
+ tensor_names.upsample_conv,
498
+ converted_state,
499
+ f"{converted_state_param_prefix}.upsample_conv",
500
+ )
501
+
502
+ def _map_skip_up_decoder_block(
503
+ self,
504
+ state: Dict[str, torch.Tensor],
505
+ converted_state: Dict[str, torch.Tensor],
506
+ converted_state_param_prefix: str,
507
+ config: unet_config.SkipUpDecoderBlock2DConfig,
508
+ tensor_names: UpDecoderBlockTensorNames,
509
+ ):
510
+ for i in range(config.num_layers):
511
+ res_skip_channels = (
512
+ config.in_channels if (i == config.num_layers - 1) else config.out_channels
513
+ )
514
+ resnet_in_channels = config.prev_out_channels if i == 0 else config.out_channels
515
+ self._map_residual_block(
516
+ state,
517
+ converted_state,
518
+ tensor_names.residual_block_tensor_names[i],
519
+ f"{converted_state_param_prefix}.resnets.{i}",
520
+ unet_config.ResidualBlock2DConfig(
521
+ in_channels=resnet_in_channels + res_skip_channels,
522
+ out_channels=config.out_channels,
523
+ time_embedding_channels=config.time_embedding_channels,
524
+ normalization_config=config.normalization_config,
525
+ activation_config=config.activation_config,
526
+ ),
527
+ )
528
+ if config.transformer_block_config:
529
+ self._map_transformer_block(
530
+ state,
531
+ converted_state,
532
+ tensor_names.transformer_block_tensor_names[i],
533
+ f"{converted_state_param_prefix}.transformers.{i}",
534
+ config.transformer_block_config,
535
+ )
536
+ if config.add_upsample and config.upsample_conv:
537
+ _map_to_converted_state(
538
+ state,
539
+ tensor_names.upsample_conv,
540
+ converted_state,
541
+ f"{converted_state_param_prefix}.upsample_conv",
542
+ )
543
+
544
+
545
+ # Alias class name for better code reading.
546
+ ClipModelLoader = BaseLoader
547
+
548
+
549
+ class AutoEncoderModelLoader(BaseLoader):
550
+
551
+ @dataclass
552
+ class TensorNames:
553
+ quant_conv: str = None
554
+ post_quant_conv: str = None
555
+ conv_in: str = None
556
+ conv_out: str = None
557
+ final_norm: str = None
558
+ mid_block_tensor_names: MidBlockTensorNames = None
559
+ up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
560
+
561
+ def __init__(self, file_name: str, names: TensorNames):
562
+ """AutoEncoderModelLoader constructor. Can be used to load encoder and decoder models.
563
+
564
+ Args:
565
+ file_name (str): Path to the checkpoint. Can be a directory or an
566
+ exact file.
567
+ names (TensorNames): An instance of `TensorNames` to determine mappings.
568
+ """
569
+ self._file_name = file_name
570
+ self._names = names
571
+ self._loader = self._get_loader()
572
+
573
+ def load(
574
+ self, model: torch.nn.Module, strict: bool = True
575
+ ) -> Tuple[List[str], List[str]]:
576
+ """Load the model from the checkpoint.
577
+
578
+ Args:
579
+ model (torch.nn.Module): The pytorch model that needs to be loaded.
580
+ strict (bool, optional): Whether the converted keys are strictly
581
+ matched. Defaults to True.
582
+
583
+ Returns:
584
+ missing_keys (List[str]): a list of str containing the missing keys.
585
+ unexpected_keys (List[str]): a list of str containing the unexpected keys.
586
+
587
+ Raises:
588
+ ValueError: If conversion results in unmapped tensors and strict mode is
589
+ enabled.
590
+ """
591
+ state = self._loader(self._file_name)
592
+ converted_state = dict()
593
+ if self._names.quant_conv is not None:
594
+ _map_to_converted_state(
595
+ state, self._names.quant_conv, converted_state, "quant_conv"
596
+ )
597
+ if self._names.post_quant_conv is not None:
598
+ _map_to_converted_state(
599
+ state, self._names.post_quant_conv, converted_state, "post_quant_conv"
600
+ )
601
+ if self._names.conv_in is not None:
602
+ _map_to_converted_state(state, self._names.conv_in, converted_state, "conv_in")
603
+ if self._names.conv_out is not None:
604
+ _map_to_converted_state(state, self._names.conv_out, converted_state, "conv_out")
605
+ if self._names.final_norm is not None:
606
+ _map_to_converted_state(
607
+ state, self._names.final_norm, converted_state, "final_norm"
608
+ )
609
+ self._map_mid_block(
610
+ state,
611
+ converted_state,
612
+ self._names.mid_block_tensor_names,
613
+ "mid_block",
614
+ model.config.mid_block_config,
615
+ )
616
+
617
+ reversed_block_out_channels = list(reversed(model.config.block_out_channels))
618
+ block_out_channels = reversed_block_out_channels[0]
619
+ for i, out_channels in enumerate(reversed_block_out_channels):
620
+ prev_output_channel = block_out_channels
621
+ block_out_channels = out_channels
622
+ not_final_block = i < len(reversed_block_out_channels) - 1
623
+ self._map_up_decoder_block(
624
+ state,
625
+ converted_state,
626
+ f"up_decoder_blocks.{i}",
627
+ unet_config.UpDecoderBlock2DConfig(
628
+ in_channels=prev_output_channel,
629
+ out_channels=block_out_channels,
630
+ normalization_config=model.config.normalization_config,
631
+ activation_config=model.config.activation_config,
632
+ num_layers=model.config.layers_per_block,
633
+ add_upsample=not_final_block,
634
+ upsample_conv=True,
635
+ ),
636
+ self._names.up_decoder_blocks_tensor_names[i],
637
+ )
638
+ if strict and state:
639
+ raise ValueError(
640
+ f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
641
+ )
642
+ return model.load_state_dict(converted_state, strict=strict)
643
+
644
+
645
+ class DiffusionModelLoader(BaseLoader):
646
+
647
+ @dataclass
648
+ class TensorNames:
649
+ time_embedding: TimeEmbeddingTensorNames = None
650
+ conv_in: str = None
651
+ conv_out: str = None
652
+ final_norm: str = None
653
+ down_encoder_blocks_tensor_names: List[DownEncoderBlockTensorNames] = None
654
+ mid_block_tensor_names: MidBlockTensorNames = None
655
+ up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
656
+
657
+ def __init__(self, file_name: str, names: TensorNames):
658
+ """DiffusionModelLoader constructor. Can be used to load diffusion models of Stable Diffusion.
659
+
660
+ Args:
661
+ file_name (str): Path to the checkpoint. Can be a directory or an
662
+ exact file.
663
+ names (TensorNames): An instance of `TensorNames` to determine mappings.
664
+ """
665
+ self._file_name = file_name
666
+ self._names = names
667
+ self._loader = self._get_loader()
668
+
669
+ def load(
670
+ self, model: torch.nn.Module, strict: bool = True
671
+ ) -> Tuple[List[str], List[str]]:
672
+ """Load the model from the checkpoint.
673
+
674
+ Args:
675
+ model (torch.nn.Module): The pytorch model that needs to be loaded.
676
+ strict (bool, optional): Whether the converted keys are strictly
677
+ matched. Defaults to True.
678
+
679
+ Returns:
680
+ missing_keys (List[str]): a list of str containing the missing keys.
681
+ unexpected_keys (List[str]): a list of str containing the unexpected keys.
682
+
683
+ Raises:
684
+ ValueError: If conversion results in unmapped tensors and strict mode is
685
+ enabled.
686
+ """
687
+ state = self._loader(self._file_name)
688
+ converted_state = dict()
689
+ config: unet_config.DiffusionModelConfig = model.config
690
+ self._map_time_embedding(
691
+ state, converted_state, "time_embedding", self._names.time_embedding
692
+ )
693
+ _map_to_converted_state(state, self._names.conv_in, converted_state, "conv_in")
694
+ _map_to_converted_state(state, self._names.conv_out, converted_state, "conv_out")
695
+ _map_to_converted_state(
696
+ state, self._names.final_norm, converted_state, "final_norm"
697
+ )
698
+
699
+ attention_config = layers_config.AttentionConfig(
700
+ num_heads=config.transformer_num_attention_heads,
701
+ num_query_groups=config.transformer_num_attention_heads,
702
+ rotary_percentage=0.0,
703
+ qkv_transpose_before_split=True,
704
+ qkv_use_bias=False,
705
+ output_proj_use_bias=True,
706
+ enable_kv_cache=False,
707
+ )
708
+
709
+ # Map down_encoders.
710
+ output_channel = config.block_out_channels[0]
711
+ for i, block_out_channel in enumerate(config.block_out_channels):
712
+ input_channel = output_channel
713
+ output_channel = block_out_channel
714
+ not_final_block = i < len(config.block_out_channels) - 1
715
+ if not_final_block:
716
+ down_encoder_block_config = unet_config.DownEncoderBlock2DConfig(
717
+ in_channels=input_channel,
718
+ out_channels=output_channel,
719
+ normalization_config=config.residual_norm_config,
720
+ activation_config=layers_config.ActivationConfig(
721
+ config.residual_activation_type
722
+ ),
723
+ num_layers=config.layers_per_block,
724
+ padding=config.downsample_padding,
725
+ time_embedding_channels=config.time_embedding_blocks_dim,
726
+ add_downsample=True,
727
+ sampling_config=unet_config.DownSamplingConfig(
728
+ mode=unet_config.SamplingType.CONVOLUTION,
729
+ in_channels=output_channel,
730
+ out_channels=output_channel,
731
+ kernel_size=3,
732
+ stride=2,
733
+ padding=config.downsample_padding,
734
+ ),
735
+ transformer_block_config=unet_config.TransformerBlock2DConfig(
736
+ attention_block_config=unet_config.AttentionBlock2DConfig(
737
+ dim=output_channel,
738
+ normalization_config=config.transformer_norm_config,
739
+ attention_config=attention_config,
740
+ ),
741
+ cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
742
+ query_dim=output_channel,
743
+ cross_dim=config.transformer_cross_attention_dim,
744
+ normalization_config=config.transformer_norm_config,
745
+ attention_config=attention_config,
746
+ ),
747
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
748
+ feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
749
+ dim=output_channel,
750
+ hidden_dim=output_channel * 4,
751
+ normalization_config=config.transformer_norm_config,
752
+ activation_config=layers_config.ActivationConfig(
753
+ type=config.transformer_ff_activation_type,
754
+ dim_in=output_channel,
755
+ dim_out=output_channel * 4,
756
+ ),
757
+ use_bias=True,
758
+ ),
759
+ ),
760
+ )
761
+ else:
762
+ down_encoder_block_config = unet_config.DownEncoderBlock2DConfig(
763
+ in_channels=input_channel,
764
+ out_channels=output_channel,
765
+ normalization_config=config.residual_norm_config,
766
+ activation_config=layers_config.ActivationConfig(
767
+ config.residual_activation_type
768
+ ),
769
+ num_layers=config.layers_per_block,
770
+ padding=config.downsample_padding,
771
+ time_embedding_channels=config.time_embedding_blocks_dim,
772
+ add_downsample=False,
773
+ )
774
+
775
+ self._map_down_encoder_block(
776
+ state,
777
+ converted_state,
778
+ f"down_encoders.{i}",
779
+ down_encoder_block_config,
780
+ self._names.down_encoder_blocks_tensor_names[i],
781
+ )
782
+
783
+ # Map mid block.
784
+ mid_block_channels = config.block_out_channels[-1]
785
+ mid_block_config = unet_config.MidBlock2DConfig(
786
+ in_channels=mid_block_channels,
787
+ normalization_config=config.residual_norm_config,
788
+ activation_config=layers_config.ActivationConfig(
789
+ config.residual_activation_type
790
+ ),
791
+ num_layers=config.mid_block_layers,
792
+ time_embedding_channels=config.time_embedding_blocks_dim,
793
+ transformer_block_config=unet_config.TransformerBlock2DConfig(
794
+ attention_block_config=unet_config.AttentionBlock2DConfig(
795
+ dim=mid_block_channels,
796
+ normalization_config=config.transformer_norm_config,
797
+ attention_config=attention_config,
798
+ ),
799
+ cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
800
+ query_dim=mid_block_channels,
801
+ cross_dim=config.transformer_cross_attention_dim,
802
+ normalization_config=config.transformer_norm_config,
803
+ attention_config=attention_config,
804
+ ),
805
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
806
+ feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
807
+ dim=mid_block_channels,
808
+ hidden_dim=mid_block_channels * 4,
809
+ normalization_config=config.transformer_norm_config,
810
+ activation_config=layers_config.ActivationConfig(
811
+ type=config.transformer_ff_activation_type,
812
+ dim_in=mid_block_channels,
813
+ dim_out=mid_block_channels * 4,
814
+ ),
815
+ use_bias=True,
816
+ ),
817
+ ),
818
+ )
819
+ self._map_mid_block(
820
+ state,
821
+ converted_state,
822
+ self._names.mid_block_tensor_names,
823
+ "mid_block",
824
+ mid_block_config,
825
+ )
826
+
827
+ # Map up_decoders.
828
+ reversed_block_out_channels = list(reversed(model.config.block_out_channels))
829
+ up_decoder_layers_per_block = config.layers_per_block + 1
830
+ output_channel = reversed_block_out_channels[0]
831
+ for i, block_out_channel in enumerate(reversed_block_out_channels):
832
+ prev_out_channel = output_channel
833
+ output_channel = block_out_channel
834
+ input_channel = reversed_block_out_channels[
835
+ min(i + 1, len(reversed_block_out_channels) - 1)
836
+ ]
837
+ not_final_block = i < len(reversed_block_out_channels) - 1
838
+ not_first_block = i != 0
839
+ if not_first_block:
840
+ up_encoder_block_config = unet_config.SkipUpDecoderBlock2DConfig(
841
+ in_channels=input_channel,
842
+ out_channels=output_channel,
843
+ prev_out_channels=prev_out_channel,
844
+ normalization_config=config.residual_norm_config,
845
+ activation_config=layers_config.ActivationConfig(
846
+ config.residual_activation_type
847
+ ),
848
+ num_layers=up_decoder_layers_per_block,
849
+ time_embedding_channels=config.time_embedding_blocks_dim,
850
+ add_upsample=not_final_block,
851
+ upsample_conv=True,
852
+ sampling_config=unet_config.UpSamplingConfig(
853
+ mode=unet_config.SamplingType.NEAREST,
854
+ scale_factor=2,
855
+ ),
856
+ transformer_block_config=unet_config.TransformerBlock2DConfig(
857
+ attention_block_config=unet_config.AttentionBlock2DConfig(
858
+ dim=output_channel,
859
+ normalization_config=config.transformer_norm_config,
860
+ attention_config=attention_config,
861
+ ),
862
+ cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
863
+ query_dim=output_channel,
864
+ cross_dim=config.transformer_cross_attention_dim,
865
+ normalization_config=config.transformer_norm_config,
866
+ attention_config=attention_config,
867
+ ),
868
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
869
+ feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
870
+ dim=output_channel,
871
+ hidden_dim=output_channel * 4,
872
+ normalization_config=config.transformer_norm_config,
873
+ activation_config=layers_config.ActivationConfig(
874
+ type=config.transformer_ff_activation_type,
875
+ dim_in=output_channel,
876
+ dim_out=output_channel * 4,
877
+ ),
878
+ use_bias=True,
879
+ ),
880
+ ),
881
+ )
882
+ else:
883
+ up_encoder_block_config = unet_config.SkipUpDecoderBlock2DConfig(
884
+ in_channels=input_channel,
885
+ out_channels=output_channel,
886
+ prev_out_channels=prev_out_channel,
887
+ normalization_config=config.residual_norm_config,
888
+ activation_config=layers_config.ActivationConfig(
889
+ config.residual_activation_type
890
+ ),
891
+ num_layers=up_decoder_layers_per_block,
892
+ time_embedding_channels=config.time_embedding_blocks_dim,
893
+ add_upsample=not_final_block,
894
+ upsample_conv=True,
895
+ sampling_config=unet_config.UpSamplingConfig(
896
+ mode=unet_config.SamplingType.NEAREST, scale_factor=2
897
+ ),
898
+ )
899
+ self._map_skip_up_decoder_block(
900
+ state,
901
+ converted_state,
902
+ f"up_decoders.{i}",
903
+ up_encoder_block_config,
904
+ self._names.up_decoder_blocks_tensor_names[i],
905
+ )
906
+ if strict and state:
907
+ raise ValueError(
908
+ f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
909
+ )
910
+ return model.load_state_dict(converted_state, strict=strict)
911
+
912
+ def _map_time_embedding(
913
+ self,
914
+ state: Dict[str, torch.Tensor],
915
+ converted_state: Dict[str, torch.Tensor],
916
+ converted_state_param_prefix: str,
917
+ tensor_names: TimeEmbeddingTensorNames,
918
+ ):
919
+ _map_to_converted_state(
920
+ state, tensor_names.w1, converted_state, f"{converted_state_param_prefix}.w1"
921
+ )
922
+ _map_to_converted_state(
923
+ state, tensor_names.w2, converted_state, f"{converted_state_param_prefix}.w2"
924
+ )