ai-edge-torch-nightly 0.2.0.dev20240611__py3-none-any.whl → 0.2.0.dev20240617__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (21) hide show
  1. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +19 -0
  2. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +9 -2
  3. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +9 -6
  4. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +33 -25
  5. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +523 -202
  6. ai_edge_torch/generative/examples/t5/t5_attention.py +10 -39
  7. ai_edge_torch/generative/layers/attention.py +154 -26
  8. ai_edge_torch/generative/layers/model_config.py +3 -0
  9. ai_edge_torch/generative/layers/unet/blocks_2d.py +473 -49
  10. ai_edge_torch/generative/layers/unet/builder.py +20 -2
  11. ai_edge_torch/generative/layers/unet/model_config.py +157 -5
  12. ai_edge_torch/generative/test/test_model_conversion.py +24 -0
  13. ai_edge_torch/generative/test/test_quantize.py +1 -0
  14. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +860 -0
  15. ai_edge_torch/generative/utilities/t5_loader.py +33 -17
  16. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/METADATA +1 -1
  17. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/RECORD +20 -20
  18. ai_edge_torch/generative/utilities/autoencoder_loader.py +0 -298
  19. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/LICENSE +0 -0
  20. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/WHEEL +0 -0
  21. {ai_edge_torch_nightly-0.2.0.dev20240611.dist-info → ai_edge_torch_nightly-0.2.0.dev20240617.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,860 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ # Common utility functions for data loading etc.
16
+ from dataclasses import dataclass
17
+ from typing import Dict, List, Optional, Tuple
18
+
19
+ import torch
20
+
21
+ import ai_edge_torch.generative.layers.model_config as layers_config
22
+ import ai_edge_torch.generative.layers.unet.model_config as unet_config
23
+ import ai_edge_torch.generative.utilities.loader as loader
24
+
25
+
26
+ @dataclass
27
+ class ResidualBlockTensorNames:
28
+ norm_1: str = None
29
+ conv_1: str = None
30
+ norm_2: str = None
31
+ conv_2: str = None
32
+ residual_layer: str = None
33
+ time_embedding: str = None
34
+
35
+
36
+ @dataclass
37
+ class AttentionBlockTensorNames:
38
+ norm: str = None
39
+ fused_qkv_proj: str = None
40
+ output_proj: str = None
41
+
42
+
43
+ @dataclass
44
+ class CrossAttentionBlockTensorNames:
45
+ norm: str = None
46
+ q_proj: str = None
47
+ k_proj: str = None
48
+ v_proj: str = None
49
+ output_proj: str = None
50
+
51
+
52
+ @dataclass
53
+ class TimeEmbeddingTensorNames:
54
+ w1: str = None
55
+ w2: str = None
56
+
57
+
58
+ @dataclass
59
+ class FeedForwardBlockTensorNames:
60
+ w1: str = None
61
+ w2: str = None
62
+ norm: str = None
63
+ ge_glu: str = None
64
+
65
+
66
+ @dataclass
67
+ class TransformerBlockTensorNames:
68
+ pre_conv_norm: str
69
+ conv_in: str
70
+ self_attention: AttentionBlockTensorNames
71
+ cross_attention: CrossAttentionBlockTensorNames
72
+ feed_forward: FeedForwardBlockTensorNames
73
+ conv_out: str
74
+
75
+
76
+ @dataclass
77
+ class MidBlockTensorNames:
78
+ residual_block_tensor_names: List[ResidualBlockTensorNames]
79
+ attention_block_tensor_names: Optional[List[AttentionBlockTensorNames]] = None
80
+ transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None
81
+
82
+
83
+ @dataclass
84
+ class DownEncoderBlockTensorNames:
85
+ residual_block_tensor_names: List[ResidualBlockTensorNames]
86
+ transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None
87
+ downsample_conv: str = None
88
+
89
+
90
+ @dataclass
91
+ class UpDecoderBlockTensorNames:
92
+ residual_block_tensor_names: List[ResidualBlockTensorNames]
93
+ transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None
94
+ upsample_conv: str = None
95
+
96
+
97
+ @dataclass
98
+ class SkipUpDecoderBlockTensorNames:
99
+ residual_block_tensor_names: List[ResidualBlockTensorNames]
100
+ transformer_block_tensor_names: Optional[List[TransformerBlockTensorNames]] = None
101
+ upsample_conv: str = None
102
+
103
+
104
+ def _map_to_converted_state(
105
+ state: Dict[str, torch.Tensor],
106
+ state_param: str,
107
+ converted_state: Dict[str, torch.Tensor],
108
+ converted_state_param: str,
109
+ ):
110
+ converted_state[f"{converted_state_param}.weight"] = state.pop(
111
+ f"{state_param}.weight"
112
+ )
113
+ if f"{state_param}.bias" in state:
114
+ converted_state[f"{converted_state_param}.bias"] = state.pop(f"{state_param}.bias")
115
+
116
+
117
+ class BaseLoader(loader.ModelLoader):
118
+
119
+ def _map_residual_block(
120
+ self,
121
+ state: Dict[str, torch.Tensor],
122
+ converted_state: Dict[str, torch.Tensor],
123
+ tensor_names: ResidualBlockTensorNames,
124
+ converted_state_param_prefix: str,
125
+ config: unet_config.ResidualBlock2DConfig,
126
+ ):
127
+ _map_to_converted_state(
128
+ state,
129
+ tensor_names.norm_1,
130
+ converted_state,
131
+ f"{converted_state_param_prefix}.norm_1",
132
+ )
133
+ _map_to_converted_state(
134
+ state,
135
+ tensor_names.conv_1,
136
+ converted_state,
137
+ f"{converted_state_param_prefix}.conv_1",
138
+ )
139
+ _map_to_converted_state(
140
+ state,
141
+ tensor_names.norm_2,
142
+ converted_state,
143
+ f"{converted_state_param_prefix}.norm_2",
144
+ )
145
+ _map_to_converted_state(
146
+ state,
147
+ tensor_names.conv_2,
148
+ converted_state,
149
+ f"{converted_state_param_prefix}.conv_2",
150
+ )
151
+ if config.in_channels != config.out_channels:
152
+ _map_to_converted_state(
153
+ state,
154
+ tensor_names.residual_layer,
155
+ converted_state,
156
+ f"{converted_state_param_prefix}.residual_layer",
157
+ )
158
+ if config.time_embedding_channels is not None:
159
+ _map_to_converted_state(
160
+ state,
161
+ tensor_names.time_embedding,
162
+ converted_state,
163
+ f"{converted_state_param_prefix}.time_emb_proj",
164
+ )
165
+
166
+ def _map_attention_block(
167
+ self,
168
+ state: Dict[str, torch.Tensor],
169
+ converted_state: Dict[str, torch.Tensor],
170
+ tensor_names: AttentionBlockTensorNames,
171
+ converted_state_param_prefix: str,
172
+ config: unet_config.AttentionBlock2DConfig,
173
+ ):
174
+ if config.normalization_config.type != layers_config.NormalizationType.NONE:
175
+ _map_to_converted_state(
176
+ state,
177
+ tensor_names.norm,
178
+ converted_state,
179
+ f"{converted_state_param_prefix}.norm",
180
+ )
181
+ attention_layer_prefix = f"{converted_state_param_prefix}.attention"
182
+ _map_to_converted_state(
183
+ state,
184
+ tensor_names.fused_qkv_proj,
185
+ converted_state,
186
+ f"{attention_layer_prefix}.qkv_projection",
187
+ )
188
+ _map_to_converted_state(
189
+ state,
190
+ tensor_names.output_proj,
191
+ converted_state,
192
+ f"{attention_layer_prefix}.output_projection",
193
+ )
194
+
195
+ def _map_cross_attention_block(
196
+ self,
197
+ state: Dict[str, torch.Tensor],
198
+ converted_state: Dict[str, torch.Tensor],
199
+ tensor_names: CrossAttentionBlockTensorNames,
200
+ converted_state_param_prefix: str,
201
+ config: unet_config.CrossAttentionBlock2DConfig,
202
+ ):
203
+ if config.normalization_config.type != layers_config.NormalizationType.NONE:
204
+ _map_to_converted_state(
205
+ state,
206
+ tensor_names.norm,
207
+ converted_state,
208
+ f"{converted_state_param_prefix}.norm",
209
+ )
210
+ attention_layer_prefix = f"{converted_state_param_prefix}.attention"
211
+ _map_to_converted_state(
212
+ state,
213
+ tensor_names.q_proj,
214
+ converted_state,
215
+ f"{attention_layer_prefix}.q_projection",
216
+ )
217
+ _map_to_converted_state(
218
+ state,
219
+ tensor_names.k_proj,
220
+ converted_state,
221
+ f"{attention_layer_prefix}.k_projection",
222
+ )
223
+ _map_to_converted_state(
224
+ state,
225
+ tensor_names.v_proj,
226
+ converted_state,
227
+ f"{attention_layer_prefix}.v_projection",
228
+ )
229
+ _map_to_converted_state(
230
+ state,
231
+ tensor_names.output_proj,
232
+ converted_state,
233
+ f"{attention_layer_prefix}.output_projection",
234
+ )
235
+
236
+ def _map_feedforward_block(
237
+ self,
238
+ state: Dict[str, torch.Tensor],
239
+ converted_state: Dict[str, torch.Tensor],
240
+ tensor_names: FeedForwardBlockTensorNames,
241
+ converted_state_param_prefix: str,
242
+ config: unet_config.FeedForwardBlock2DConfig,
243
+ ):
244
+ _map_to_converted_state(
245
+ state,
246
+ tensor_names.norm,
247
+ converted_state,
248
+ f"{converted_state_param_prefix}.norm",
249
+ )
250
+ if config.activation_config.type == layers_config.ActivationType.GE_GLU:
251
+ _map_to_converted_state(
252
+ state,
253
+ tensor_names.ge_glu,
254
+ converted_state,
255
+ f"{converted_state_param_prefix}.act.proj",
256
+ )
257
+ else:
258
+ _map_to_converted_state(
259
+ state, tensor_names.w1, converted_state, f"{converted_state_param_prefix}.w1"
260
+ )
261
+
262
+ _map_to_converted_state(
263
+ state, tensor_names.w2, converted_state, f"{converted_state_param_prefix}.w2"
264
+ )
265
+
266
+ def _map_transformer_block(
267
+ self,
268
+ state: Dict[str, torch.Tensor],
269
+ converted_state: Dict[str, torch.Tensor],
270
+ tensor_names: TransformerBlockTensorNames,
271
+ converted_state_param_prefix: str,
272
+ config: unet_config.TransformerBlock2Dconfig,
273
+ ):
274
+ _map_to_converted_state(
275
+ state,
276
+ tensor_names.pre_conv_norm,
277
+ converted_state,
278
+ f"{converted_state_param_prefix}.pre_conv_norm",
279
+ )
280
+ _map_to_converted_state(
281
+ state,
282
+ tensor_names.conv_in,
283
+ converted_state,
284
+ f"{converted_state_param_prefix}.conv_in",
285
+ )
286
+ self._map_attention_block(
287
+ state,
288
+ converted_state,
289
+ tensor_names.self_attention,
290
+ f"{converted_state_param_prefix}.self_attention",
291
+ config.attention_block_config,
292
+ )
293
+ self._map_cross_attention_block(
294
+ state,
295
+ converted_state,
296
+ tensor_names.cross_attention,
297
+ f"{converted_state_param_prefix}.cross_attention",
298
+ config.cross_attention_block_config,
299
+ )
300
+ self._map_feedforward_block(
301
+ state,
302
+ converted_state,
303
+ tensor_names.feed_forward,
304
+ f"{converted_state_param_prefix}.feed_forward",
305
+ config.feed_forward_block_config,
306
+ )
307
+ _map_to_converted_state(
308
+ state,
309
+ tensor_names.conv_out,
310
+ converted_state,
311
+ f"{converted_state_param_prefix}.conv_out",
312
+ )
313
+
314
+ def _map_mid_block(
315
+ self,
316
+ state: Dict[str, torch.Tensor],
317
+ converted_state: Dict[str, torch.Tensor],
318
+ tensor_names: MidBlockTensorNames,
319
+ converted_state_param_prefix: str,
320
+ config: unet_config.MidBlock2DConfig,
321
+ ):
322
+ residual_block_config = unet_config.ResidualBlock2DConfig(
323
+ in_channels=config.in_channels,
324
+ out_channels=config.in_channels,
325
+ time_embedding_channels=config.time_embedding_channels,
326
+ normalization_config=config.normalization_config,
327
+ activation_config=config.activation_config,
328
+ )
329
+ self._map_residual_block(
330
+ state,
331
+ converted_state,
332
+ tensor_names.residual_block_tensor_names[0],
333
+ f"{converted_state_param_prefix}.resnets.0",
334
+ residual_block_config,
335
+ )
336
+ for i in range(config.num_layers):
337
+ if config.attention_block_config:
338
+ self._map_attention_block(
339
+ state,
340
+ converted_state,
341
+ tensor_names.attention_block_tensor_names[i],
342
+ f"{converted_state_param_prefix}.attentions.{i}",
343
+ config.attention_block_config,
344
+ )
345
+ if config.transformer_block_config:
346
+ self._map_transformer_block(
347
+ state,
348
+ converted_state,
349
+ tensor_names.transformer_block_tensor_names[i],
350
+ f"{converted_state_param_prefix}.transformers.{i}",
351
+ config.transformer_block_config,
352
+ )
353
+ self._map_residual_block(
354
+ state,
355
+ converted_state,
356
+ tensor_names.residual_block_tensor_names[i + 1],
357
+ f"{converted_state_param_prefix}.resnets.{i+1}",
358
+ residual_block_config,
359
+ )
360
+
361
+ def _map_down_encoder_block(
362
+ self,
363
+ state: Dict[str, torch.Tensor],
364
+ converted_state: Dict[str, torch.Tensor],
365
+ converted_state_param_prefix: str,
366
+ config: unet_config.DownEncoderBlock2DConfig,
367
+ tensor_names: DownEncoderBlockTensorNames,
368
+ ):
369
+ for i in range(config.num_layers):
370
+ input_channels = config.in_channels if i == 0 else config.out_channels
371
+ self._map_residual_block(
372
+ state,
373
+ converted_state,
374
+ tensor_names.residual_block_tensor_names[i],
375
+ f"{converted_state_param_prefix}.resnets.{i}",
376
+ unet_config.ResidualBlock2DConfig(
377
+ in_channels=input_channels,
378
+ out_channels=config.out_channels,
379
+ time_embedding_channels=config.time_embedding_channels,
380
+ normalization_config=config.normalization_config,
381
+ activation_config=config.activation_config,
382
+ ),
383
+ )
384
+ if config.transformer_block_config:
385
+ self._map_transformer_block(
386
+ state,
387
+ converted_state,
388
+ tensor_names.transformer_block_tensor_names[i],
389
+ f"{converted_state_param_prefix}.transformers.{i}",
390
+ config.transformer_block_config,
391
+ )
392
+ if (
393
+ config.add_downsample
394
+ and config.sampling_config.mode == unet_config.SamplingType.CONVOLUTION
395
+ ):
396
+ _map_to_converted_state(
397
+ state,
398
+ tensor_names.downsample_conv,
399
+ converted_state,
400
+ f"{converted_state_param_prefix}.downsampler",
401
+ )
402
+
403
+ def _map_up_decoder_block(
404
+ self,
405
+ state: Dict[str, torch.Tensor],
406
+ converted_state: Dict[str, torch.Tensor],
407
+ converted_state_param_prefix: str,
408
+ config: unet_config.UpDecoderBlock2DConfig,
409
+ tensor_names: UpDecoderBlockTensorNames,
410
+ ):
411
+ for i in range(config.num_layers):
412
+ input_channels = config.in_channels if i == 0 else config.out_channels
413
+ self._map_residual_block(
414
+ state,
415
+ converted_state,
416
+ tensor_names.residual_block_tensor_names[i],
417
+ f"{converted_state_param_prefix}.resnets.{i}",
418
+ unet_config.ResidualBlock2DConfig(
419
+ in_channels=input_channels,
420
+ out_channels=config.out_channels,
421
+ time_embedding_channels=config.time_embedding_channels,
422
+ normalization_config=config.normalization_config,
423
+ activation_config=config.activation_config,
424
+ ),
425
+ )
426
+ if config.transformer_block_config:
427
+ self._map_transformer_block(
428
+ state,
429
+ converted_state,
430
+ tensor_names.transformer_block_tensor_names[i],
431
+ f"{converted_state_param_prefix}.transformers.{i}",
432
+ config.transformer_block_config,
433
+ )
434
+ if config.add_upsample and config.upsample_conv:
435
+ _map_to_converted_state(
436
+ state,
437
+ tensor_names.upsample_conv,
438
+ converted_state,
439
+ f"{converted_state_param_prefix}.upsample_conv",
440
+ )
441
+
442
+ def _map_skip_up_decoder_block(
443
+ self,
444
+ state: Dict[str, torch.Tensor],
445
+ converted_state: Dict[str, torch.Tensor],
446
+ converted_state_param_prefix: str,
447
+ config: unet_config.SkipUpDecoderBlock2DConfig,
448
+ tensor_names: UpDecoderBlockTensorNames,
449
+ ):
450
+ for i in range(config.num_layers):
451
+ res_skip_channels = (
452
+ config.in_channels if (i == config.num_layers - 1) else config.out_channels
453
+ )
454
+ resnet_in_channels = config.prev_out_channels if i == 0 else config.out_channels
455
+ self._map_residual_block(
456
+ state,
457
+ converted_state,
458
+ tensor_names.residual_block_tensor_names[i],
459
+ f"{converted_state_param_prefix}.resnets.{i}",
460
+ unet_config.ResidualBlock2DConfig(
461
+ in_channels=resnet_in_channels + res_skip_channels,
462
+ out_channels=config.out_channels,
463
+ time_embedding_channels=config.time_embedding_channels,
464
+ normalization_config=config.normalization_config,
465
+ activation_config=config.activation_config,
466
+ ),
467
+ )
468
+ if config.transformer_block_config:
469
+ self._map_transformer_block(
470
+ state,
471
+ converted_state,
472
+ tensor_names.transformer_block_tensor_names[i],
473
+ f"{converted_state_param_prefix}.transformers.{i}",
474
+ config.transformer_block_config,
475
+ )
476
+ if config.add_upsample and config.upsample_conv:
477
+ _map_to_converted_state(
478
+ state,
479
+ tensor_names.upsample_conv,
480
+ converted_state,
481
+ f"{converted_state_param_prefix}.upsample_conv",
482
+ )
483
+
484
+
485
+ class AutoEncoderModelLoader(BaseLoader):
486
+
487
+ @dataclass
488
+ class TensorNames:
489
+ quant_conv: str = None
490
+ post_quant_conv: str = None
491
+ conv_in: str = None
492
+ conv_out: str = None
493
+ final_norm: str = None
494
+ mid_block_tensor_names: MidBlockTensorNames = None
495
+ up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
496
+
497
+ def __init__(self, file_name: str, names: TensorNames):
498
+ """AutoEncoderModelLoader constructor. Can be used to load encoder and decoder models.
499
+
500
+ Args:
501
+ file_name (str): Path to the checkpoint. Can be a directory or an
502
+ exact file.
503
+ names (TensorNames): An instance of `TensorNames` to determine mappings.
504
+ """
505
+ self._file_name = file_name
506
+ self._names = names
507
+ self._loader = self._get_loader()
508
+
509
+ def load(
510
+ self, model: torch.nn.Module, strict: bool = True
511
+ ) -> Tuple[List[str], List[str]]:
512
+ """Load the model from the checkpoint.
513
+
514
+ Args:
515
+ model (torch.nn.Module): The pytorch model that needs to be loaded.
516
+ strict (bool, optional): Whether the converted keys are strictly
517
+ matched. Defaults to True.
518
+
519
+ Returns:
520
+ missing_keys (List[str]): a list of str containing the missing keys.
521
+ unexpected_keys (List[str]): a list of str containing the unexpected keys.
522
+
523
+ Raises:
524
+ ValueError: If conversion results in unmapped tensors and strict mode is
525
+ enabled.
526
+ """
527
+ state = self._loader(self._file_name)
528
+ converted_state = dict()
529
+ if self._names.quant_conv is not None:
530
+ _map_to_converted_state(
531
+ state, self._names.quant_conv, converted_state, "quant_conv"
532
+ )
533
+ if self._names.post_quant_conv is not None:
534
+ _map_to_converted_state(
535
+ state, self._names.post_quant_conv, converted_state, "post_quant_conv"
536
+ )
537
+ if self._names.conv_in is not None:
538
+ _map_to_converted_state(state, self._names.conv_in, converted_state, "conv_in")
539
+ if self._names.conv_out is not None:
540
+ _map_to_converted_state(state, self._names.conv_out, converted_state, "conv_out")
541
+ if self._names.final_norm is not None:
542
+ _map_to_converted_state(
543
+ state, self._names.final_norm, converted_state, "final_norm"
544
+ )
545
+ self._map_mid_block(
546
+ state,
547
+ converted_state,
548
+ self._names.mid_block_tensor_names,
549
+ "mid_block",
550
+ model.config.mid_block_config,
551
+ )
552
+
553
+ reversed_block_out_channels = list(reversed(model.config.block_out_channels))
554
+ block_out_channels = reversed_block_out_channels[0]
555
+ for i, out_channels in enumerate(reversed_block_out_channels):
556
+ prev_output_channel = block_out_channels
557
+ block_out_channels = out_channels
558
+ not_final_block = i < len(reversed_block_out_channels) - 1
559
+ self._map_up_decoder_block(
560
+ state,
561
+ converted_state,
562
+ f"up_decoder_blocks.{i}",
563
+ unet_config.UpDecoderBlock2DConfig(
564
+ in_channels=prev_output_channel,
565
+ out_channels=block_out_channels,
566
+ normalization_config=model.config.normalization_config,
567
+ activation_config=model.config.activation_config,
568
+ num_layers=model.config.layers_per_block,
569
+ add_upsample=not_final_block,
570
+ upsample_conv=True,
571
+ ),
572
+ self._names.up_decoder_blocks_tensor_names[i],
573
+ )
574
+ if strict and state:
575
+ raise ValueError(
576
+ f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
577
+ )
578
+ return model.load_state_dict(converted_state, strict=strict)
579
+
580
+
581
+ class DiffusionModelLoader(BaseLoader):
582
+
583
+ @dataclass
584
+ class TensorNames:
585
+ time_embedding: TimeEmbeddingTensorNames = None
586
+ conv_in: str = None
587
+ conv_out: str = None
588
+ final_norm: str = None
589
+ down_encoder_blocks_tensor_names: List[DownEncoderBlockTensorNames] = None
590
+ mid_block_tensor_names: MidBlockTensorNames = None
591
+ up_decoder_blocks_tensor_names: List[UpDecoderBlockTensorNames] = None
592
+
593
+ def __init__(self, file_name: str, names: TensorNames):
594
+ """DiffusionModelLoader constructor. Can be used to load diffusion models of Stable Diffusion.
595
+
596
+ Args:
597
+ file_name (str): Path to the checkpoint. Can be a directory or an
598
+ exact file.
599
+ names (TensorNames): An instance of `TensorNames` to determine mappings.
600
+ """
601
+ self._file_name = file_name
602
+ self._names = names
603
+ self._loader = self._get_loader()
604
+
605
+ def load(
606
+ self, model: torch.nn.Module, strict: bool = True
607
+ ) -> Tuple[List[str], List[str]]:
608
+ """Load the model from the checkpoint.
609
+
610
+ Args:
611
+ model (torch.nn.Module): The pytorch model that needs to be loaded.
612
+ strict (bool, optional): Whether the converted keys are strictly
613
+ matched. Defaults to True.
614
+
615
+ Returns:
616
+ missing_keys (List[str]): a list of str containing the missing keys.
617
+ unexpected_keys (List[str]): a list of str containing the unexpected keys.
618
+
619
+ Raises:
620
+ ValueError: If conversion results in unmapped tensors and strict mode is
621
+ enabled.
622
+ """
623
+ state = self._loader(self._file_name)
624
+ converted_state = dict()
625
+ config: unet_config.DiffusionModelConfig = model.config
626
+ self._map_time_embedding(
627
+ state, converted_state, "time_embedding", self._names.time_embedding
628
+ )
629
+ _map_to_converted_state(state, self._names.conv_in, converted_state, "conv_in")
630
+ _map_to_converted_state(state, self._names.conv_out, converted_state, "conv_out")
631
+ _map_to_converted_state(
632
+ state, self._names.final_norm, converted_state, "final_norm"
633
+ )
634
+
635
+ attention_config = layers_config.AttentionConfig(
636
+ num_heads=config.transformer_num_attention_heads,
637
+ num_query_groups=config.transformer_num_attention_heads,
638
+ rotary_percentage=0.0,
639
+ qkv_transpose_before_split=True,
640
+ qkv_use_bias=False,
641
+ output_proj_use_bias=True,
642
+ enable_kv_cache=False,
643
+ )
644
+
645
+ # Map down_encoders.
646
+ output_channel = config.block_out_channels[0]
647
+ for i, block_out_channel in enumerate(config.block_out_channels):
648
+ input_channel = output_channel
649
+ output_channel = block_out_channel
650
+ not_final_block = i < len(config.block_out_channels) - 1
651
+ if not_final_block:
652
+ down_encoder_block_config = unet_config.DownEncoderBlock2DConfig(
653
+ in_channels=input_channel,
654
+ out_channels=output_channel,
655
+ normalization_config=config.residual_norm_config,
656
+ activation_config=layers_config.ActivationConfig(
657
+ config.residual_activation_type
658
+ ),
659
+ num_layers=config.layers_per_block,
660
+ padding=config.downsample_padding,
661
+ time_embedding_channels=config.time_embedding_blocks_dim,
662
+ add_downsample=True,
663
+ sampling_config=unet_config.DownSamplingConfig(
664
+ mode=unet_config.SamplingType.CONVOLUTION,
665
+ in_channels=output_channel,
666
+ out_channels=output_channel,
667
+ kernel_size=3,
668
+ stride=2,
669
+ padding=config.downsample_padding,
670
+ ),
671
+ transformer_block_config=unet_config.TransformerBlock2Dconfig(
672
+ attention_block_config=unet_config.AttentionBlock2DConfig(
673
+ dim=output_channel,
674
+ normalization_config=config.transformer_norm_config,
675
+ attention_config=attention_config,
676
+ ),
677
+ cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
678
+ query_dim=output_channel,
679
+ cross_dim=config.transformer_cross_attention_dim,
680
+ normalization_config=config.transformer_norm_config,
681
+ attention_config=attention_config,
682
+ ),
683
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
684
+ feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
685
+ dim=output_channel,
686
+ hidden_dim=output_channel * 4,
687
+ normalization_config=config.transformer_norm_config,
688
+ activation_config=layers_config.ActivationConfig(
689
+ type=config.transformer_ff_activation_type,
690
+ dim_in=output_channel,
691
+ dim_out=output_channel * 4,
692
+ ),
693
+ use_bias=True,
694
+ ),
695
+ ),
696
+ )
697
+ else:
698
+ down_encoder_block_config = unet_config.DownEncoderBlock2DConfig(
699
+ in_channels=input_channel,
700
+ out_channels=output_channel,
701
+ normalization_config=config.residual_norm_config,
702
+ activation_config=layers_config.ActivationConfig(
703
+ config.residual_activation_type
704
+ ),
705
+ num_layers=config.layers_per_block,
706
+ padding=config.downsample_padding,
707
+ time_embedding_channels=config.time_embedding_blocks_dim,
708
+ add_downsample=False,
709
+ )
710
+
711
+ self._map_down_encoder_block(
712
+ state,
713
+ converted_state,
714
+ f"down_encoders.{i}",
715
+ down_encoder_block_config,
716
+ self._names.down_encoder_blocks_tensor_names[i],
717
+ )
718
+
719
+ # Map mid block.
720
+ mid_block_channels = config.block_out_channels[-1]
721
+ mid_block_config = unet_config.MidBlock2DConfig(
722
+ in_channels=mid_block_channels,
723
+ normalization_config=config.residual_norm_config,
724
+ activation_config=layers_config.ActivationConfig(
725
+ config.residual_activation_type
726
+ ),
727
+ num_layers=config.mid_block_layers,
728
+ time_embedding_channels=config.time_embedding_blocks_dim,
729
+ transformer_block_config=unet_config.TransformerBlock2Dconfig(
730
+ attention_block_config=unet_config.AttentionBlock2DConfig(
731
+ dim=mid_block_channels,
732
+ normalization_config=config.transformer_norm_config,
733
+ attention_config=attention_config,
734
+ ),
735
+ cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
736
+ query_dim=mid_block_channels,
737
+ cross_dim=config.transformer_cross_attention_dim,
738
+ normalization_config=config.transformer_norm_config,
739
+ attention_config=attention_config,
740
+ ),
741
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
742
+ feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
743
+ dim=mid_block_channels,
744
+ hidden_dim=mid_block_channels * 4,
745
+ normalization_config=config.transformer_norm_config,
746
+ activation_config=layers_config.ActivationConfig(
747
+ type=config.transformer_ff_activation_type,
748
+ dim_in=mid_block_channels,
749
+ dim_out=mid_block_channels * 4,
750
+ ),
751
+ use_bias=True,
752
+ ),
753
+ ),
754
+ )
755
+ self._map_mid_block(
756
+ state,
757
+ converted_state,
758
+ self._names.mid_block_tensor_names,
759
+ "mid_block",
760
+ mid_block_config,
761
+ )
762
+
763
+ # Map up_decoders.
764
+ reversed_block_out_channels = list(reversed(model.config.block_out_channels))
765
+ up_decoder_layers_per_block = config.layers_per_block + 1
766
+ output_channel = reversed_block_out_channels[0]
767
+ for i, block_out_channel in enumerate(reversed_block_out_channels):
768
+ prev_out_channel = output_channel
769
+ output_channel = block_out_channel
770
+ input_channel = reversed_block_out_channels[
771
+ min(i + 1, len(reversed_block_out_channels) - 1)
772
+ ]
773
+ not_final_block = i < len(reversed_block_out_channels) - 1
774
+ not_first_block = i != 0
775
+ if not_first_block:
776
+ up_encoder_block_config = unet_config.SkipUpDecoderBlock2DConfig(
777
+ in_channels=input_channel,
778
+ out_channels=output_channel,
779
+ prev_out_channels=prev_out_channel,
780
+ normalization_config=config.residual_norm_config,
781
+ activation_config=layers_config.ActivationConfig(
782
+ config.residual_activation_type
783
+ ),
784
+ num_layers=up_decoder_layers_per_block,
785
+ time_embedding_channels=config.time_embedding_blocks_dim,
786
+ add_upsample=not_final_block,
787
+ upsample_conv=True,
788
+ sampling_config=unet_config.UpSamplingConfig(
789
+ mode=unet_config.SamplingType.NEAREST,
790
+ scale_factor=2,
791
+ ),
792
+ transformer_block_config=unet_config.TransformerBlock2Dconfig(
793
+ attention_block_config=unet_config.AttentionBlock2DConfig(
794
+ dim=output_channel,
795
+ normalization_config=config.transformer_norm_config,
796
+ attention_config=attention_config,
797
+ ),
798
+ cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
799
+ query_dim=output_channel,
800
+ cross_dim=config.transformer_cross_attention_dim,
801
+ normalization_config=config.transformer_norm_config,
802
+ attention_config=attention_config,
803
+ ),
804
+ pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
805
+ feed_forward_block_config=unet_config.FeedForwardBlock2DConfig(
806
+ dim=output_channel,
807
+ hidden_dim=output_channel * 4,
808
+ normalization_config=config.transformer_norm_config,
809
+ activation_config=layers_config.ActivationConfig(
810
+ type=config.transformer_ff_activation_type,
811
+ dim_in=output_channel,
812
+ dim_out=output_channel * 4,
813
+ ),
814
+ use_bias=True,
815
+ ),
816
+ ),
817
+ )
818
+ else:
819
+ up_encoder_block_config = unet_config.SkipUpDecoderBlock2DConfig(
820
+ in_channels=input_channel,
821
+ out_channels=output_channel,
822
+ prev_out_channels=prev_out_channel,
823
+ normalization_config=config.residual_norm_config,
824
+ activation_config=layers_config.ActivationConfig(
825
+ config.residual_activation_type
826
+ ),
827
+ num_layers=up_decoder_layers_per_block,
828
+ time_embedding_channels=config.time_embedding_blocks_dim,
829
+ add_upsample=not_final_block,
830
+ upsample_conv=True,
831
+ sampling_config=unet_config.UpSamplingConfig(
832
+ mode=unet_config.SamplingType.NEAREST, scale_factor=2
833
+ ),
834
+ )
835
+ self._map_skip_up_decoder_block(
836
+ state,
837
+ converted_state,
838
+ f"up_decoders.{i}",
839
+ up_encoder_block_config,
840
+ self._names.up_decoder_blocks_tensor_names[i],
841
+ )
842
+ if strict and state:
843
+ raise ValueError(
844
+ f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
845
+ )
846
+ return model.load_state_dict(converted_state, strict=strict)
847
+
848
+ def _map_time_embedding(
849
+ self,
850
+ state: Dict[str, torch.Tensor],
851
+ converted_state: Dict[str, torch.Tensor],
852
+ converted_state_param_prefix: str,
853
+ tensor_names: TimeEmbeddingTensorNames,
854
+ ):
855
+ _map_to_converted_state(
856
+ state, tensor_names.w1, converted_state, f"{converted_state_param_prefix}.w1"
857
+ )
858
+ _map_to_converted_state(
859
+ state, tensor_names.w2, converted_state, f"{converted_state_param_prefix}.w2"
860
+ )