ai-edge-torch-nightly 0.2.0.dev20240801__py3-none-any.whl → 0.2.0.dev20240803__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (89) hide show
  1. ai_edge_torch/__init__.py +1 -0
  2. ai_edge_torch/convert/conversion.py +12 -8
  3. ai_edge_torch/convert/conversion_utils.py +38 -20
  4. ai_edge_torch/convert/converter.py +11 -5
  5. ai_edge_torch/convert/fx_passes/__init__.py +3 -4
  6. ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
  7. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +46 -40
  8. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
  9. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
  10. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
  11. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
  16. ai_edge_torch/convert/test/test_convert.py +39 -16
  17. ai_edge_torch/convert/test/test_convert_composites.py +115 -86
  18. ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
  19. ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
  20. ai_edge_torch/convert/to_channel_last_io.py +6 -2
  21. ai_edge_torch/debug/culprit.py +41 -16
  22. ai_edge_torch/debug/test/test_culprit.py +4 -3
  23. ai_edge_torch/debug/test/test_search_model.py +4 -3
  24. ai_edge_torch/debug/utils.py +3 -1
  25. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
  26. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
  27. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
  28. ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
  29. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
  30. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
  31. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
  32. ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
  33. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
  34. ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
  35. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  36. ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
  37. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +14 -6
  38. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +14 -7
  39. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +41 -16
  40. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  41. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +36 -13
  42. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  43. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  44. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  45. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  46. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  47. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
  48. ai_edge_torch/generative/examples/t5/t5.py +158 -125
  49. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
  50. ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
  52. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
  55. ai_edge_torch/generative/fx_passes/__init__.py +1 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
  57. ai_edge_torch/generative/layers/attention.py +19 -11
  58. ai_edge_torch/generative/layers/builder.py +3 -4
  59. ai_edge_torch/generative/layers/kv_cache.py +4 -3
  60. ai_edge_torch/generative/layers/model_config.py +6 -2
  61. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
  62. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
  63. ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
  64. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  65. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
  66. ai_edge_torch/generative/quantize/example.py +2 -3
  67. ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
  68. ai_edge_torch/generative/test/loader_test.py +5 -4
  69. ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
  70. ai_edge_torch/generative/test/test_model_conversion.py +2 -3
  71. ai_edge_torch/generative/test/test_quantize.py +45 -48
  72. ai_edge_torch/generative/utilities/loader.py +55 -28
  73. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
  74. ai_edge_torch/generative/utilities/t5_loader.py +77 -48
  75. ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
  76. ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
  77. ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
  78. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
  79. ai_edge_torch/model.py +8 -5
  80. ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
  81. ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
  82. ai_edge_torch/quantize/quant_config.py +6 -2
  83. ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
  84. ai_edge_torch/version.py +16 -0
  85. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/METADATA +1 -1
  86. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/RECORD +89 -88
  87. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/LICENSE +0 -0
  88. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/WHEEL +0 -0
  89. {ai_edge_torch_nightly-0.2.0.dev20240801.dist-info → ai_edge_torch_nightly-0.2.0.dev20240803.dist-info}/top_level.txt +0 -0
@@ -18,11 +18,10 @@ import glob
18
18
  import os
19
19
  from typing import Callable, Dict
20
20
 
21
+ from ai_edge_torch.generative.layers import model_config
21
22
  from safetensors import safe_open
22
23
  import torch
23
24
 
24
- from ai_edge_torch.generative.layers import model_config
25
-
26
25
 
27
26
  def load_safetensors(full_path: str):
28
27
  """Loads safetensors into a single state dictionary.
@@ -71,7 +70,11 @@ def load_pytorch_statedict(full_path: str):
71
70
  Raises:
72
71
  ValueError: If no tensors are loaded from the provided directory or file.
73
72
  """
74
- pattern = os.path.join(full_path, "*.bin") if os.path.isdir(full_path) else full_path
73
+ pattern = (
74
+ os.path.join(full_path, "*.bin")
75
+ if os.path.isdir(full_path)
76
+ else full_path
77
+ )
75
78
  files = []
76
79
  for file in glob.glob(pattern):
77
80
  files.append(file)
@@ -131,7 +134,10 @@ class ModelLoader:
131
134
  self._loader = self._get_loader()
132
135
 
