ai-edge-torch-nightly 0.2.0.dev20240730__py3-none-any.whl → 0.2.0.dev20240802__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (89) hide show
  1. ai_edge_torch/convert/conversion.py +12 -8
  2. ai_edge_torch/convert/conversion_utils.py +38 -20
  3. ai_edge_torch/convert/converter.py +11 -5
  4. ai_edge_torch/convert/fx_passes/__init__.py +3 -4
  5. ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
  6. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +45 -36
  7. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
  8. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
  9. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
  10. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
  11. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
  15. ai_edge_torch/convert/test/test_convert.py +39 -16
  16. ai_edge_torch/convert/test/test_convert_composites.py +115 -86
  17. ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
  18. ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
  19. ai_edge_torch/convert/to_channel_last_io.py +6 -2
  20. ai_edge_torch/debug/culprit.py +41 -16
  21. ai_edge_torch/debug/test/test_culprit.py +4 -3
  22. ai_edge_torch/debug/test/test_search_model.py +4 -3
  23. ai_edge_torch/debug/utils.py +3 -1
  24. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
  25. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
  26. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
  27. ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
  28. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
  29. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
  30. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
  31. ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
  32. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
  33. ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
  34. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  35. ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
  36. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +26 -13
  37. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +15 -7
  38. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +47 -16
  39. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  40. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +42 -12
  41. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  42. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  43. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  44. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  45. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  46. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
  47. ai_edge_torch/generative/examples/t5/t5.py +158 -125
  48. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
  49. ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
  50. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
  52. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
  53. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
  54. ai_edge_torch/generative/fx_passes/__init__.py +1 -2
  55. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
  56. ai_edge_torch/generative/layers/attention.py +19 -11
  57. ai_edge_torch/generative/layers/builder.py +3 -4
  58. ai_edge_torch/generative/layers/kv_cache.py +4 -3
  59. ai_edge_torch/generative/layers/model_config.py +6 -2
  60. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
  61. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
  62. ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
  63. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  64. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
  65. ai_edge_torch/generative/quantize/example.py +2 -3
  66. ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
  67. ai_edge_torch/generative/quantize/quant_recipe_utils.py +10 -0
  68. ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
  69. ai_edge_torch/generative/test/loader_test.py +5 -4
  70. ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
  71. ai_edge_torch/generative/test/test_model_conversion.py +2 -3
  72. ai_edge_torch/generative/test/test_quantize.py +45 -47
  73. ai_edge_torch/generative/utilities/loader.py +55 -28
  74. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
  75. ai_edge_torch/generative/utilities/t5_loader.py +77 -48
  76. ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
  77. ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
  78. ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
  79. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
  80. ai_edge_torch/model.py +8 -5
  81. ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
  82. ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
  83. ai_edge_torch/quantize/quant_config.py +6 -2
  84. ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
  85. {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/METADATA +1 -1
  86. {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/RECORD +89 -89
  87. {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/LICENSE +0 -0
  88. {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/WHEEL +0 -0
  89. {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240802.dist-info}/top_level.txt +0 -0
@@ -16,13 +16,6 @@ import operator
16
16
  import os
17
17
  from typing import Optional, Tuple, Union
18
18
 
19
- import torch
20
- import torch.ao.quantization.quantize_pt2e
21
- from torch.export import ExportedProgram
22
- from torch.fx import GraphModule
23
- from torch.fx import Node
24
- import torch.utils._pytree as pytree
25
-
26
19
  from ai_edge_torch.convert.fx_passes import ExportedProgramPassBase
27
20
  from ai_edge_torch.convert.fx_passes import ExportedProgramPassResult
28
21
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
@@ -30,6 +23,12 @@ from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layo
30
23
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_partitioners # NOQA
31
24
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite # NOQA
32
25
  from ai_edge_torch.convert.fx_passes.optimize_layout_transposes_pass import utils # NOQA
26
+ import torch
27
+ import torch.ao.quantization.quantize_pt2e
28
+ from torch.export import ExportedProgram
29
+ from torch.fx import GraphModule
30
+ from torch.fx import Node
31
+ import torch.utils._pytree as pytree
33
32
 
34
33
  TransposeFunc = Union[utils.tensor_to_nchw, utils.tensor_to_nhwc]
35
34
 
@@ -208,7 +207,8 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
208
207
 
209
208
  if not layout_check.is_4d(input_node):
210
209
  raise AssertionError(
211
- f"Attempting to convert non-NHWC compatible node to NHWC: {input_node}"
210
+ "Attempting to convert non-NHWC compatible node to NHWC:"
211
+ f" {input_node}"
212
212
  )
213
213
 
214
214
  # Assign target node's source meta to the to_NHWC node, because the transpose
@@ -250,7 +250,9 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
250
250
 
251
251
  for node in graph.nodes:
252
252
  has_input_nodes = len(node.all_input_nodes) > 0
253
- all_inputs_are_const = all(map(layout_mark.is_const_node, node.all_input_nodes))
253
+ all_inputs_are_const = all(
254
+ map(layout_mark.is_const_node, node.all_input_nodes)
255
+ )
254
256
  if (
255
257
  node.name in non_user_input_names
256
258
  or (has_input_nodes and all_inputs_are_const)
@@ -262,7 +264,9 @@ class OptimizeLayoutTransposesPass(ExportedProgramPassBase):
262
264
  self.mark_const_nodes(exported_program)
263
265
 
264
266
  graph_module = exported_program.graph_module
265
- partitioner = os.environ.get("AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER", None)
267
+ partitioner = os.environ.get(
268
+ "AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER", None
269
+ )
266
270
  if partitioner == "MINCUT":
267
271
  graph_module = layout_partitioners.min_cut.partition(graph_module)
268
272
  elif partitioner == "GREEDY":
@@ -20,15 +20,14 @@ import tempfile
20
20
  from typing import Tuple
21
21
  import unittest
22
22
 
23
+ import ai_edge_torch
24
+ from ai_edge_torch.convert import conversion_utils as cutils
25
+ from ai_edge_torch.testing import model_coverage
23
26
  import numpy as np
24
27
  import tensorflow as tf
25
28
  import torch
26
29
  import torchvision
27
30
 
28
- import ai_edge_torch
29
- from ai_edge_torch.convert import conversion_utils as cutils
30
- from ai_edge_torch.testing import model_coverage
31
-
32
31
 
33
32
  @dataclass
34
33
  class TestContainer1:
@@ -36,7 +35,9 @@ class TestContainer1:
36
35
  data_2: Tuple[torch.Tensor, torch.Tensor]
37
36
 
38
37
 
39
- torch.export.register_dataclass(TestContainer1, serialized_type_name="TestContainer1")
38
+ torch.export.register_dataclass(
39
+ TestContainer1, serialized_type_name="TestContainer1"
40
+ )
40
41
 
41
42
 
42
43
  class TestConvert(unittest.TestCase):
@@ -60,7 +61,9 @@ class TestConvert(unittest.TestCase):
60
61
  torch_module = Add().eval()
61
62
  edge_model = ai_edge_torch.convert(torch_module, args)
62
63
 
63
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
64
+ self.assertTrue(
65
+ model_coverage.compare_tflite_torch(edge_model, torch_module, args)
66
+ )
64
67
 
65
68
  def test_convert_dot_add(self):
66
69
  class DotAdd(torch.nn.Module):
@@ -77,14 +80,18 @@ class TestConvert(unittest.TestCase):
77
80
  torch_module = DotAdd().eval()
78
81
  edge_model = ai_edge_torch.convert(torch_module, args)
79
82
 
80
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
83
+ self.assertTrue(
84
+ model_coverage.compare_tflite_torch(edge_model, torch_module, args)
85
+ )
81
86
 
82
87
  def test_convert_resnet18(self):
83
88
  args = (torch.randn(4, 3, 224, 224),)
84
89
  torch_module = torchvision.models.resnet18().eval()
85
90
  edge_model = ai_edge_torch.convert(torch_module, args)
86
91
 
87
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
92
+ self.assertTrue(
93
+ model_coverage.compare_tflite_torch(edge_model, torch_module, args)
94
+ )
88
95
 
89
96
  def test_signature_args_ordering(self):
90
97
  """Tests conversion of a model with more than 10 arguments."""
@@ -156,7 +163,9 @@ class TestConvert(unittest.TestCase):
156
163
  torch_model = BasicAddModelWithMultipleOutputs().eval()
157
164
  edge_model = ai_edge_torch.convert(torch_model, sample_input)
158
165
 
159
- result = model_coverage.compare_tflite_torch(edge_model, torch_model, sample_input)
166
+ result = model_coverage.compare_tflite_torch(
167
+ edge_model, torch_model, sample_input
168
+ )
160
169
  self.assertTrue(result)
161
170
 
162
171
  def test_12_outputs_model(self):
@@ -201,7 +210,9 @@ class TestConvert(unittest.TestCase):
201
210
  torch_model = BasicAddModelWithMultipleOutputs().eval()
202
211
  edge_model = ai_edge_torch.convert(torch_model, sample_input)
203
212
 
204
- result = model_coverage.compare_tflite_torch(edge_model, torch_model, sample_input)
213
+ result = model_coverage.compare_tflite_torch(
214
+ edge_model, torch_model, sample_input
215
+ )
205
216
  self.assertTrue(result)
206
217
 
207
218
  def test_apply_tfl_backdoor_flags(self):
@@ -244,7 +255,9 @@ class TestConvert(unittest.TestCase):
244
255
  tmp_dir_path, "test_convert_add_backdoor_flags_mlir_dump"
245
256
  )
246
257
  ai_edge_torch.convert(
247
- torch_module, args, _ai_edge_converter_flags={"ir_dump_dir": ir_dump_path}
258
+ torch_module,
259
+ args,
260
+ _ai_edge_converter_flags={"ir_dump_dir": ir_dump_path},
248
261
  )
249
262
  self.assertTrue(os.path.isdir(ir_dump_path))
250
263
 
@@ -296,7 +309,9 @@ class TestConvert(unittest.TestCase):
296
309
  edge_model = ai_edge_torch.convert(model, sample_kwargs=kwargs_gen())
297
310
 
298
311
  self.assertTrue(
299
- model_coverage.compare_tflite_torch(edge_model, model, kwargs=kwargs_gen)
312
+ model_coverage.compare_tflite_torch(
313
+ edge_model, model, kwargs=kwargs_gen
314
+ )
300
315
  )
301
316
 
302
317
  def test_convert_model_with_args_kwargs(self):
@@ -316,7 +331,9 @@ class TestConvert(unittest.TestCase):
316
331
  edge_model = ai_edge_torch.convert(model, args_gen(), kwargs_gen())
317
332
 
318
333
  self.assertTrue(
319
- model_coverage.compare_tflite_torch(edge_model, model, args_gen, kwargs_gen)
334
+ model_coverage.compare_tflite_torch(
335
+ edge_model, model, args_gen, kwargs_gen
336
+ )
320
337
  )
321
338
 
322
339
  def test_convert_model_with_args_nested_kwargs_1(self):
@@ -344,7 +361,9 @@ class TestConvert(unittest.TestCase):
344
361
  "z_data_2_0": kwargs["z"].data_2[0].numpy(),
345
362
  "z_data_2_1": kwargs["z"].data_2[1].numpy(),
346
363
  }
347
- self._compare_tflite_torch_args_kwargs(SampleModel(), args, kwargs, flat_inputs)
364
+ self._compare_tflite_torch_args_kwargs(
365
+ SampleModel(), args, kwargs, flat_inputs
366
+ )
348
367
 
349
368
  def test_convert_model_with_args_nested_kwargs_2(self):
350
369
  """
@@ -371,7 +390,9 @@ class TestConvert(unittest.TestCase):
371
390
  "z_data_2_0_0": kwargs["z"].data_2[0][0].numpy(),
372
391
  "z_data_2_1": kwargs["z"].data_2[1].numpy(),
373
392
  }
374
- self._compare_tflite_torch_args_kwargs(SampleModel(), args, kwargs, flat_inputs)
393
+ self._compare_tflite_torch_args_kwargs(
394
+ SampleModel(), args, kwargs, flat_inputs
395
+ )
375
396
 
376
397
  def test_convert_model_with_args_nested_kwargs_3(self):
377
398
  """
@@ -398,7 +419,9 @@ class TestConvert(unittest.TestCase):
398
419
  "z_data_2_0_foo": kwargs["z"].data_2[0]["foo"].numpy(),
399
420
  "z_data_2_1": kwargs["z"].data_2[1].numpy(),
400
421
  }
401
- self._compare_tflite_torch_args_kwargs(SampleModel(), args, kwargs, flat_inputs)
422
+ self._compare_tflite_torch_args_kwargs(
423
+ SampleModel(), args, kwargs, flat_inputs
424
+ )
402
425
 
403
426
  def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
404
427
  model.eval()
@@ -17,11 +17,10 @@
17
17
  from typing import Callable
18
18
  import unittest
19
19
 
20
- import parameterized
21
- import torch
22
-
23
20
  import ai_edge_torch
24
21
  from ai_edge_torch.testing import model_coverage
22
+ import parameterized
23
+ import torch
25
24
 
26
25
 
27
26
  def _func_to_torch_module(func: Callable):
@@ -47,35 +46,35 @@ class TestConvertComposites(unittest.TestCase):
47
46
  torch_module = torch.nn.Hardswish().eval()
48
47
  edge_model = ai_edge_torch.convert(torch_module, args)
49
48
 
50
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
51
-
52
- @parameterized.parameterized.expand(
53
- [
54
- # input_size, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
55
- # no padding, stride = 1
56
- ([1, 3, 6, 6], [3, 3], [1, 1], [0, 0], False, True, None),
57
- # add stride
58
- ([1, 3, 6, 6], [3, 3], [2, 2], [0, 0], False, True, None),
59
- # default values
60
- ([1, 3, 6, 6], [3, 3]),
61
- # add padding
62
- ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, True, None),
63
- # add different padding for different dims
64
- ([1, 3, 6, 6], [3, 3], [1, 1], [0, 1], False, True, None),
65
- # add both stride and padding
66
- ([1, 3, 6, 6], [3, 3], [2, 2], [1, 1], False, True, None),
67
- # count_include_pad = False
68
- ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, False, None),
69
- # ceil_mode = True
70
- ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], True, True, None),
71
- # ceil_mode = True, stride=[3, 3]
72
- ([1, 3, 6, 6], [3, 3], [3, 3], [1, 1], True, True, None),
73
- # set divisor_override
74
- ([1, 3, 6, 6], [3, 3], [1, 1], 0, False, True, 6),
75
- # padding set to one number
76
- ([1, 3, 6, 6], [3, 3], [1, 1], 1, False, True, None),
77
- ]
78
- )
49
+ self.assertTrue(
50
+ model_coverage.compare_tflite_torch(edge_model, torch_module, args)
51
+ )
52
+
53
+ @parameterized.parameterized.expand([
54
+ # input_size, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
55
+ # no padding, stride = 1
56
+ ([1, 3, 6, 6], [3, 3], [1, 1], [0, 0], False, True, None),
57
+ # add stride
58
+ ([1, 3, 6, 6], [3, 3], [2, 2], [0, 0], False, True, None),
59
+ # default values
60
+ ([1, 3, 6, 6], [3, 3]),
61
+ # add padding
62
+ ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, True, None),
63
+ # add different padding for different dims
64
+ ([1, 3, 6, 6], [3, 3], [1, 1], [0, 1], False, True, None),
65
+ # add both stride and padding
66
+ ([1, 3, 6, 6], [3, 3], [2, 2], [1, 1], False, True, None),
67
+ # count_include_pad = False
68
+ ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, False, None),
69
+ # ceil_mode = True
70
+ ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], True, True, None),
71
+ # ceil_mode = True, stride=[3, 3]
72
+ ([1, 3, 6, 6], [3, 3], [3, 3], [1, 1], True, True, None),
73
+ # set divisor_override
74
+ ([1, 3, 6, 6], [3, 3], [1, 1], 0, False, True, 6),
75
+ # padding set to one number
76
+ ([1, 3, 6, 6], [3, 3], [1, 1], 1, False, True, None),
77
+ ])
79
78
  def test_convert_avg_pool2d(self, input_size, *args):
