ai-edge-torch-nightly 0.2.0.dev20240801__py3-none-any.whl → 0.2.0.dev20240803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (89) hide show
  1. ai_edge_torch/__init__.py +1 -0
  2. ai_edge_torch/convert/conversion.py +12 -8
  3. ai_edge_torch/convert/conversion_utils.py +38 -20
  4. ai_edge_torch/convert/converter.py +11 -5
  5. ai_edge_torch/convert/fx_passes/__init__.py +3 -4
  6. ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
  7. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +46 -40
  8. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
  9. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
  10. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
  11. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
  16. ai_edge_torch/convert/test/test_convert.py +39 -16
  17. ai_edge_torch/convert/test/test_convert_composites.py +115 -86
  18. ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
  19. ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
  20. ai_edge_torch/convert/to_channel_last_io.py +6 -2
  21. ai_edge_torch/debug/culprit.py +41 -16
  22. ai_edge_torch/debug/test/test_culprit.py +4 -3
  23. ai_edge_torch/debug/test/test_search_model.py +4 -3
  24. ai_edge_torch/debug/utils.py +3 -1
  25. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
  26. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
  27. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
  28. ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
  29. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
  30. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
  31. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
  32. ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
  33. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
  34. ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
  35. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  36. ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
  37. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +14 -6
  38. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +14 -7
  39. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +41 -16
  40. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  41. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +36 -13
  42. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  43. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  44. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  45. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  46. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  47. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
  48. ai_edge_torch/generative/examples/t5/t5.py +158 -125
  49. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
  50. ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
  52. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
  55. ai_edge_torch/generative/fx_passes/__init__.py +1 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
  57. ai_edge_torch/generative/layers/attention.py +19 -11
  58. ai_edge_torch/generative/layers/builder.py +3 -4
  59. ai_edge_torch/generative/layers/kv_cache.py +4 -3
  60. ai_edge_torch/generative/layers/model_config.py +6 -2
  61. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
  62. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
  63. ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
  64. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  65. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
  66. ai_edge_torch/generative/quantize/example.py +2 -3
  67. ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
  68. ai_edge_torch/generative/test/loader_test.py +5 -4
  69. ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
  70. ai_edge_torch/generative/test/test_model_conversion.py +2 -3
  71. ai_edge_torch/generative/test/test_quantize.py +45 -48
  72. ai_edge_torch/generative/utilities/loader.py +55 -28
  73. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
  74. ai_edge_torch/generative/utilities/t5_loader.py +77 -48
  75. ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
  76. ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
  79. ai_edge_torch/model.py +8 -5
  80. ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
  81. ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
  82. ai_edge_torch/quantize/quant_config.py +6 -2
  83. ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
  84. ai_edge_torch/version.py +16 -0
  85. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/METADATA +1 -1
  86. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/RECORD +89 -88
  87. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/LICENSE +0 -0
  88. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/WHEEL +0 -0
  89. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/top_level.txt +0 -0
@@ -15,9 +15,6 @@
15
15
 
16
16
  import unittest
17
17
 
18
- from parameterized import parameterized
19
- import torch
20
-
21
18
  import ai_edge_torch
22
19
  from ai_edge_torch.generative.examples.test_models import toy_model # NOQA
23
20
  from ai_edge_torch.generative.quantize import quant_recipe
@@ -29,20 +26,20 @@ from ai_edge_torch.generative.quantize.quant_attrs import Granularity
29
26
  from ai_edge_torch.generative.quantize.quant_attrs import Mode
30
27
  from ai_edge_torch.quantize import quant_config
31
28
  from ai_edge_torch.testing import model_coverage
29
+ from parameterized import parameterized
30
+ import torch
32
31
 
33
32
 
34
33
  class TestVerifyRecipes(unittest.TestCase):
35
34
  """Unit tests that check for model quantization recipes."""
36
35
 