133
136
  def load(
134
- self, model: torch.nn.Module, strict: bool = True, fuse_attention: bool = True
137
+ self,
138
+ model: torch.nn.Module,
139
+ strict: bool = True,
140
+ fuse_attention: bool = True,
135
141
  ):
136
142
  """Load the model from the checkpoint
137
143
 
@@ -166,11 +172,14 @@ class ModelLoader:
166
172
 
167
173
  if strict and state:
168
174
  raise ValueError(
169
- f"Failed to map all tensor. Remaining tensor are: {list(state.keys())}"
175
+ "Failed to map all tensor. Remaining tensor are:"
176
+ f" {list(state.keys())}"
170
177
  )
171
178
  model.load_state_dict(converted_state, strict=strict)
172
179
 
173
- def _do_load(self, model, state, names, additional_prefix="", fuse_attention=True):
180
+ def _do_load(
181
+ self, model, state, names, additional_prefix="", fuse_attention=True
182
+ ):
174
183
  """Load the model from the checkpoint
175
184
 
176
185
  Args:
@@ -183,7 +192,9 @@ class ModelLoader:
183
192
  """
184
193
  converted_state = dict()
185
194
  if names.embedding is not None:
186
- converted_state["tok_embedding.weight"] = state.pop(f"{names.embedding}.weight")
195
+ converted_state["tok_embedding.weight"] = state.pop(
196
+ f"{names.embedding}.weight"
197
+ )
187
198
  if names.lm_head is not None:
188
199
  converted_state["lm_head.weight"] = state.pop(f"{names.lm_head}.weight")
189
200
  if model.config.lm_head_use_bias:
@@ -195,7 +206,9 @@ class ModelLoader:
195
206
  f"{final_norm_name}.weight"
196
207
  )
197
208
  if f"{final_norm_name}.bias" in state:
198
- converted_state["final_norm.bias"] = state.pop(f"{final_norm_name}.bias")
209
+ converted_state["final_norm.bias"] = state.pop(
210
+ f"{final_norm_name}.bias"
211
+ )
199
212
 
200
213
  if names.relative_attn_bias:
201
214
  rel_attn_name = names.relative_attn_bias
@@ -205,7 +218,9 @@ class ModelLoader:
205
218
  )
206
219
 
207
220
  for i in range(model.config.num_layers):
208
- self._map_norm(i, model.config, state, converted_state, names, additional_prefix)
221
+ self._map_norm(
222
+ i, model.config, state, converted_state, names, additional_prefix
223
+ )
209
224
  self._map_feedforward(
210
225
  i, model.config, state, converted_state, names, additional_prefix
211
226
  )
@@ -268,13 +283,19 @@ class ModelLoader:
268
283
  if config.ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
269
284
  ff_up_proj_name = names.ff_up_proj.format(idx)
270
285
  ff_down_proj_name = names.ff_down_proj.format(idx)
271
- converted_state[f"{prefix}.ff.w1.weight"] = state.pop(f"{ff_up_proj_name}.weight")
286
+ converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
287
+ f"{ff_up_proj_name}.weight"
288
+ )
272
289
  converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
273
290
  f"{ff_down_proj_name}.weight"
274
291
  )
275
292
  if config.ff_config.use_bias:
276
- converted_state[f"{prefix}.ff.w1.bias"] = state.pop(f"{ff_up_proj_name}.bias")
277
- converted_state[f"{prefix}.ff.w2.bias"] = state.pop(f"{ff_down_proj_name}.bias")
293
+ converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
294
+ f"{ff_up_proj_name}.bias"
295
+ )
296
+ converted_state[f"{prefix}.ff.w2.bias"] = state.pop(
297
+ f"{ff_down_proj_name}.bias"
298
+ )
278
299
  else:
279
300
  if names.ff_gate_proj is not None:
280
301
  ff_up_proj_name = names.ff_up_proj.format(idx)
@@ -290,7 +311,9 @@ class ModelLoader:
290
311
  f"{ff_gate_proj_name}.weight"
291
312
  )
292
313
  if config.ff_config.use_bias:
293
- converted_state[f"{prefix}.ff.w3.bias"] = state.pop(f"{ff_up_proj_name}.bias")
314
+ converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
315
+ f"{ff_up_proj_name}.bias"
316
+ )
294
317
  converted_state[f"{prefix}.ff.w2.bias"] = state.pop(
295
318
  f"{ff_down_proj_name}.bias"
296
319
  )
@@ -355,12 +378,12 @@ class ModelLoader:
355
378
  )
356
379
 
357
380
  o_name = names.attn_output_proj.format(idx)
358
- converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
359
- f"{o_name}.weight"
381
+ converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
382
+ state.pop(f"{o_name}.weight")
360
383
  )
361
384
  if config.attn_config.output_proj_use_bias:
362
- converted_state[f"{prefix}.atten_func.output_projection.bias"] = state.pop(
363
- f"{o_name}.bias"
385
+ converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
386
+ state.pop(f"{o_name}.bias")
364
387
  )
365
388
 
366
389
  def _map_cross_attention(
@@ -385,47 +408,51 @@ class ModelLoader:
385
408
  v_name = names.cross_attn_value_proj.format(idx)
386
409
 
387
410
  if fuse_attention:
388
- converted_state[f"{prefix}.cross_atten_func.attn.weight"] = self._fuse_qkv(
389
- config,
390
- state.pop(f"{q_name}.weight"),
391
- state.pop(f"{k_name}.weight"),
392
- state.pop(f"{v_name}.weight"),
411
+ converted_state[f"{prefix}.cross_atten_func.attn.weight"] = (
412
+ self._fuse_qkv(
413
+ config,
414
+ state.pop(f"{q_name}.weight"),
415
+ state.pop(f"{k_name}.weight"),
416
+ state.pop(f"{v_name}.weight"),
417
+ )
393
418
  )
394
419
  if config.attn_config.qkv_use_bias:
395
- converted_state[f"{prefix}.cross_atten_func.attn.bias"] = self._fuse_qkv(
396
- config,
397
- state.pop(f"{q_name}.bias"),
398
- state.pop(f"{k_name}.bias"),
399
- state.pop(f"{v_name}.bias"),
420
+ converted_state[f"{prefix}.cross_atten_func.attn.bias"] = (
421
+ self._fuse_qkv(
422
+ config,
423
+ state.pop(f"{q_name}.bias"),
424
+ state.pop(f"{k_name}.bias"),
425
+ state.pop(f"{v_name}.bias"),
426
+ )
400
427
  )
401
428
  else:
402
- converted_state[f"{prefix}.cross_atten_func.q_projection.weight"] = state.pop(
403
- f"{q_name}.weight"
429
+ converted_state[f"{prefix}.cross_atten_func.q_projection.weight"] = (
430
+ state.pop(f"{q_name}.weight")
404
431
  )
405
- converted_state[f"{prefix}.cross_atten_func.k_projection.weight"] = state.pop(
406
- f"{k_name}.weight"
432
+ converted_state[f"{prefix}.cross_atten_func.k_projection.weight"] = (
433
+ state.pop(f"{k_name}.weight")
407
434
  )
408
- converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = state.pop(
409
- f"{v_name}.weight"
435
+ converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = (
436
+ state.pop(f"{v_name}.weight")
410
437
  )
411
438
  if config.attn_config.qkv_use_bias:
412
- converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = state.pop(
413
- f"{q_name}.bias"
439
+ converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = (
440
+ state.pop(f"{q_name}.bias")
414
441
  )
415
- converted_state[f"{prefix}.cross_atten_func.k_projection.bias"] = state.pop(
416
- f"{k_name}.bias"
442
+ converted_state[f"{prefix}.cross_atten_func.k_projection.bias"] = (
443
+ state.pop(f"{k_name}.bias")
417
444
  )
418
- converted_state[f"{prefix}.cross_atten_func.v_projection.bias"] = state.pop(
419
- f"{v_name}.bias"
445
+ converted_state[f"{prefix}.cross_atten_func.v_projection.bias"] = (
446
+ state.pop(f"{v_name}.bias")
420
447
  )
421
448
 
422
449
  o_name = names.cross_attn_output_proj.format(idx)
423
- converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = state.pop(
424
- f"{o_name}.weight"
450
+ converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = (
451
+ state.pop(f"{o_name}.weight")
425
452
  )
426
453
  if config.attn_config.output_proj_use_bias:
427
- converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = state.pop(
428
- f"{o_name}.bias"
454
+ converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = (
455
+ state.pop(f"{o_name}.bias")
429
456
  )
430
457
 
431
458
  def _map_norm(
@@ -450,12 +477,12 @@ class ModelLoader:
450
477
 
451
478
  if names.pre_cross_attn_norm:
452
479
  pre_cross_attn_norm_name = names.pre_cross_attn_norm.format(idx)
453
- converted_state[f"{prefix}.cross_atten_func.pre_atten_norm.weight"] = state.pop(
454
- f"{pre_cross_attn_norm_name}.weight"
480
+ converted_state[f"{prefix}.cross_atten_func.pre_atten_norm.weight"] = (
481
+ state.pop(f"{pre_cross_attn_norm_name}.weight")
455
482
  )
456
483
  if f"{pre_cross_attn_norm_name}.bias" in state:
457
- converted_state[f"{prefix}.cross_atten_func.pre_atten_norm.bias"] = state.pop(
458
- f"{pre_cross_attn_norm_name}.bias"
484
+ converted_state[f"{prefix}.cross_atten_func.pre_atten_norm.bias"] = (
485
+ state.pop(f"{pre_cross_attn_norm_name}.bias")
459
486
  )
460
487
 
461
488
  if names.pre_ff_norm is not None:
@@ -475,7 +502,9 @@ class ModelLoader:
475
502
  k: torch.Tensor,
476
503
  v: torch.Tensor,
477
504
  ) -> torch.Tensor:
478
- q_per_kv = config.attn_config.num_heads // config.attn_config.num_query_groups
505
+ q_per_kv = (
506
+ config.attn_config.num_heads // config.attn_config.num_query_groups
507
+ )
479
508
  qs = torch.split(q, config.head_dim * q_per_kv)
480
509
  ks = torch.split(k, config.head_dim)
481
510
  vs = torch.split(v, config.head_dim)
@@ -16,11 +16,10 @@ import copy
16
16
  from typing import Any
17
17
  import uuid
18
18
 
19
- import torch
20
- from torch_xla.experimental import xla_marker
21
-
22
19
  from ai_edge_torch.hlfb.mark_pattern.pattern import Pattern
23
20
  from ai_edge_torch.hlfb.mark_pattern.pattern import ScalarAttrTracker # NOQA
21
+ import torch
22
+ from torch_xla.experimental import xla_marker
24
23
 
25
24
 
26
25
  @torch._dynamo.assume_constant_result
@@ -16,6 +16,7 @@ import copy
16
16
  import dataclasses
17
17
  from typing import Any, Callable, Optional, Union
18
18
 
19
+ from ai_edge_torch.hlfb.mark_pattern import passes
19
20
  import torch
20
21
  from torch.export.graph_signature import TensorArgument
21
22
  from torch.fx import Graph
@@ -23,8 +24,6 @@ from torch.fx import GraphModule
23
24
  from torch.fx.passes.utils.matcher_utils import InternalMatch
24
25
  from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
25
26
 
26
- from ai_edge_torch.hlfb.mark_pattern import passes
27
-
28
27
 
29
28
  def _are_equal(x: Any, y: Any) -> bool:
30
29
  if type(x) != type(y):
@@ -69,7 +68,9 @@ class ScalarAttrTracker:
69
68
  pattern_arg_pos: int
70
69
  transform: Callable = lambda x: x
71
70
  inverse_transform: Callable = lambda x: x
72
- _source_targets: list[tuple[Any, Any]] = dataclasses.field(default_factory=list)
71
+ _source_targets: list[tuple[Any, Any]] = dataclasses.field(
72
+ default_factory=list
73
+ )
73
74
 
74
75
  def track(self, *sources):
75
76
  """Register magic values to track the (transformed) attr values in
