ai-edge-torch-nightly 0.2.0.dev20240805__py3-none-any.whl → 0.2.0.dev20240808__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (104) hide show
  1. ai_edge_torch/__init__.py +5 -5
  2. ai_edge_torch/{convert → _convert}/conversion.py +40 -50
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +83 -43
  5. ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
  8. ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
  9. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
  10. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
  19. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  20. ai_edge_torch/_convert/signature.py +100 -0
  21. ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
  22. ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
  23. ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
  24. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
  25. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
  26. ai_edge_torch/config.py +24 -0
  27. ai_edge_torch/conftest.py +20 -0
  28. ai_edge_torch/debug/culprit.py +22 -22
  29. ai_edge_torch/debug/test/test_culprit.py +4 -3
  30. ai_edge_torch/debug/test/test_search_model.py +5 -5
  31. ai_edge_torch/debug/utils.py +11 -2
  32. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
  33. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
  34. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
  35. ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
  36. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
  39. ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
  40. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
  41. ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
  42. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
  44. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
  45. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
  46. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  47. ai_edge_torch/generative/examples/t5/t5.py +2 -2
  48. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
  49. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  50. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
  52. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
  55. ai_edge_torch/generative/fx_passes/__init__.py +2 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
  57. ai_edge_torch/generative/layers/attention.py +35 -26
  58. ai_edge_torch/generative/layers/attention_utils.py +23 -12
  59. ai_edge_torch/generative/layers/builder.py +0 -1
  60. ai_edge_torch/generative/layers/feed_forward.py +6 -10
  61. ai_edge_torch/generative/layers/kv_cache.py +0 -1
  62. ai_edge_torch/generative/layers/model_config.py +2 -5
  63. ai_edge_torch/generative/layers/normalization.py +5 -7
  64. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
  65. ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
  66. ai_edge_torch/generative/layers/unet/model_config.py +14 -15
  67. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
  68. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
  69. ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
  70. ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
  71. ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
  72. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
  73. ai_edge_torch/generative/test/test_model_conversion.py +24 -25
  74. ai_edge_torch/generative/test/test_quantize.py +10 -5
  75. ai_edge_torch/generative/utilities/loader.py +12 -12
  76. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
  77. ai_edge_torch/generative/utilities/t5_loader.py +12 -13
  78. ai_edge_torch/hlfb/__init__.py +1 -1
  79. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
  80. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  81. ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
  82. ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
  83. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
  84. ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
  85. ai_edge_torch/lowertools/_shim.py +80 -0
  86. ai_edge_torch/lowertools/common_utils.py +89 -0
  87. ai_edge_torch/lowertools/odml_torch_utils.py +211 -0
  88. ai_edge_torch/lowertools/torch_xla_utils.py +273 -0
  89. ai_edge_torch/model.py +14 -9
  90. ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
  91. ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
  92. ai_edge_torch/quantize/quant_config.py +7 -7
  93. ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
  94. ai_edge_torch/version.py +1 -1
  95. {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/METADATA +1 -1
  96. ai_edge_torch_nightly-0.2.0.dev20240808.dist-info/RECORD +141 -0
  97. ai_edge_torch/convert/conversion_utils.py +0 -439
  98. ai_edge_torch_nightly-0.2.0.dev20240805.dist-info/RECORD +0 -133
  99. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  100. /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
  101. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  102. {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/LICENSE +0 -0
  103. {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/WHEEL +0 -0
  104. {ai_edge_torch_nightly-0.2.0.dev20240805.dist-info → ai_edge_torch_nightly-0.2.0.dev20240808.dist-info}/top_level.txt +0 -0
@@ -12,24 +12,25 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Tests for ai_edge_torch.convert."""
15
16
 
16
-
17
- from dataclasses import dataclass
17
+ import dataclasses
18
18
  import os
19
- import tempfile
20
19
  from typing import Tuple
21
- import unittest
22
20
 
23
21
  import ai_edge_torch
24
- from ai_edge_torch.convert import conversion_utils as cutils
22
+ from ai_edge_torch import config
23
+ from ai_edge_torch._convert import conversion_utils
25
24
  from ai_edge_torch.testing import model_coverage
26
25
  import numpy as np
27
26
  import tensorflow as tf
28
27
  import torch
29
28
  import torchvision
30
29
 
30
+ from tensorflow.python.platform import googletest
31
+
31
32
 
32
- @dataclass
33
+ @dataclasses.dataclass
33
34
  class TestContainer1:
34
35
  data_1: torch.Tensor
35
36
  data_2: Tuple[torch.Tensor, torch.Tensor]
@@ -40,10 +41,11 @@ torch.export.register_dataclass(
40
41
  )
41
42
 
42
43
 
43
- class TestConvert(unittest.TestCase):
44
+ class TestConvert(googletest.TestCase):
44
45
  """Tests conversion of various modules."""
45
46
 
46
47
  def setUp(self):
48
+ super().setUp()
47
49
  torch.manual_seed(0)
48
50
 
49
51
  def test_convert_add(self):
@@ -66,8 +68,9 @@ class TestConvert(unittest.TestCase):
66
68
  )
67
69
 
68
70
  def test_convert_dot_add(self):
71
+ """Tests conversion of a matrix multiplication followed by an add."""
72
+
69
73
  class DotAdd(torch.nn.Module):
70
- """Tests conversion of a matrix multiplication followed by an add."""
71
74
 
72
75
  def forward(self, a, b, c):
73
76
  return a @ b + c
@@ -97,20 +100,21 @@ class TestConvert(unittest.TestCase):
97
100
  """Tests conversion of a model with more than 10 arguments."""
98
101
 
99
102
  class AddChainWith11Args(torch.nn.Module):
103
+ """A model with 11 arguments."""
100
104
 
101
105
  def forward(
102
106
  self,
103
- arg0: "f32[64]",
104
- arg1: "f32[64]",
105
- arg2: "f32[64]",
106
- arg3: "f32[64]",
107
- arg4: "f32[64]",
108
- arg5: "f32[64]",
109
- arg6: "f32[64]",
110
- arg7: "f32[64]",
111
- arg8: "f32[64]",
112
- arg9: "f32[64]",
113
- arg10: "f32[64]",
107
+ arg0: torch.Tensor,
108
+ arg1: torch.Tensor,
109
+ arg2: torch.Tensor,
110
+ arg3: torch.Tensor,
111
+ arg4: torch.Tensor,
112
+ arg5: torch.Tensor,
113
+ arg6: torch.Tensor,
114
+ arg7: torch.Tensor,
115
+ arg8: torch.Tensor,
116
+ arg9: torch.Tensor,
117
+ arg10: torch.Tensor,
114
118
  ):
115
119
  add0 = torch.add(arg0, arg1)
116
120
  add1 = torch.add(add0, arg2)
@@ -149,6 +153,7 @@ class TestConvert(unittest.TestCase):
149
153
  """Tests conversion of a model that returns multiple outputs."""
150
154
 
151
155
  class BasicAddModelWithMultipleOutputs(torch.nn.Module):
156
+ """A model that returns multiple outputs."""
152
157
 
153
158
  def forward(self, arg0, arg1):
154
159
  add0 = arg0 + arg1
@@ -172,6 +177,7 @@ class TestConvert(unittest.TestCase):
172
177
  """Tests conversion of a model that returns multiple outputs."""
173
178
 
174
179
  class BasicAddModelWithMultipleOutputs(torch.nn.Module):
180
+ """A model that returns multiple outputs."""
175
181
 
176
182
  def forward(self, arg0, arg1):
177
183
  add0 = arg0 + arg1
@@ -215,8 +221,8 @@ class TestConvert(unittest.TestCase):
215
221
  )
216
222
  self.assertTrue(result)
217
223
 
218
- def test_apply_tfl_backdoor_flags(self):
219
- """Tests if _apply_tfl_backdoor_flags correctly sets the values in a Converter object."""
224
+ def test_apply_tfl_converter_flags(self):
225
+ """Tests if _apply_tfl_converter_flags correctly sets the values in a Converter object."""
220
226
 
221
227
  class MockConverterInternalObject:
222
228
 
@@ -231,12 +237,12 @@ class TestConvert(unittest.TestCase):
231
237
 
232
238
  mock_converter = MockConverter()
233
239
  flags = {"key1": "new_value1", "key2": {"subkey2": "new_subvalue2"}}
234
- cutils._apply_tfl_backdoor_flags(mock_converter, flags)
240
+ conversion_utils.apply_tfl_converter_flags(mock_converter, flags)
235
241
 
236
242
  self.assertTrue(flags["key1"], "new_value1")
237
243
  self.assertTrue(flags["key2"]["subkey2"], "new_subvalue2")
238
244
 
239
- def test_convert_add_backdoor_flags(self):
245
+ def test_convert_add_converter_flags(self):
240
246
  """Tests conversion of an add module setting a tflite converter flag."""
241
247
 
242
248
  class Add(torch.nn.Module):
@@ -250,21 +256,23 @@ class TestConvert(unittest.TestCase):
250
256
  )
251
257
  torch_module = Add().eval()
252
258
 
253
- with tempfile.TemporaryDirectory() as tmp_dir_path:
254
- ir_dump_path = os.path.join(
255
- tmp_dir_path, "test_convert_add_backdoor_flags_mlir_dump"
256
- )
257
- ai_edge_torch.convert(
258
- torch_module,
259
- args,
260
- _ai_edge_converter_flags={"ir_dump_dir": ir_dump_path},
261
- )
262
- self.assertTrue(os.path.isdir(ir_dump_path))
259
+ tmp_dir_path = self.create_tempdir()
260
+ ir_dump_path = os.path.join(
261
+ tmp_dir_path, "test_convert_add_converter_flags_mlir_dump"
262
+ )
263
+ ai_edge_torch.convert(
264
+ torch_module,
265
+ args,
266
+ _ai_edge_converter_flags={"ir_dump_dir": ir_dump_path},
267
+ )
268
+ self.assertTrue(os.path.isdir(ir_dump_path))
263
269
 
270
+ @googletest.skipIf(
271
+ not config.Config.use_torch_xla,
272
+ reason="Shape polymorphism is not yet support with odml_torch.",
273
+ )
264
274
  def test_convert_model_with_dynamic_batch(self):
265
- """
266
- Test converting a simple model with dynamic batch size.
267
- """
275
+ """Test converting a simple model with dynamic batch size."""
268
276
 
269
277
  class SampleModel(torch.nn.Module):
270
278
 
@@ -294,9 +302,7 @@ class TestConvert(unittest.TestCase):
294
302
  )
295
303
 
296
304
  def test_convert_model_with_kwargs(self):
297
- """
298
- Test converting a simple model with sample_kwargs.
299
- """
305
+ """Test converting a simple model with sample_kwargs."""
300
306
 
301
307
  class SampleModel(torch.nn.Module):
302
308
 
@@ -315,9 +321,7 @@ class TestConvert(unittest.TestCase):
315
321
  )
316
322
 
317
323
  def test_convert_model_with_args_kwargs(self):
318
- """
319
- Test converting a simple model with both sample_args and sample_kwargs.
320
- """
324
+ """Test converting a simple model with both sample_args and sample_kwargs."""
321
325
 
322
326
  class SampleModel(torch.nn.Module):
323
327
 
@@ -337,9 +341,7 @@ class TestConvert(unittest.TestCase):
337
341
  )
338
342
 
339
343
  def test_convert_model_with_args_nested_kwargs_1(self):
340
- """
341
- Test converting a simple model with both sample_args and nested sample_kwargs.
342
- """
344
+ """Test converting a simple model with both sample_args and nested sample_kwargs."""
343
345
 
344
346
  class SampleModel(torch.nn.Module):
345
347
 
@@ -366,9 +368,7 @@ class TestConvert(unittest.TestCase):
366
368
  )
367
369
 
368
370
  def test_convert_model_with_args_nested_kwargs_2(self):
369
- """
370
- Test converting a simple model with both sample_args and nested sample_kwargs.
371
- """
371
+ """Test converting a simple model with both sample_args and nested sample_kwargs."""
372
372
 
373
373
  class SampleModel(torch.nn.Module):
374
374
 
@@ -395,9 +395,7 @@ class TestConvert(unittest.TestCase):
395
395
  )
396
396
 
397
397
  def test_convert_model_with_args_nested_kwargs_3(self):
398
- """
399
- Test converting a simple model with both sample_args and nested sample_kwargs.
400
- """
398
+ """Test converting a simple model with both sample_args and nested sample_kwargs."""
401
399
 
402
400
  class SampleModel(torch.nn.Module):
403
401
 
@@ -437,4 +435,4 @@ class TestConvert(unittest.TestCase):
437
435
 
438
436
 
439
437
  if __name__ == "__main__":
440
- unittest.main()
438
+ googletest.main()
@@ -12,18 +12,21 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Tests conversion modules that are meant to be wrapped as composites."""
15
16
 
16
-
17
- from typing import Callable
18
- import unittest
17
+ from collections.abc import Callable
19
18
 
20
19
  import ai_edge_torch
21
20
  from ai_edge_torch.testing import model_coverage
22
21
  import parameterized
23
22
  import torch
24
23
 
24
+ from tensorflow.python.platform import googletest
25
+
26
+
27
+ def _func_to_torch_module(func: Callable[..., torch.Tensor]):
28
+ """Wraps a function into a torch module."""
25
29
 
26
- def _func_to_torch_module(func: Callable):
27
30
  class TestModule(torch.nn.Module):
28
31
 
29
32
  def __init__(self, func):
@@ -36,7 +39,7 @@ def _func_to_torch_module(func: Callable):
36
39
  return TestModule(func).eval()
37
40
 
38
41
 
39
- class TestConvertComposites(unittest.TestCase):
42
+ class TestConvertComposites(googletest.TestCase):
40
43
  """Tests conversion modules that are meant to be wrapped as composites."""
41
44
 
42
45
  def test_convert_hardswish(self):
@@ -51,7 +54,8 @@ class TestConvertComposites(unittest.TestCase):
51
54
  )
52
55
 
53
56
  @parameterized.parameterized.expand([
54
- # input_size, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
57
+ # (input_size, kernel_size, stride, padding, ceil_mode,
58
+ # count_include_pad, divisor_override)
55
59
  # no padding, stride = 1
56
60
  ([1, 3, 6, 6], [3, 3], [1, 1], [0, 0], False, True, None),
57
61
  # add stride
@@ -64,6 +68,8 @@ class TestConvertComposites(unittest.TestCase):
64
68
  ([1, 3, 6, 6], [3, 3], [1, 1], [0, 1], False, True, None),
65
69
  # add both stride and padding
66
70
  ([1, 3, 6, 6], [3, 3], [2, 2], [1, 1], False, True, None),
71
+ # padding set to one number
72
+ ([1, 3, 6, 6], [3, 3], [1, 1], 1, False, True, None),
67
73
  # count_include_pad = False
68
74
  ([1, 3, 6, 6], [3, 3], [1, 1], [1, 1], False, False, None),
69
75
  # ceil_mode = True
@@ -72,8 +78,6 @@ class TestConvertComposites(unittest.TestCase):
72
78
  ([1, 3, 6, 6], [3, 3], [3, 3], [1, 1], True, True, None),
73
79
  # set divisor_override
74
80
  ([1, 3, 6, 6], [3, 3], [1, 1], 0, False, True, 6),
75
- # padding set to one number
76
- ([1, 3, 6, 6], [3, 3], [1, 1], 1, False, True, None),
77
81
  ])
78
82
  def test_convert_avg_pool2d(self, input_size, *args):
79
83
  """Tests conversion of a module containing an avg_pool2d aten."""
@@ -111,7 +115,7 @@ class TestConvertComposites(unittest.TestCase):
111
115
  def test_convert_upsample_bilinear_functional(self, input_size, kwargs):
112
116
  """Tests conversion of a torch.nn.functional.upsample module."""
113
117
  torch_module = _func_to_torch_module(
114
- lambda input_tensor: torch.nn.functional.upsample(
118
+ lambda input_tensor: torch.nn.functional.upsample( # pylint: disable=unnecessary-lambda
115
119
  input_tensor, **kwargs
116
120
  )
117
121
  )
@@ -146,7 +150,7 @@ class TestConvertComposites(unittest.TestCase):
146
150
  def test_convert_upsample_bilinear(self, input_size, kwargs):
147
151
  """Tests conversion of a torch.nn.Upsample module."""
148
152
  torch_module = _func_to_torch_module(
149
- lambda input_tensor: torch.nn.Upsample(**kwargs)(input_tensor)
153
+ lambda input_tensor: torch.nn.Upsample(**kwargs)(input_tensor) # pylint: disable=unnecessary-lambda
150
154
  )
151
155
  tracing_args = (torch.randn(*input_size),)
152
156
  edge_model = ai_edge_torch.convert(torch_module, tracing_args)
@@ -179,7 +183,7 @@ class TestConvertComposites(unittest.TestCase):
179
183
  def test_convert_interpolate_bilinear_functional(self, input_size, kwargs):
180
184
  """Tests conversion of a torch.nn.functional.interpolate module."""
181
185
  torch_module = _func_to_torch_module(
182
- lambda input_tensor: torch.nn.functional.interpolate(
186
+ lambda input_tensor: torch.nn.functional.interpolate( # pylint: disable=unnecessary-lambda
183
187
  input_tensor, **kwargs
184
188
  )
185
189
  )
@@ -227,4 +231,4 @@ class TestConvertComposites(unittest.TestCase):
227
231
 
228
232
 
229
233
  if __name__ == '__main__':
230
- unittest.main()
234
+ googletest.main()
@@ -12,19 +12,21 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
16
- import unittest
15
+ """Tests for multi-signature conversion."""
17
16
 
18
17
  import ai_edge_torch
19
18
  from ai_edge_torch.testing import model_coverage
20
19
  import torch
21
20
  import torchvision
22
21
 
22
+ from tensorflow.python.platform import googletest
23
+
23
24
 
24
- class TestConvertMultiSignature(unittest.TestCase):
25
+ class TestConvertMultiSignature(googletest.TestCase):
25
26
  """Tests conversion of various modules through multi-signature conversion."""
26
27
 
27
28
  def setUp(self):
29
+ super().setUp()
28
30
  torch.manual_seed(0)
29
31
 
30
32
  def test_convert_mobilenet_v2_with_default(self):
@@ -144,4 +146,4 @@ class TestConvertMultiSignature(unittest.TestCase):
144
146
 
145
147
 
146
148
  if __name__ == "__main__":
147
- unittest.main()
149
+ googletest.main()
@@ -12,12 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
16
- import unittest
15
+ """Tests for to_channel_last_io API and module wrapper."""
17
16
 
18
17
  import ai_edge_torch
19
18
  import torch
20
19
 
20
+ from tensorflow.python.platform import googletest
21
+
21
22
 
22
23
  class Identity(torch.nn.Module):
23
24
 
@@ -25,7 +26,7 @@ class Identity(torch.nn.Module):
25
26
  return x
26
27
 
27
28
 
28
- class TestToChannelLastIO(unittest.TestCase):
29
+ class TestToChannelLastIO(googletest.TestCase):
29
30
  """Tests to_channel_last_io API and module wrapper."""
30
31
 
31
32
  def test_no_transformations(self):
@@ -92,4 +93,4 @@ class TestToChannelLastIO(unittest.TestCase):
92
93
 
93
94
 
94
95
  if __name__ == "__main__":
95
- unittest.main()
96
+ googletest.main()
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Transforms the input and output of a module to channel last layout."""
15
16
 
16
17
  from typing import Optional
17
18
 
@@ -82,8 +83,10 @@ def to_channel_last_io(
82
83
  (N, C, ...) to channel last (N, ..., C).
83
84
  outputs (list[int]): Transform outputs with indices in the list from channel
84
85
  first (N, C, ...) to channel last (N, ..., C).
86
+
85
87
  Returns:
86
- The wrapped nn.Module with additional layout transposes after inputs and/or before
88
+ The wrapped nn.Module with additional layout transposes after inputs and/or
89
+ before
87
90
  outputs.
88
91
  """
89
92
  return ChannelLastIOWrapper(module, args=args, outputs=outputs)
@@ -0,0 +1,24 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Provides a configuration for the AI Edge Torch library."""
17
+
18
+ import dataclasses
19
+ import os
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class Config:
24
+ use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "True") == "True"
@@ -0,0 +1,20 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ from absl import flags
17
+
18
+
19
+ def pytest_configure(config):
20
+ flags.FLAGS.mark_as_parsed()
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Culprit finder for AI Edge Torch conversion."""
15
16
 
16
17
  import contextlib
17
18
  import copy
@@ -20,14 +21,13 @@ import functools
20
21
  import io
21
22
  import operator
22
23
  import os
23
- import sys
24
24
  from typing import Any, Callable, Generator, List, Optional, Tuple, Union
25
25
 
26
26
  import ai_edge_torch
27
27
  from ai_edge_torch.debug import utils
28
- from functorch.compile import minifier as fx_minifier
29
28
  import torch
30
29
  from torch._functorch import aot_autograd
30
+ from torch._functorch.fx_minifier import minifier as fx_minifier
31
31
  import torch.utils._pytree as pytree
32
32
 
33
33
  _torch_float_dtypes = {
@@ -116,7 +116,7 @@ class Culprit(SearchResult):
116
116
  print_output: bool - If true, prints the code to stdout. Otherwise returns
117
117
  the code in a str.
118
118
  """
119
- # TODO (b/321263453): Support Python code gen with sample arg tensor values.
119
+ # TODO: b/321263453 - Support Python code gen with sample arg tensor values.
120
120
  random_inputs = True
121
121
 
122
122
  graph_module_code = self.graph_module.print_readable(
@@ -152,6 +152,7 @@ class Culprit(SearchResult):
152
152
 
153
153
  def print_code(self, print_output=True):
154
154
  """Print the Python code for culprit graph module, sample args, and AI
155
+
155
156
  Edge Torch conversion that will fail with the error.
156
157
 
157
158
  Args:
@@ -188,8 +189,8 @@ class Culprit(SearchResult):
188
189
 
189
190
 
190
191
  def _normalize_getitem_nodes(fx_gm: torch.fx.GraphModule):
191
- """
192
- This function turns all operator getitem nodes in ExportedProgram FX graph to
192
+ """This function turns all operator getitem nodes in ExportedProgram FX graph to
193
+
193
194
  new nodes composed of "computation + getitem". The normalization duplicates
194
195
  some computations in the graph but would make the graph more friendly for
195
196
  partitioning in FX minifier.
@@ -367,19 +368,18 @@ def _search_model(
367
368
  max_granularity: Optional[int] = None,
368
369
  enable_fx_minifier_logging: bool = False,
369
370
  ) -> Generator[SearchResult, None, None]:
370
- """Finds subgraphs in the torch model that satisfy a certain predicate function provided by the users.
371
+ """Finds subgraphs in the torch model that satify a certain predicate function provided by the users.
371
372
 
372
373
  Args:
373
- predicate_f: a predicate function the users specify.
374
- It takes a FX (sub)graph and the inputs to this graph,
375
- return True if the graph satisfies the predicate,
376
- return False otherwise.
374
+ predicate_f: a predicate function the users specify. It takes a FX
375
+ (sub)graph and the inputs to this graph, return True if the graph
376
+ satisfies the predicate, return False otherwise.
377
377
  model: model in which to search subgraph.
378
- export_args: A set of args to trace the model with,
379
- i.e. model(*args) must run.
380
- max_granularity - FX minifier arg. The maximum granularity (number of nodes)
381
- in the returned ATen FX subgraph of the culprit.
382
- enable_fx_minifier_logging: If true, allows the underlying FX minifier to log the progress.
378
+ export_args: A set of args to trace the model with, i.e. model(*args) must
379
+ run. max_granularity - FX minifier arg. The maximum granularity (number of
380
+ nodes) in the returned ATen FX subgraph of the culprit.
381
+ enable_fx_minifier_logging: If true, allows the underlying FX minifier to
382
+ log the progress.
383
383
  """
384
384
 
385
385
  if isinstance(model, torch.nn.Module):
@@ -469,13 +469,13 @@ def find_culprits(
469
469
 
470
470
  Args:
471
471
  torch_model: model to export and save
472
- args: A set of args to trace the model with, i.e.
473
- torch_model(*args) must run
474
- max_granularity - FX minifier arg. The maximum granularity (number of nodes)
475
- in the returned ATen FX subgraph of the culprit.
476
- runtime_errors: If true, find culprits for Python runtime errors
477
- with converted model.
478
- enable_fx_minifier_logging: If true, allows the underlying FX minifier to log the progress.
472
+ args: A set of args to trace the model with, i.e. torch_model(*args) must
473
+ run max_granularity - FX minifier arg. The maximum granularity (number of
474
+ nodes) in the returned ATen FX subgraph of the culprit.
475
+ runtime_errors: If true, find culprits for Python runtime errors with
476
+ converted model.
477
+ enable_fx_minifier_logging: If true, allows the underlying FX minifier to
478
+ log the progress.
479
479
  """
480
480
 
481
481
  fx_minifier_checker = functools.partial(
@@ -17,11 +17,12 @@
17
17
  import ast
18
18
  import io
19
19
  import sys
20
- import unittest
21
20
 
22
21
  from ai_edge_torch.debug import find_culprits
23
22
  import torch
24
23
 
24
+ from tensorflow.python.platform import googletest
25
+
25
26
  _test_culprit_lib = torch.library.Library("test_culprit", "DEF")
26
27
 
27
28
  _test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
@@ -49,7 +50,7 @@ class BadModel(torch.nn.Module):
49
50
  return x
50
51
 
51
52
 
52
- class TestCulprit(unittest.TestCase):
53
+ class TestCulprit(googletest.TestCase):
53
54
 
54
55
  def test_find_culprits(self):
55
56
  model = BadModel().eval()
@@ -131,4 +132,4 @@ class TestCulprit(unittest.TestCase):
131
132
 
132
133
 
133
134
  if __name__ == "__main__":
134
- unittest.main()
135
+ googletest.main()
@@ -12,15 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
16
-
17
- import unittest
15
+ """Tests for search_model."""
18
16
 
19
17
  from ai_edge_torch.debug import _search_model
20
18
  import torch
21
19
 
20
+ from tensorflow.python.platform import googletest
21
+
22
22
 
23
- class TestSearchModel(unittest.TestCase):
23
+ class TestSearchModel(googletest.TestCase):
24
24
 
25
25
  def test_search_model_with_ops(self):
26
26
  class MultipleOpsModel(torch.nn.Module):
@@ -48,4 +48,4 @@ class TestSearchModel(unittest.TestCase):
48
48
 
49
49
 
50
50
  if __name__ == "__main__":
51
- unittest.main()
51
+ googletest.main()
@@ -12,12 +12,12 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ """Utils for debugging."""
16
+
15
17
  import contextlib
16
18
  import sys
17
19
 
18
20
  import torch
19
- from torch.export.graph_signature import InputKind
20
- import torch.fx._pytree as fx_pytree
21
21
  from torch.utils import _pytree as pytree
22
22
 
23
23
 
@@ -33,6 +33,15 @@ def exported_program_to_fx_graph_module_and_inputs(
33
33
 
34
34
  @contextlib.contextmanager
35
35
  def redirect_stdio(stdout, stderr):
36
+ """Redirects stdout and stderr to the given file objects.
37
+
38
+ Args:
39
+ stdout: A file object to redirect stdout to.
40
+ stderr: A file object to redirect stderr to.
41
+
42
+ Yields:
43
+ The file objects that stdout and stderr were redirected to.
44
+ """
36
45
  old_stdout = sys.stdout
37
46
  old_stderr = sys.stderr
38
47
 
@@ -34,8 +34,8 @@ def convert_gemma_to_tflite(
34
34
  quantize: bool = True,
35
35
  ):
36
36
  """An example method for converting a Gemma 2B model to multi-signature
37
- tflite model.
38
37
 
38
+ tflite model.
39
39
  Args:
40
40
  checkpoint_path (str): The filepath to the model checkpoint, or directory
41
41
  holding the checkpoint.
@@ -43,8 +43,8 @@ def convert_gemma_to_tflite(
43
43
  Defaults to 512.
44
44
  kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
45
45
  including both prefill and decode. Defaults to 1024.
46
- quantize (bool, optional): Whether the model should be quanized.
47
- Defaults to True.
46
+ quantize (bool, optional): Whether the model should be quanized. Defaults
47
+ to True.
48
48
  """
49
49
  pytorch_model = gemma.build_2b_model(
50
50
  checkpoint_path, kv_cache_max_len=kv_cache_max_len
@@ -73,7 +73,9 @@ class Gemma(nn.Module):
73
73
  )
74
74
  self.rope_cache = attn_utils.build_rope_cache(
75
75
  size=config.kv_cache_max,
76
- dim=int(config.attn_config.rotary_percentage * config.head_dim),
76
+ dim=int(
77
+ config.attn_config.rotary_percentage * config.attn_config.head_dim
78
+ ),
77
79
  base=10_000,
78
80
  condense_ratio=1,
79
81
  dtype=torch.float32,
@@ -125,6 +127,7 @@ class Gemma(nn.Module):
125
127
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
126
128
  attn_config = cfg.AttentionConfig(
127
129
  num_heads=8,
130
+ head_dim=256,
128
131
  num_query_groups=1,
129
132
  rotary_percentage=1.0,
130
133
  )