37
- @parameterized.expand(
38
- [
39
- (Dtype.FP32, Dtype.FP32),
40
- (Dtype.INT8, Dtype.INT8),
41
- (Dtype.INT8, Dtype.FP16),
42
- (Dtype.FP16, Dtype.INT8),
43
- (Dtype.FP16, Dtype.FP16),
44
- ]
45
- )
36
+ @parameterized.expand([
37
+ (Dtype.FP32, Dtype.FP32),
38
+ (Dtype.INT8, Dtype.INT8),
39
+ (Dtype.INT8, Dtype.FP16),
40
+ (Dtype.FP16, Dtype.INT8),
41
+ (Dtype.FP16, Dtype.FP16),
42
+ ])
46
43
  def test_verify_invalid_recipes(
47
44
  self,
48
45
  activation,
@@ -54,31 +51,29 @@ class TestVerifyRecipes(unittest.TestCase):
54
51
  with self.assertRaises(ValueError):
55
52
  quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
56
53
 
57
- @parameterized.expand(
58
- [
59
- (
60
- Dtype.FP32,
61
- Dtype.INT8,
62
- Mode.DYNAMIC_RANGE,
63
- Algorithm.MIN_MAX,
64
- Granularity.CHANNELWISE,
65
- ),
66
- (
67
- Dtype.FP32,
68
- Dtype.INT8,
69
- Mode.WEIGHT_ONLY,
70
- Algorithm.MIN_MAX,
71
- Granularity.CHANNELWISE,
72
- ),
73
- (
74
- Dtype.FP32,
75
- Dtype.FP16,
76
- Mode.WEIGHT_ONLY,
77
- Algorithm.FLOAT_CAST,
78
- Granularity.NONE,
79
- ),
80
- ]
81
- )
54
+ @parameterized.expand([
55
+ (
56
+ Dtype.FP32,
57
+ Dtype.INT8,
58
+ Mode.DYNAMIC_RANGE,
59
+ Algorithm.MIN_MAX,
60
+ Granularity.CHANNELWISE,
61
+ ),
62
+ (
63
+ Dtype.FP32,
64
+ Dtype.INT8,
65
+ Mode.WEIGHT_ONLY,
66
+ Algorithm.MIN_MAX,
67
+ Granularity.CHANNELWISE,
68
+ ),
69
+ (
70
+ Dtype.FP32,
71
+ Dtype.FP16,
72
+ Mode.WEIGHT_ONLY,
73
+ Algorithm.FLOAT_CAST,
74
+ Granularity.NONE,
75
+ ),
76
+ ])
82
77
  def test_verify_valid_recipes(
83
78
  self,
84
79
  activation,
@@ -87,7 +82,9 @@ class TestVerifyRecipes(unittest.TestCase):
87
82
  algo,
88
83
  granularity,
89
84
  ):
90
- quant_recipe.LayerQuantRecipe(activation, weight, mode, algo, granularity).verify()
85
+ quant_recipe.LayerQuantRecipe(
86
+ activation, weight, mode, algo, granularity
87
+ ).verify()
91
88
 
92
89
 
93
90
  class TestQuantizeConvert(unittest.TestCase):
@@ -107,15 +104,13 @@ class TestQuantizeConvert(unittest.TestCase):
107
104
  )
108
105
  )
109
106
 
110
- @parameterized.expand(
111
- [
112
- (quant_recipes.full_fp16_recipe()),
113
- (quant_recipes.full_int8_dynamic_recipe()),
114
- (quant_recipes.full_int8_weight_only_recipe()),
115
- (_attention_int8_dynamic_recipe()),
116
- (_feedforward_int8_dynamic_recipe()),
117
- ]
118
- )
107
+ @parameterized.expand([
108
+ (quant_recipes.full_fp16_recipe()),
109
+ (quant_recipes.full_int8_dynamic_recipe()),
110
+ (quant_recipes.full_int8_weight_only_recipe()),
111
+ (_attention_int8_dynamic_recipe()),
112
+ (_feedforward_int8_dynamic_recipe()),
113
+ ])
119
114
  def test_quantize_convert_toy_sizes(self, quant_config):
120
115
  config = toy_model.get_model_config()
121
116
  pytorch_model = toy_model.ToySingleLayerModel(config)
@@ -146,7 +141,9 @@ class TestQuantizeConvert(unittest.TestCase):
146
141
  )
147
142
  float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
148
143
 