@@ -78,7 +79,9 @@ class ScalarAttrTracker:
78
79
  for source in sources:
79
80
  target = self.transform(source)
80
81
  if not _are_equal(self.inverse_transform(target), source):
81
- raise Exception(f"Invalid transform/inverse_transform for {self.attr_name}")
82
+ raise Exception(
83
+ f"Invalid transform/inverse_transform for {self.attr_name}"
84
+ )
82
85
  self._source_targets.append([source, target])
83
86
  return self
84
87
 
@@ -189,7 +192,9 @@ class Pattern:
189
192
 
190
193
  self.name = name
191
194
  self.attr_builder = attr_builder
192
- self._scalar_attr_trackers = scalar_attr_trackers if scalar_attr_trackers else []
195
+ self._scalar_attr_trackers = (
196
+ scalar_attr_trackers if scalar_attr_trackers else []
197
+ )
193
198
 
194
199
  exported_program = torch.export.export(module, export_args)
195
200
  if decomp_table is not None:
@@ -201,7 +206,9 @@ class Pattern:
201
206
  self._scalar_attr_locations = []
202
207
  for tracker in self._scalar_attr_trackers:
203
208
  self._scalar_attr_locations.append(
204
- _find_scalar_attr(module, export_args, tracker, decomp_table=decomp_table)
209
+ _find_scalar_attr(
210
+ module, export_args, tracker, decomp_table=decomp_table
211
+ )
205
212
  )