80
79
  """Tests conversion of a module containing an avg_pool2d aten."""
81
80
  torch_module = _func_to_torch_module(
@@ -85,51 +84,65 @@ class TestConvertComposites(unittest.TestCase):
85
84
  edge_model = ai_edge_torch.convert(torch_module, tracing_args)
86
85
 
87
86
  self.assertTrue(
88
- model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
87
+ model_coverage.compare_tflite_torch(
88
+ edge_model, torch_module, tracing_args
89
+ )
89
90
  )
90
91
 
91
- @parameterized.parameterized.expand(
92
- [
93
- # use scale_factor with align_corners=False
94
- (
95
- [1, 3, 10, 10],
96
- dict(scale_factor=3.0, mode='bilinear', align_corners=False),
97
- ),
98
- # use scale_factor with align_corners=true
99
- ([1, 3, 10, 10], dict(scale_factor=3.0, mode='bilinear', align_corners=True)),
100
- # use size
101
- ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
102
- # use size with align_corners=true
103
- ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear', align_corners=True)),
104
- ]
105
- )
92
+ @parameterized.parameterized.expand([
93
+ # use scale_factor with align_corners=False
94
+ (
95
+ [1, 3, 10, 10],
96
+ dict(scale_factor=3.0, mode='bilinear', align_corners=False),
97
+ ),
98
+ # use scale_factor with align_corners=true
99
+ (
100
+ [1, 3, 10, 10],
101
+ dict(scale_factor=3.0, mode='bilinear', align_corners=True),
102
+ ),
103
+ # use size
104
+ ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
105
+ # use size with align_corners=true
106
+ (
107
+ [1, 3, 10, 10],
108
+ dict(size=[15, 20], mode='bilinear', align_corners=True),
109
+ ),
110
+ ])
106
111
  def test_convert_upsample_bilinear_functional(self, input_size, kwargs):
107
112
  """Tests conversion of a torch.nn.functional.upsample module."""
108
113
  torch_module = _func_to_torch_module(
109
- lambda input_tensor: torch.nn.functional.upsample(input_tensor, **kwargs)
114
+ lambda input_tensor: torch.nn.functional.upsample(
115
+ input_tensor, **kwargs
116
+ )
110
117
  )
111
118
  tracing_args = (torch.randn(*input_size),)
112
119
  edge_model = ai_edge_torch.convert(torch_module, tracing_args)
113
120
 
114
121
  self.assertTrue(
115
- model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
122
+ model_coverage.compare_tflite_torch(
123
+ edge_model, torch_module, tracing_args
124
+ )
116
125
  )
117
126
 
118
- @parameterized.parameterized.expand(
119
- [
120
- # use scale_factor with align_corners=False
121
- (
122
- [1, 3, 10, 10],
123
- dict(scale_factor=3.0, mode='bilinear', align_corners=False),
124
- ),
125
- # use scale_factor with align_corners=true
126
- ([1, 3, 10, 10], dict(scale_factor=3.0, mode='bilinear', align_corners=True)),
127
- # use size
128
- ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
129
- # use size with align_corners=true
130
- ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear', align_corners=True)),
131
- ]
132
- )
127
+ @parameterized.parameterized.expand([
128
+ # use scale_factor with align_corners=False
129
+ (
130
+ [1, 3, 10, 10],
131
+ dict(scale_factor=3.0, mode='bilinear', align_corners=False),
132
+ ),
133
+ # use scale_factor with align_corners=true
134
+ (
135
+ [1, 3, 10, 10],
136
+ dict(scale_factor=3.0, mode='bilinear', align_corners=True),
137
+ ),
138
+ # use size
139
+ ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
140
+ # use size with align_corners=true
141
+ (
142
+ [1, 3, 10, 10],
143
+ dict(size=[15, 20], mode='bilinear', align_corners=True),
144
+ ),
145
+ ])
133
146
  def test_convert_upsample_bilinear(self, input_size, kwargs):
134
147
  """Tests conversion of a torch.nn.Upsample module."""
135
148
  torch_module = _func_to_torch_module(
@@ -139,34 +152,44 @@ class TestConvertComposites(unittest.TestCase):
139
152
  edge_model = ai_edge_torch.convert(torch_module, tracing_args)
140
153
 
141
154
  self.assertTrue(
142
- model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
155
+ model_coverage.compare_tflite_torch(
156
+ edge_model, torch_module, tracing_args
157
+ )
143
158
  )
144
159
 
145
- @parameterized.parameterized.expand(
146
- [
147
- # use scale_factor with align_corners=False
148
- (
149
- [1, 3, 10, 10],
150
- dict(scale_factor=3.0, mode='bilinear', align_corners=False),
151
- ),
152
- # use scale_factor with align_corners=true
153
- ([1, 3, 10, 10], dict(scale_factor=3.0, mode='bilinear', align_corners=True)),
154
- # use size
155
- ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
156
- # use size with align_corners=true
157
- ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear', align_corners=True)),
158
- ]
159
- )
160
+ @parameterized.parameterized.expand([
161
+ # use scale_factor with align_corners=False
162
+ (
163
+ [1, 3, 10, 10],
164
+ dict(scale_factor=3.0, mode='bilinear', align_corners=False),
165
+ ),
166
+ # use scale_factor with align_corners=true
167
+ (
168
+ [1, 3, 10, 10],
169
+ dict(scale_factor=3.0, mode='bilinear', align_corners=True),
170
+ ),
171
+ # use size
172
+ ([1, 3, 10, 10], dict(size=[15, 20], mode='bilinear')),
173
+ # use size with align_corners=true
174
+ (
175
+ [1, 3, 10, 10],
176
+ dict(size=[15, 20], mode='bilinear', align_corners=True),
177
+ ),
178
+ ])
160
179
  def test_convert_interpolate_bilinear_functional(self, input_size, kwargs):
161
180
  """Tests conversion of a torch.nn.functional.interpolate module."""
162
181
  torch_module = _func_to_torch_module(
163
- lambda input_tensor: torch.nn.functional.interpolate(input_tensor, **kwargs)
182
+ lambda input_tensor: torch.nn.functional.interpolate(
183
+ input_tensor, **kwargs
184
+ )
164
185
  )
165
186
  tracing_args = (torch.randn(*input_size),)
166
187
  edge_model = ai_edge_torch.convert(torch_module, tracing_args)
167
188
 
168
189
  self.assertTrue(
169
- model_coverage.compare_tflite_torch(edge_model, torch_module, tracing_args)
190
+ model_coverage.compare_tflite_torch(
191
+ edge_model, torch_module, tracing_args
192
+ )
170
193
  )
171
194
 
172
195
  def test_convert_gelu(self):
@@ -176,7 +199,9 @@ class TestConvertComposites(unittest.TestCase):
176
199
  torch_module = torch.nn.GELU().eval()
177
200
  edge_model = ai_edge_torch.convert(torch_module, args)
178
201
 
179
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
202
+ self.assertTrue(
203
+ model_coverage.compare_tflite_torch(edge_model, torch_module, args)
204
+ )
180
205
 
181
206
  def test_convert_gelu_approximate(self):
182
207
  """Tests conversion of an Approximate GELU module."""
@@ -185,7 +210,9 @@ class TestConvertComposites(unittest.TestCase):
185
210
  torch_module = torch.nn.GELU('tanh').eval()
186
211
  edge_model = ai_edge_torch.convert(torch_module, args)
187
212
 
188
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
213
+ self.assertTrue(
214
+ model_coverage.compare_tflite_torch(edge_model, torch_module, args)
215
+ )
189
216
 
190
217
  def test_convert_embedding_lookup(self):
191
218
  """Tests conversion of an Embedding module."""
@@ -194,7 +221,9 @@ class TestConvertComposites(unittest.TestCase):
194
221
  torch_module = torch.nn.Embedding(10, 10)
195
222
  edge_model = ai_edge_torch.convert(torch_module, args)
196
223
 
197
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
224
+ self.assertTrue(
225
+ model_coverage.compare_tflite_torch(edge_model, torch_module, args)
226
+ )
198
227
 
199
228
 
200
229
  if __name__ == '__main__':
@@ -15,11 +15,10 @@
15
15
 
16
16
  import unittest
17
17
 
18
- import torch
19
- import torchvision
20
-
21
18
  import ai_edge_torch
22
19
  from ai_edge_torch.testing import model_coverage
20
+ import torch
21
+ import torchvision
23
22
 
24
23
 
25
24
  class TestConvertMultiSignature(unittest.TestCase):
@@ -41,7 +40,9 @@ class TestConvertMultiSignature(unittest.TestCase):
41
40
  signature_name, torch_module, large_args
42
41
  ).convert(torch_module, args)
43
42
 
44
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
43
+ self.assertTrue(
44
+ model_coverage.compare_tflite_torch(edge_model, torch_module, args)
45
+ )
45
46
  self.assertTrue(
46
47
  model_coverage.compare_tflite_torch(
47
48
  edge_model, torch_module, large_args, signature_name=signature_name
@@ -74,7 +75,10 @@ class TestConvertMultiSignature(unittest.TestCase):
74
75
  )
75
76
  self.assertTrue(
76
77
  model_coverage.compare_tflite_torch(
77
- edge_model, torch_module, large_args, signature_name=signature_name_2
78
+ edge_model,
79
+ torch_module,
80
+ large_args,
81
+ signature_name=signature_name_2,
78
82
  )
79
83
  )
80
84
 
@@ -87,11 +91,13 @@ class TestConvertMultiSignature(unittest.TestCase):
87
91
 
88
92
  signature_name = "large_input"
89
93
 
90
- edge_model = ai_edge_torch.signature(signature_name, torch_module, args).convert(
91
- torch_module, large_args
92
- )
94
+ edge_model = ai_edge_torch.signature(
95
+ signature_name, torch_module, args
96
+ ).convert(torch_module, large_args)
93
97
 
94
- self.assertTrue(model_coverage.compare_tflite_torch(edge_model, torch_module, args))
98
+ self.assertTrue(
99
+ model_coverage.compare_tflite_torch(edge_model, torch_module, args)
100
+ )
95
101
  self.assertTrue(
96
102
  model_coverage.compare_tflite_torch(
97
103
  edge_model, torch_module, large_args, signature_name=signature_name
@@ -110,7 +116,9 @@ class TestConvertMultiSignature(unittest.TestCase):
110
116
  resnet_signature_name = "resnet18"
111
117
 
112
118
  edge_model = (
113
- ai_edge_torch.signature(mobilenet_signature_name, mobilentv2, mobilenet_args)
119
+ ai_edge_torch.signature(
120
+ mobilenet_signature_name, mobilentv2, mobilenet_args
121
+ )
114
122
  .signature(resnet_signature_name, resnet18, resnet_args)
115
123
  .convert(resnet18, resnet_args)
116
124
  )
@@ -15,9 +15,8 @@
15
15
 
16
16
  import unittest
17
17
 
18
- import torch
19
-
20
18
  import ai_edge_torch
19
+ import torch
21
20
 
22
21
 
23
22
  class Identity(torch.nn.Module):
@@ -31,7 +31,9 @@ class ChannelLastIOWrapper(nn.Module):
31
31
  if not torch.is_tensor(x):
32
32
  raise ValueError("Input must be a torch tensor")
33
33
  if x.ndim < 3:
34
- raise ValueError("Input must be a tensor with rank >= 3 in layout (N, C, ...)")
34
+ raise ValueError(
35
+ "Input must be a tensor with rank >= 3 in layout (N, C, ...)"
36
+ )
35
37
  dims = [0, *range(2, x.ndim), 1]
36
38
  return torch.permute(x, dims)
37
39
 
@@ -39,7 +41,9 @@ class ChannelLastIOWrapper(nn.Module):
39
41
  if not torch.is_tensor(x):
40
42
  raise ValueError("Input must be a torch tensor.")
41
43
  if x.ndim < 3:
42
- raise ValueError("Input must be a tensor with rank >= 3 in layout (N, ..., C)")
44
+ raise ValueError(
45
+ "Input must be a tensor with rank >= 3 in layout (N, ..., C)"
46
+ )
43
47
  dims = [0, x.ndim - 1, *range(1, x.ndim - 1)]
44
48
  return torch.permute(x, dims)
45
49