149
- self.assertLess(len(quantized_model._tflite_model), len(float_model._tflite_model))
144
+ self.assertLess(
145
+ len(quantized_model._tflite_model), len(float_model._tflite_model)
146
+ )
150
147
  self.assertTrue(
151
148
  model_coverage.compare_tflite_torch(
152
149
  quantized_model,
@@ -18,11 +18,10 @@ import glob
18
18
  import os
19
19
  from typing import Callable, Dict, List, Tuple
20
20
 
21
+ from ai_edge_torch.generative.layers import model_config
21
22
  from safetensors import safe_open
22
23
  import torch
23
24
 
24
- from ai_edge_torch.generative.layers import model_config
25
-
26
25
 
27
26
  def load_safetensors(full_path: str):
28
27
  """Loads safetensors into a single state dictionary.
@@ -158,14 +157,22 @@ class ModelLoader:
158
157
  f"{self._names.embedding_position}"
159
158
  )
160
159
  if self._names.lm_head is not None:
161
- converted_state["lm_head.weight"] = state.pop(f"{self._names.lm_head}.weight")
160
+ converted_state["lm_head.weight"] = state.pop(
161
+ f"{self._names.lm_head}.weight"
162
+ )
162
163
  if model.config.lm_head_use_bias:
163
- converted_state["lm_head.bias"] = state.pop(f"{self._names.lm_head}.bias")
164
+ converted_state["lm_head.bias"] = state.pop(
165
+ f"{self._names.lm_head}.bias"
166
+ )
164
167
  if self._names.final_norm is not None:
165
168
  final_norm_name = self._names.final_norm
166
- converted_state["final_norm.weight"] = state.pop(f"{final_norm_name}.weight")
169
+ converted_state["final_norm.weight"] = state.pop(
170
+ f"{final_norm_name}.weight"
171
+ )
167
172
  if f"{final_norm_name}.bias" in state:
168
- converted_state["final_norm.bias"] = state.pop(f"{final_norm_name}.bias")
173
+ converted_state["final_norm.bias"] = state.pop(
174
+ f"{final_norm_name}.bias"
175
+ )
169
176
 
170
177
  for i in range(model.config.num_layers):
171
178
  self._map_norm(i, model.config, state, converted_state)
@@ -214,18 +221,26 @@ class ModelLoader:
214
221
  if config.ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
215
222
  ff_up_proj_name = self._names.ff_up_proj.format(idx)
216
223
  ff_down_proj_name = self._names.ff_down_proj.format(idx)
217
- converted_state[f"{prefix}.ff.w1.weight"] = state.pop(f"{ff_up_proj_name}.weight")
224
+ converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
225
+ f"{ff_up_proj_name}.weight"
226
+ )
218
227
  converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
219
228
  f"{ff_down_proj_name}.weight"
220
229
  )
221
230
  if config.ff_config.use_bias:
222
- converted_state[f"{prefix}.ff.w1.bias"] = state.pop(f"{ff_up_proj_name}.bias")
223
- converted_state[f"{prefix}.ff.w2.bias"] = state.pop(f"{ff_down_proj_name}.bias")
231
+ converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
232
+ f"{ff_up_proj_name}.bias"
233
+ )
234
+ converted_state[f"{prefix}.ff.w2.bias"] = state.pop(
235
+ f"{ff_down_proj_name}.bias"
236
+ )
224
237
  else:
225
238
  ff_up_proj_name = self._names.ff_up_proj.format(idx)
226
239
  ff_down_proj_name = self._names.ff_down_proj.format(idx)
227
240
  ff_gate_proj_name = self._names.ff_gate_proj.format(idx)
228
- converted_state[f"{prefix}.ff.w3.weight"] = state.pop(f"{ff_up_proj_name}.weight")
241
+ converted_state[f"{prefix}.ff.w3.weight"] = state.pop(
242
+ f"{ff_up_proj_name}.weight"
243
+ )
229
244
  converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
230
245
  f"{ff_down_proj_name}.weight"
231
246
  )
@@ -233,9 +248,15 @@ class ModelLoader:
233
248
  f"{ff_gate_proj_name}.weight"
234
249
  )
235
250
  if config.ff_config.use_bias:
236
- converted_state[f"{prefix}.ff.w3.bias"] = state.pop(f"{ff_up_proj_name}.bias")
237
- converted_state[f"{prefix}.ff.w2.bias"] = state.pop(f"{ff_down_proj_name}.bias")
238
- converted_state[f"{prefix}.ff.w1.bias"] = state.pop(f"{ff_gate_proj_name}.bias")
251
+ converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
252
+ f"{ff_up_proj_name}.bias"
253
+ )
254
+ converted_state[f"{prefix}.ff.w2.bias"] = state.pop(
255
+ f"{ff_down_proj_name}.bias"
256
+ )
257
+ converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
258
+ f"{ff_gate_proj_name}.bias"
259
+ )
239
260
 
240
261
  def _map_attention(
241
262
  self,
@@ -254,11 +275,13 @@ class ModelLoader:
254
275
  q_name = self._names.attn_query_proj.format(idx)
255
276
  k_name = self._names.attn_key_proj.format(idx)
256
277
  v_name = self._names.attn_value_proj.format(idx)
257
- converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv(
258
- config,
259
- state.pop(f"{q_name}.weight"),
260
- state.pop(f"{k_name}.weight"),
261
- state.pop(f"{v_name}.weight"),
278
+ converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = (
279
+ self._fuse_qkv(
280
+ config,
281
+ state.pop(f"{q_name}.weight"),
282
+ state.pop(f"{k_name}.weight"),
283
+ state.pop(f"{v_name}.weight"),
284
+ )
262
285
  )
263
286
  if config.attn_config.qkv_use_bias:
264
287
  if self._names.attn_fused_qkv_proj:
@@ -266,20 +289,22 @@ class ModelLoader:
266
289
  f"{fused_qkv_name}.bias"
267
290
  )
268
291
  else:
269
- converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
270
- config,
271
- state.pop(f"{q_name}.bias"),
272
- state.pop(f"{k_name}.bias"),
273
- state.pop(f"{v_name}.bias"),
292
+ converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = (
293
+ self._fuse_qkv(
294
+ config,
295
+ state.pop(f"{q_name}.bias"),
296
+ state.pop(f"{k_name}.bias"),
297
+ state.pop(f"{v_name}.bias"),
298
+ )
274
299
  )
275
300
 
276
301
  o_name = self._names.attn_output_proj.format(idx)
277
- converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
278
- f"{o_name}.weight"
302
+ converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
303
+ state.pop(f"{o_name}.weight")
279
304
  )
280
305
  if config.attn_config.output_proj_use_bias:
281
- converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
282
- f"{o_name}.bias"
306
+ converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
307
+ state.pop(f"{o_name}.bias")
283
308
  )
284
309
 
285
310
  def _map_norm(
@@ -318,7 +343,9 @@ class ModelLoader:
318
343
  v: torch.Tensor,
319
344
  ) -> torch.Tensor:
320
345
  if config.attn_config.qkv_fused_interleaved:
321
- q_per_kv = config.attn_config.num_heads // config.attn_config.num_query_groups
346
+ q_per_kv = (
347
+ config.attn_config.num_heads // config.attn_config.num_query_groups
348
+ )
322
349
  qs = torch.split(q, config.head_dim * q_per_kv)
323
350
  ks = torch.split(k, config.head_dim)
324
351
  vs = torch.split(v, config.head_dim)
@@ -16,11 +16,10 @@
16
16
  from dataclasses import dataclass
17
17
  from typing import Dict, List, Optional, Tuple
18
18
 
19
- import torch
20
-
21
19
  import ai_edge_torch.generative.layers.model_config as layers_config
22
20
  import ai_edge_torch.generative.layers.unet.model_config as unet_config
23
21
  import ai_edge_torch.generative.utilities.loader as loader
22
+ import torch
24
23
 
25
24
 
26
25
  @dataclass
@@ -80,27 +79,35 @@ class TransformerBlockTensorNames:
80
79
  class MidBlockTensorNames:
81
80
  residual_block_tensor_names: List[ResidualBlockTensorNames]
82
81
  attention_block_tensor_names: Optional[List[AttentionBlockTensorNames]] = None
83
- transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None
82
+ transformer_block_tensor_names: Optional[
83
+ List[TransformerBlockTensorNames]
84
+ ] = None
84
85
 
85
86
 
86
87
  @dataclass
87
88
  class DownEncoderBlockTensorNames:
88
89
  residual_block_tensor_names: List[ResidualBlockTensorNames]
89
- transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None
90
+ transformer_block_tensor_names: Optional[
91
+ List[TransformerBlockTensorNames]
92
+ ] = None
90
93
  downsample_conv: str = None
91
94
 
92
95
 
93
96
  @dataclass
94
97
  class UpDecoderBlockTensorNames:
95
98
  residual_block_tensor_names: List[ResidualBlockTensorNames]
96
- transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None
99
+ transformer_block_tensor_names: Optional[
100
+ List[TransformerBlockTensorNames]
101
+ ] = None
97
102
  upsample_conv: str = None
98
103
 
99
104
 
100
105
  @dataclass
101
106
  class SkipUpDecoderBlockTensorNames:
102
107
  residual_block_tensor_names: List[ResidualBlockTensorNames]
103
- transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None
108
+ transformer_block_tensor_names: Optional[
109
+ List[TransformerBlockTensorNames]
110
+ ] = None
104
111
  upsample_conv: str = None
105
112
 
106
113
 
@@ -119,7 +126,9 @@ def _map_to_converted_state(
119
126
  converted_state[f"{converted_state_param}.weight"]
120
127
  )
121
128
  if f"{state_param}.bias" in state:
122
- converted_state[f"{converted_state_param}.bias"] = state.pop(f"{state_param}.bias")
129
+ converted_state[f"{converted_state_param}.bias"] = state.pop(
130
+ f"{state_param}.bias"
131
+ )
123
132
  if squeeze_dims:
124
133
  converted_state[f"{converted_state_param}.bias"] = torch.squeeze(
125
134
  converted_state[f"{converted_state_param}.bias"]
@@ -220,25 +229,41 @@ class BaseLoader(loader.ModelLoader):
220
229
  f"{attention_layer_prefix}.v_projection",
221
230
  squeeze_dims=True,
222
231
  )
223
- converted_state[f"{attention_layer_prefix}.qkv_projection.weight"] = torch.concat(
224
- [
225
- converted_state[f"{attention_layer_prefix}.q_projection.weight"],
226
- converted_state[f"{attention_layer_prefix}.k_projection.weight"],
227
- converted_state[f"{attention_layer_prefix}.v_projection.weight"],
228
- ],
229
- axis=0,
232
+ converted_state[f"{attention_layer_prefix}.qkv_projection.weight"] = (
233
+ torch.concat(
234
+ [
235
+ converted_state[
236
+ f"{attention_layer_prefix}.q_projection.weight"
237
+ ],
238
+ converted_state[
239
+ f"{attention_layer_prefix}.k_projection.weight"
240
+ ],
241
+ converted_state[
242
+ f"{attention_layer_prefix}.v_projection.weight"
243
+ ],
244
+ ],
245
+ axis=0,
246
+ )
230
247
  )
231
248
  del converted_state[f"{attention_layer_prefix}.q_projection.weight"]
232
249
  del converted_state[f"{attention_layer_prefix}.k_projection.weight"]
233
250
  del converted_state[f"{attention_layer_prefix}.v_projection.weight"]
234
251
  if config.attention_config.qkv_use_bias:
235
- converted_state[f"{attention_layer_prefix}.qkv_projection.bias"] = torch.concat(
236
- [
237
- converted_state[f"{attention_layer_prefix}.q_projection.bias"],
238
- converted_state[f"{attention_layer_prefix}.k_projection.bias"],
239
- converted_state[f"{attention_layer_prefix}.v_projection.bias"],
240
- ],
241
- axis=0,
252
+ converted_state[f"{attention_layer_prefix}.qkv_projection.bias"] = (
253
+ torch.concat(
254
+ [
255
+ converted_state[
256
+ f"{attention_layer_prefix}.q_projection.bias"
257
+ ],
258
+ converted_state[
259
+ f"{attention_layer_prefix}.k_projection.bias"
260
+ ],
261
+ converted_state[
262
+ f"{attention_layer_prefix}.v_projection.bias"
263
+ ],
264
+ ],
265
+ axis=0,
266
+ )
242
267
  )
243
268
  del converted_state[f"{attention_layer_prefix}.q_projection.bias"]
244
269
  del converted_state[f"{attention_layer_prefix}.k_projection.bias"]
@@ -316,11 +341,17 @@ class BaseLoader(loader.ModelLoader):
316
341
  )
317
342
  else:
318
343
  _map_to_converted_state(
319
- state, tensor_names.w1, converted_state, f"{converted_state_param_prefix}.w1"
344
+ state,
345
+ tensor_names.w1,
346
+ converted_state,
347
+ f"{converted_state_param_prefix}.w1",
320
348
  )
321
349
 
322
350
  _map_to_converted_state(
323
- state, tensor_names.w2, converted_state, f"{converted_state_param_prefix}.w2"
351
+ state,
352
+ tensor_names.w2,
353
+ converted_state,
354
+ f"{converted_state_param_prefix}.w2",
324
355
  )
325
356
 
326
357
  def _map_transformer_block(
@@ -509,9 +540,13 @@ class BaseLoader(loader.ModelLoader):
509
540
  ):
510
541
  for i in range(config.num_layers):
511
542
  res_skip_channels = (
512
- config.in_channels if (i == config.num_layers - 1) else config.out_channels
543
+ config.in_channels
544
+ if (i == config.num_layers - 1)
545
+ else config.out_channels
546
+ )
547
+ resnet_in_channels = (
548
+ config.prev_out_channels if i == 0 else config.out_channels
513
549
  )
514
- resnet_in_channels = config.prev_out_channels if i == 0 else config.out_channels
515
550
  self._map_residual_block(
516
551
  state,
517
552
  converted_state,
@@ -599,9 +634,13 @@ class AutoEncoderModelLoader(BaseLoader):
599
634
  state, self._names.post_quant_conv, converted_state, "post_quant_conv"
600
635
  )
601
636
  if self._names.conv_in is not None:
602
- _map_to_converted_state(state, self._names.conv_in, converted_state, "conv_in")
637
+ _map_to_converted_state(
638
+ state, self._names.conv_in, converted_state, "conv_in"
639
+ )
603
640
  if self._names.conv_out is not None:
604
- _map_to_converted_state(state, self._names.conv_out, converted_state, "conv_out")
641
+ _map_to_converted_state(
642
+ state, self._names.conv_out, converted_state, "conv_out"
643
+ )
605
644
  if self._names.final_norm is not None:
606
645
  _map_to_converted_state(
607
646
  state, self._names.final_norm, converted_state, "final_norm"
@@ -614,7 +653,9 @@ class AutoEncoderModelLoader(BaseLoader):
614
653
  model.config.mid_block_config,
615
654
  )
616
655
 
617
- reversed_block_out_channels = list(reversed(model.config.block_out_channels))
656
+ reversed_block_out_channels = list(
657
+ reversed(model.config.block_out_channels)
658
+ )
618
659
  block_out_channels = reversed_block_out_channels[0]
619
660
  for i, out_channels in enumerate(reversed_block_out_channels):
620
661
  prev_output_channel = block_out_channels
@@ -690,8 +731,12 @@ class DiffusionModelLoader(BaseLoader):
690
731
  self._map_time_embedding(
691
732
  state, converted_state, "time_embedding", self._names.time_embedding
692
733
  )
693
- _map_to_converted_state(state, self._names.conv_in, converted_state, "conv_in")
694
- _map_to_converted_state(state, self._names.conv_out, converted_state, "conv_out")
734
+ _map_to_converted_state(
735
+ state, self._names.conv_in, converted_state, "conv_in"
736
+ )
737
+ _map_to_converted_state(
738
+ state, self._names.conv_out, converted_state, "conv_out"
739
+ )
695
740
  _map_to_converted_state(
696
741
  state, self._names.final_norm, converted_state, "final_norm"
697
742
  )
@@ -825,7 +870,9 @@ class DiffusionModelLoader(BaseLoader):
825
870
  )
826
871
 
827
872
  # Map up_decoders.
828
- reversed_block_out_channels = list(reversed(model.config.block_out_channels))
873
+ reversed_block_out_channels = list(
874
+ reversed(model.config.block_out_channels)
875
+ )
829
876
  up_decoder_layers_per_block = config.layers_per_block + 1
830
877
  output_channel = reversed_block_out_channels[0]
831
878
  for i, block_out_channel in enumerate(reversed_block_out_channels):
@@ -917,8 +964,14 @@ class DiffusionModelLoader(BaseLoader):
917
964
  tensor_names: TimeEmbeddingTensorNames,
918
965
  ):
919
966
  _map_to_converted_state(
920
- state, tensor_names.w1, converted_state, f"{converted_state_param_prefix}.w1"
967
+ state,
968
+ tensor_names.w1,
969
+ converted_state,
970
+ f"{converted_state_param_prefix}.w1",
921
971
  )
922
972
  _map_to_converted_state(
923
- state, tensor_names.w2, converted_state, f"{converted_state_param_prefix}.w2"
973
+ state,
974
+ tensor_names.w2,
975
+ converted_state,
976
+ f"{converted_state_param_prefix}.w2",
924
977
  )