206
213
 
207
214
  # Sanitize graph_module for more precise pattern matching.
@@ -251,7 +258,9 @@ class Pattern:
251
258
  attrs = {}
252
259
 
253
260
  for loc in self._scalar_attr_locations:
254
- attrs[loc.attr_name] = self._get_attr_value_from_pattern_match(match, loc)
261
+ attrs[loc.attr_name] = self._get_attr_value_from_pattern_match(
262
+ match, loc
263
+ )
255
264
 
256
265
  attrs = attrs if attrs else None
257
266
  match_with_attrs.append((match, attrs))
@@ -15,11 +15,10 @@
15
15
 
16
16
  import unittest
17
17
 
18
+ from ai_edge_torch.hlfb import mark_pattern
18
19
  import torch
19
20
  import torch_xla
20
21
 
21
- from ai_edge_torch.hlfb import mark_pattern
22
-
23
22
 
24
23
  def _export_stablehlo_mlir(model, args=None):
25
24
  if not isinstance(model, torch.export.ExportedProgram):
@@ -73,7 +72,9 @@ class TestMarkPattern(unittest.TestCase):
73
72
  mlir = _export_stablehlo_mlir(exported_program)
74
73
 
75
74
  self.assertEqual(mlir.count('stablehlo.composite "test.add"'), 2)
76
- self.assertEqual(mlir.count('composite_attributes = {alias = "test.test_add"}'), 2)
75
+ self.assertEqual(
76
+ mlir.count('composite_attributes = {alias = "test.test_add"}'), 2
77
+ )
77
78
 
78
79
  def test_mark_pattern_with_scalar_attr_tracker(self):
79
80
  class TestModel(torch.nn.Module):
@@ -15,12 +15,11 @@
15
15
  import math
16
16
  import unittest
17
17
 
18
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
18
19
  import torch
19
20
  import torch.nn.functional as F
20
21
  import torch_xla
21
22
 
22
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
23
-
24
23
 
25
24
  def _export_stablehlo_mlir(model, args):
26
25
  ep = torch.export.export(model, args)
@@ -80,7 +79,9 @@ class TestStableHLOCompositeBuilder(unittest.TestCase):
80
79
  super().__init__()
81
80
 
82
81
  def log_softmax(self, x: torch.Tensor, dim: int):
83
- builder = StableHLOCompositeBuilder(name="test.log_softmax", attr={"dim": dim})
82
+ builder = StableHLOCompositeBuilder(
83
+ name="test.log_softmax", attr={"dim": dim}
84
+ )
84
85
  x = builder.mark_inputs(x)
85
86
  y = torch.nn.functional.log_softmax(x, dim=dim)
86
87
  y = builder.mark_outputs(y)
@@ -126,7 +127,8 @@ class TestStableHLOCompositeBuilder(unittest.TestCase):
126
127
  self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 1)
127
128
  self.assertEqual(
128
129
  mlir.count(
129
- 'composite_attributes = {dim = 0 : i64, source = "torch.nn", version = 1.000000e+00 : f32}'
130
+ 'composite_attributes = {dim = 0 : i64, source = "torch.nn",'
131
+ " version = 1.000000e+00 : f32}"
130
132
  ),
131
133
  1,
132
134
  )
@@ -236,8 +238,12 @@ class TestStableHLOCompositeBuilder(unittest.TestCase):
236
238
  self.assertEqual(
237
239
  mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 2
238
240
  )
239
- self.assertEqual(mlir.count("composite_attributes = {include_captanh = true}"), 1)
240
- self.assertEqual(mlir.count("composite_attributes = {include_captanh = false}"), 1)
241
+ self.assertEqual(
242
+ mlir.count("composite_attributes = {include_captanh = true}"), 1
243
+ )
244
+ self.assertEqual(
245
+ mlir.count("composite_attributes = {include_captanh = false}"), 1
246
+ )
241
247
 
242
248
  def test_build_composite_with_multiple_inputs_outputs(self):
243
249
  class SampleModel(torch.nn.Module):
ai_edge_torch/model.py CHANGED
@@ -21,12 +21,11 @@ from __future__ import annotations
21
21
 
22
22
  import abc
23
23
 
24
+ from ai_edge_torch.convert import conversion_utils as cutils
24
25
  import numpy as np
25
26
  import numpy.typing as npt
26
27
  import tensorflow as tf
27
28
 
28
- from ai_edge_torch.convert import conversion_utils as cutils
29
-
30
29
 
31
30
  class Model(abc.ABC):
32
31
  """Represents and edge model."""
@@ -84,7 +83,8 @@ class TfLiteModel(Model):
84
83
  signature_list = interpreter.get_signature_list()
85
84
  if signature_name not in signature_list:
86
85
  raise ValueError(
87
- f"Invalid signature name provided. Available signatures: {', '.join(signature_list.keys())}"
86
+ 'Invalid signature name provided. Available signatures:'
87
+ f' {", ".join(signature_list.keys())}'
88
88
  )
89
89
 
90
90
  try:
@@ -92,14 +92,17 @@ class TfLiteModel(Model):
92
92
  except ValueError as exception:
93
93
  if 'Invalid signature_key provided.' in str(exception):
94
94
  raise ValueError(
95
- f'Invalid signature key provided. Available signatures: {list(signature_list.keys())}'
95
+ 'Invalid signature key provided. Available signatures:'
96
+ f' {list(signature_list.keys())}'
96
97
  )
97
98
  else:
98
99
  raise exception
99
100
 
100
101
  if len(signature_list[signature_name]['inputs']) != len(args) + len(kwargs):
101
102
  raise ValueError(
102
- f"The model requires {len(signature_list[signature_name]['inputs'])} arguments but {len(args)} was provided."
103
+ 'The model requires'
104
+ f' {len(signature_list[signature_name]["inputs"])} arguments but'
105
+ f' {len(args)} was provided.'
103
106
  )
104
107
 
105
108
  # Gather the input dictionary based on the signature.
@@ -19,6 +19,12 @@ import copy
19
19
  import functools
20
20
  from typing import Any, Callable, Dict, List, Optional, Set
21
21
 
22
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import _convert_scalars_to_attrs # NOQA
23
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import OP_TO_ANNOTATOR
24
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorConfig
25
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorPatternType
26
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import propagate_annotation
27
+ from ai_edge_torch.quantize.pt2e_quantizer_utils import QuantizationConfig
22
28
  import torch
23
29
  from torch.ao.quantization.fake_quantize import FusedMovingAvgObsFakeQuantize
24
30
  from torch.ao.quantization.observer import HistogramObserver
@@ -34,20 +40,15 @@ from torch.ao.quantization.quantizer import Quantizer
34
40
  from torch.fx import Node
35
41
  import torch.nn.functional as F
36
42
 
37
- from ai_edge_torch.quantize.pt2e_quantizer_utils import _convert_scalars_to_attrs # NOQA
38
- from ai_edge_torch.quantize.pt2e_quantizer_utils import OP_TO_ANNOTATOR
39
- from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorConfig
40
- from ai_edge_torch.quantize.pt2e_quantizer_utils import OperatorPatternType
41
- from ai_edge_torch.quantize.pt2e_quantizer_utils import propagate_annotation
42
- from ai_edge_torch.quantize.pt2e_quantizer_utils import QuantizationConfig
43
-
44
43
  __all__ = [
45
44
  "PT2EQuantizer",
46
45
  "get_symmetric_quantization_config",
47
46
  ]
48
47
 
49
48
 
50
- def _supported_symmetric_quantized_operators() -> Dict[str, List[OperatorPatternType]]:
49
+ def _supported_symmetric_quantized_operators() -> (
50
+ Dict[str, List[OperatorPatternType]]
51
+ ):
51
52
  supported_operators: Dict[str, List[OperatorPatternType]] = {
52
53
  # Both conv and linear should be able to handle relu + hardtanh fusion since
53
54
  # those are clamp ops
@@ -92,7 +93,9 @@ def get_symmetric_quantization_config(
92
93
  ):
93
94
  if is_qat:
94
95
  if is_dynamic:
95
- raise NotImplementedError("dynamic quantization for qat is not yet implemented.")
96
+ raise NotImplementedError(
97
+ "dynamic quantization for qat is not yet implemented."
98
+ )
96
99
  act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
97
100
  else:
98
101
  if is_dynamic:
@@ -106,12 +109,18 @@ def get_symmetric_quantization_config(
106
109
  quant_max=127,
107
110
  qscheme=torch.per_tensor_affine,
108
111
  is_dynamic=is_dynamic,
109
- observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(eps=2**-12),
112
+ observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
113
+ eps=2**-12
114
+ ),
110
115
  )
111
116
  qscheme = (
112
- torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
117
+ torch.per_channel_symmetric
118
+ if is_per_channel
119
+ else torch.per_tensor_symmetric
120
+ )
121
+ weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
122
+ MinMaxObserver
113
123
  )
114
- weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = MinMaxObserver
115
124
  if is_qat:
116
125
  weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
117
126
  elif is_per_channel:
@@ -197,7 +206,9 @@ def _get_module_name_filter(module_name: str):
197
206
  # }
198
207
  # get_attr nodes doesn't have nn_module_stack?
199
208
  nn_module_stack = n.meta.get("nn_module_stack", {})
200
- names = [n[len("L__self___") :].replace("_", ".") for n in nn_module_stack.keys()]
209
+ names = [
210
+ n[len("L__self___") :].replace("_", ".") for n in nn_module_stack.keys()
211
+ ]
201
212
  return module_name in names
202
213
 
203
214
  return module_name_filter
@@ -232,7 +243,9 @@ def _get_not_module_type_or_name_filter(
232
243
  tp_list: List[Callable], module_name_list: List[str]
233
244
  ) -> Callable[[Node], bool]:
234
245
  module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
235
- module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
246
+ module_name_list_filters = [
247
+ _get_module_name_filter(m) for m in module_name_list
248
+ ]
236
249
 
237
250
  def not_module_type_or_name_filter(n: Node) -> bool:
238
251
  return not any(f(n) for f in module_type_filters + module_name_list_filters)
@@ -307,7 +320,9 @@ class PT2EQuantizer(Quantizer):
307
320
  return ops
308
321
  return []
309
322
 
310
- def set_global(self, quantization_config: QuantizationConfig) -> PT2EQuantizer:
323
+ def set_global(
324
+ self, quantization_config: QuantizationConfig
325
+ ) -> PT2EQuantizer:
311
326
  self.global_config = quantization_config
312
327
  return self
313
328
 
@@ -154,7 +154,9 @@ def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
154
154
  torch.per_tensor_symmetric,
155
155
  torch.per_channel_symmetric,
156
156
  ]:
157
- raise ValueError(f"Unsupported quantization_spec {quantization_spec} for weight")
157
+ raise ValueError(
158
+ f"Unsupported quantization_spec {quantization_spec} for weight"
159
+ )
158
160
  return quantization_spec
159
161
 
160
162
 
@@ -193,7 +195,10 @@ def _annotate_linear(
193
195
  weight_qspec = get_weight_qspec(quantization_config)
194
196
  bias_qspec = get_bias_qspec(quantization_config)
195
197
  for node in gm.graph.nodes:
196
- if node.op != "call_function" or node.target != torch.ops.aten.linear.default:
198
+ if (
199
+ node.op != "call_function"
200
+ or node.target != torch.ops.aten.linear.default
201
+ ):
197
202
  continue
198
203
  if filter_fn and not filter_fn(node):
199
204
  continue
@@ -417,7 +422,9 @@ def _annotate_conv_bn(
417
422
  Find conv + batchnorm parititions
418
423
  Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
419
424
  """
420
- return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=False)
425
+ return _do_annotate_conv_bn(
426
+ gm, quantization_config, filter_fn, has_relu=False
427
+ )
421
428
 
422
429
 
423
430
  @register_annotator("conv_bn_relu")
@@ -486,7 +493,9 @@ def _do_annotate_conv_bn(
486
493
  # Match against all conv dimensions and cuda variants
487
494
  for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
488
495
  pattern = get_pattern(conv_fn, relu_is_inplace)
489
- pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda)
496
+ pattern = _get_aten_graph_module_for_pattern(
497
+ pattern, example_inputs, is_cuda
498
+ )
490
499
  pattern.graph.eliminate_dead_code()
491
500
  pattern.recompile()
492
501
  matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
@@ -676,7 +685,9 @@ def _annotate_adaptive_avg_pool2d(
676
685
  and pool_node.target != torch.ops.aten.mean.dim
677
686
  and pool_node.target != torch.ops.aten.as_strided_.default
678
687
  ):
679
- raise ValueError(f"{pool_node} is not an aten adaptive_avg_pool2d operator")
688
+ raise ValueError(
689
+ f"{pool_node} is not an aten adaptive_avg_pool2d operator"
690
+ )
680
691
 
681
692
  if _is_annotated([pool_node]):
682
693
  continue
@@ -741,7 +752,8 @@ def _annotate_fixed_qparams(
741
752
  continue
742
753
 
743
754
  node.meta["quantization_annotation"] = QuantizationAnnotation(
744
- output_qspec=get_fixed_qparams_qspec(quantization_config), _annotated=True
755
+ output_qspec=get_fixed_qparams_qspec(quantization_config),
756
+ _annotated=True,
745
757
  )
746
758
  _mark_nodes_as_annotated(partition)
747
759
  annotated_partitions.append(partition)
@@ -885,7 +897,9 @@ def _annotate_mul(
885
897
  filter_fn: Optional[Callable[[Node], bool]] = None,
886
898
  ) -> Optional[List[List[Node]]]:
887
899
  mul_partitions = get_source_partitions(
888
- gm.graph, ["mul", "mul_", operator.mul, torch.mul, operator.imul], filter_fn
900
+ gm.graph,
901
+ ["mul", "mul_", operator.mul, torch.mul, operator.imul],
902
+ filter_fn,
889
903
  )
890
904
  mul_partitions = list(itertools.chain(*mul_partitions.values()))
891
905
  annotated_partitions = []
@@ -932,8 +946,9 @@ def _annotate_cat(
932
946
 
933
947
  if cat_node.target != torch.ops.aten.cat.default:
934
948
  raise Exception(
935
- f"Expected cat node: torch.ops.aten.cat.default, but found {cat_node.target}"
936
- " please check if you are calling the correct capture API"
949
+ "Expected cat node: torch.ops.aten.cat.default, but found"
950
+ f" {cat_node.target} please check if you are calling the correct"
951
+ " capture API"
937
952
  )
938
953
 
939
954
  annotated_partitions.append(cat_partition.nodes)
@@ -987,7 +1002,9 @@ def propagate_annotation(model: torch.fx.GraphModule) -> None:
987
1002
  if not isinstance(prev_node, Node):
988
1003
  continue
989
1004
 
990
- quantization_annotation = prev_node.meta.get("quantization_annotation", None)
1005
+ quantization_annotation = prev_node.meta.get(
1006
+ "quantization_annotation", None
1007
+ )
991
1008
  if not quantization_annotation:
992
1009
  continue
993
1010
 
@@ -1014,7 +1031,9 @@ def propagate_annotation(model: torch.fx.GraphModule) -> None:
1014
1031
 
1015
1032
 
1016
1033
  # TODO: make the list of ops customizable
1017
- def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule:
1034
+ def _convert_scalars_to_attrs(
1035
+ model: torch.fx.GraphModule,
1036
+ ) -> torch.fx.GraphModule:
1018
1037
  for n in model.graph.nodes:
1019
1038
  if n.op != "call_function" or n.target not in [
1020
1039
  torch.ops.aten.add.Tensor,
@@ -76,6 +76,10 @@ class QuantConfig:
76
76
  elif generative_recipe is not None:
77
77
  generative_recipe.verify()
78
78
  object.__setattr__(self, 'generative_recipe', generative_recipe)
79
- object.__setattr__(self, '_quantizer_mode', self._QuantizerMode.AI_EDGE_QUANTIZER)
79
+ object.__setattr__(
80
+ self, '_quantizer_mode', self._QuantizerMode.AI_EDGE_QUANTIZER
81
+ )
80
82
  else:
81
- raise ValueError('Either pt2e_quantizer or generative_recipe must be set.')
83
+ raise ValueError(
84
+ 'Either pt2e_quantizer or generative_recipe must be set.'
85
+ )