ai-edge-torch-nightly 0.2.0.dev20240806__py3-none-any.whl → 0.3.0.dev20240809__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (104) hide show
  1. ai_edge_torch/__init__.py +5 -5
  2. ai_edge_torch/{convert → _convert}/conversion.py +40 -50
  3. ai_edge_torch/_convert/conversion_utils.py +64 -0
  4. ai_edge_torch/{convert → _convert}/converter.py +83 -43
  5. ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
  6. ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
  7. ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
  8. ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
  9. ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
  10. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
  11. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
  12. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
  13. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
  14. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
  15. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
  16. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
  17. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
  18. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
  19. ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
  20. ai_edge_torch/_convert/signature.py +100 -0
  21. ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
  22. ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
  23. ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
  24. ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
  25. ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
  26. ai_edge_torch/config.py +24 -0
  27. ai_edge_torch/conftest.py +20 -0
  28. ai_edge_torch/debug/culprit.py +22 -22
  29. ai_edge_torch/debug/test/test_culprit.py +4 -3
  30. ai_edge_torch/debug/test/test_search_model.py +5 -5
  31. ai_edge_torch/debug/utils.py +11 -2
  32. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
  33. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
  34. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
  35. ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
  36. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
  38. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
  39. ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
  40. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
  41. ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
  42. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  43. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
  44. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
  45. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
  46. ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
  47. ai_edge_torch/generative/examples/t5/t5.py +2 -2
  48. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
  49. ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
  50. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
  52. ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
  55. ai_edge_torch/generative/fx_passes/__init__.py +2 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
  57. ai_edge_torch/generative/layers/attention.py +35 -26
  58. ai_edge_torch/generative/layers/attention_utils.py +23 -12
  59. ai_edge_torch/generative/layers/builder.py +0 -1
  60. ai_edge_torch/generative/layers/feed_forward.py +6 -10
  61. ai_edge_torch/generative/layers/kv_cache.py +0 -1
  62. ai_edge_torch/generative/layers/model_config.py +2 -5
  63. ai_edge_torch/generative/layers/normalization.py +5 -7
  64. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
  65. ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
  66. ai_edge_torch/generative/layers/unet/model_config.py +14 -15
  67. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
  68. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
  69. ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
  70. ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
  71. ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
  72. ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
  73. ai_edge_torch/generative/test/test_model_conversion.py +24 -25
  74. ai_edge_torch/generative/test/test_quantize.py +10 -5
  75. ai_edge_torch/generative/utilities/loader.py +12 -12
  76. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
  77. ai_edge_torch/generative/utilities/t5_loader.py +12 -13
  78. ai_edge_torch/hlfb/__init__.py +1 -1
  79. ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
  80. ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
  81. ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
  82. ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
  83. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
  84. ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
  85. ai_edge_torch/lowertools/_shim.py +80 -0
  86. ai_edge_torch/lowertools/common_utils.py +89 -0
  87. ai_edge_torch/lowertools/odml_torch_utils.py +211 -0
  88. ai_edge_torch/lowertools/torch_xla_utils.py +273 -0
  89. ai_edge_torch/model.py +14 -9
  90. ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
  91. ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
  92. ai_edge_torch/quantize/quant_config.py +7 -7
  93. ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
  94. ai_edge_torch/version.py +1 -1
  95. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/METADATA +1 -1
  96. ai_edge_torch_nightly-0.3.0.dev20240809.dist-info/RECORD +141 -0
  97. ai_edge_torch/convert/conversion_utils.py +0 -439
  98. ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/RECORD +0 -133
  99. /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
  100. /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
  101. /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
  102. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/LICENSE +0 -0
  103. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/WHEEL +0 -0
  104. {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/top_level.txt +0 -0
@@ -14,8 +14,7 @@
14
14
  # ==============================================================================
15
15
 
16
16
  # UNet configuration class.
17
- from dataclasses import dataclass
18
- from dataclasses import field
17
+ import dataclasses
19
18
  import enum
20
19
  from typing import List, Optional
21
20
 
@@ -30,13 +29,13 @@ class SamplingType(enum.Enum):
30
29
  CONVOLUTION = enum.auto()
31
30
 
32
31
 
33
- @dataclass
32
+ @dataclasses.dataclass
34
33
  class UpSamplingConfig:
35
34
  mode: SamplingType
36
35
  scale_factor: float
37
36
 
38
37
 
39
- @dataclass
38
+ @dataclasses.dataclass
40
39
  class DownSamplingConfig:
41
40
  mode: SamplingType
42
41
  in_channels: int
@@ -46,7 +45,7 @@ class DownSamplingConfig:
46
45
  out_channels: Optional[int] = None
47
46
 
48
47
 
49
- @dataclass
48
+ @dataclasses.dataclass
50
49
  class ResidualBlock2DConfig:
51
50
  in_channels: int
52
51
  out_channels: int
@@ -56,7 +55,7 @@ class ResidualBlock2DConfig:
56
55
  time_embedding_channels: Optional[int] = None
57
56
 
58
57
 
59
- @dataclass
58
+ @dataclasses.dataclass
60
59
  class AttentionBlock2DConfig:
61
60
  dim: int
62
61
  normalization_config: layers_cfg.NormalizationConfig
@@ -65,7 +64,7 @@ class AttentionBlock2DConfig:
65
64
  attention_batch_size: int = 1
66
65
 
67
66
 
68
- @dataclass
67
+ @dataclasses.dataclass
69
68
  class CrossAttentionBlock2DConfig:
70
69
  query_dim: int
71
70
  cross_dim: int
@@ -75,7 +74,7 @@ class CrossAttentionBlock2DConfig:
75
74
  attention_batch_size: int = 1
76
75
 
77
76
 
78
- @dataclass
77
+ @dataclasses.dataclass
79
78
  class FeedForwardBlock2DConfig:
80
79
  dim: int
81
80
  hidden_dim: int
@@ -84,7 +83,7 @@ class FeedForwardBlock2DConfig:
84
83
  use_bias: bool
85
84
 
86
85
 
87
- @dataclass
86
+ @dataclasses.dataclass
88
87
  class TransformerBlock2DConfig:
89
88
  pre_conv_normalization_config: layers_cfg.NormalizationConfig
90
89
  attention_block_config: AttentionBlock2DConfig
@@ -92,7 +91,7 @@ class TransformerBlock2DConfig:
92
91
  feed_forward_block_config: FeedForwardBlock2DConfig
93
92
 
94
93
 
95
- @dataclass
94
+ @dataclasses.dataclass
96
95
  class UpDecoderBlock2DConfig:
97
96
  in_channels: int
98
97
  out_channels: int
@@ -113,7 +112,7 @@ class UpDecoderBlock2DConfig:
113
112
  context_dim: Optional[int] = None
114
113
 
115
114
 
116
- @dataclass
115
+ @dataclasses.dataclass
117
116
  class SkipUpDecoderBlock2DConfig:
118
117
  in_channels: int
119
118
  out_channels: int
@@ -136,7 +135,7 @@ class SkipUpDecoderBlock2DConfig:
136
135
  context_dim: Optional[int] = None
137
136
 
138
137
 
139
- @dataclass
138
+ @dataclasses.dataclass
140
139
  class DownEncoderBlock2DConfig:
141
140
  in_channels: int
142
141
  out_channels: int
@@ -157,7 +156,7 @@ class DownEncoderBlock2DConfig:
157
156
  context_dim: Optional[int] = None
158
157
 
159
158
 
160
- @dataclass
159
+ @dataclasses.dataclass
161
160
  class MidBlock2DConfig:
162
161
  in_channels: int
163
162
  normalization_config: layers_cfg.NormalizationConfig
@@ -173,7 +172,7 @@ class MidBlock2DConfig:
173
172
  context_dim: Optional[int] = None
174
173
 
175
174
 
176
- @dataclass
175
+ @dataclasses.dataclass
177
176
  class AutoEncoderConfig:
178
177
  """Configurations of encoder/decoder in the autoencoder model."""
179
178
 
@@ -210,7 +209,7 @@ class AutoEncoderConfig:
210
209
  mid_block_config: MidBlock2DConfig
211
210
 
212
211
 
213
- @dataclass
212
+ @dataclasses.dataclass
214
213
  class DiffusionModelConfig:
215
214
  """Configurations of Diffusion model."""
216
215
 
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -13,8 +13,6 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import json
17
-
18
16
  from ai_edge_quantizer import quantizer
19
17
  from ai_edge_torch.generative.quantize import quant_attrs
20
18
  from ai_edge_torch.generative.quantize import quant_recipe
@@ -25,7 +25,8 @@ class LayerQuantRecipe:
25
25
  """Quantization recipe for a single Edge Generative API layer (e.g. Attention).
26
26
 
27
27
  Generic layer-scoped quantization recipe that specifies how this layer should
28
- be quantized by the Edge Generative API. This is applicable to layers implemented
28
+ be quantized by the Edge Generative API. This is applicable to layers
29
+ implemented
29
30
  in ai_edge_torch/generative/layers/. Combinations of attributes that are not
30
31
  supported during runtime will be detected when .verify() is called.
31
32
 
@@ -83,7 +84,8 @@ class LayerQuantRecipe:
83
84
  class GenerativeQuantRecipe:
84
85
  """Quantization recipe for a model composed of the Edge Generative API layers.
85
86
 
86
- Some layers can be specified with different `LayerQuantRecipe` for each block by
87
+ Some layers can be specified with different `LayerQuantRecipe` for each block
88
+ by
87
89
  providing a dictionary keyed by the TransformerBlock index, e.g. attention
88
90
  and feedforward. For example,
89
91
 
@@ -102,11 +104,11 @@ class GenerativeQuantRecipe:
102
104
  default: The quantization recipe for global scope of the model.
103
105
  embedding: Recipe for the embedding table.
104
106
  attention: Recipe for the attention blocks. This could be specified with
105
- different LayerQuantRecipe for each block by providing a dictionary
106
- keyed by the TransformerBlock index.
107
+ different LayerQuantRecipe for each block by providing a dictionary keyed
108
+ by the TransformerBlock index.
107
109
  feedforward: Recipe for the feedforward layers. This could be specified with
108
- different LayerQuantRecipe for each block by providing a dictionary
109
- keyed by the TransformerBlock index.
110
+ different LayerQuantRecipe for each block by providing a dictionary keyed
111
+ by the TransformerBlock index.
110
112
  """
111
113
 
112
114
  default: Optional[LayerQuantRecipe] = None
@@ -16,7 +16,8 @@
16
16
  """Helper functions to construct custom quantization recipes.
17
17
 
18
18
  These are intended for more advanced users who want to configure their own
19
- quantization recipes. For pre-constructed recipes, use `quant_recipes.py` instead.
19
+ quantization recipes. For pre-constructed recipes, use `quant_recipes.py`
20
+ instead.
20
21
 
21
22
  Typical usage example:
22
23
 
@@ -14,24 +14,23 @@
14
14
  # ==============================================================================
15
15
  # A suite of tests to validate experimental external KV Cache layers and models.
16
16
 
17
- import unittest
18
-
19
17
  from ai_edge_torch.generative.examples.experimental.gemma import gemma
20
18
  from ai_edge_torch.generative.examples.experimental.phi import phi2
21
19
  from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
22
20
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
23
21
  import ai_edge_torch.generative.layers.model_config as cfg
24
- import numpy as np
25
22
  import torch
26
23
 
24
+ from tensorflow.python.platform import googletest
25
+
27
26
 
28
- class TestExternalKVLayers(unittest.TestCase):
27
+ class TestExternalKVLayers(googletest.TestCase):
29
28
 
30
29
  def _get_test_config(
31
30
  self, num_layers, head_dim, num_query_groups, kv_cache_max_len
32
31
  ):
33
32
  attn_config = cfg.AttentionConfig(
34
- num_heads=1, num_query_groups=num_query_groups
33
+ num_heads=1, head_dim=head_dim, num_query_groups=num_query_groups
35
34
  )
36
35
  config = cfg.ModelConfig(
37
36
  kv_cache_max_len=kv_cache_max_len,
@@ -117,7 +116,7 @@ class TestExternalKVLayers(unittest.TestCase):
117
116
  self.assertEqual(input_specs[1].arg.name, "kv_v_0")
118
117
 
119
118
 
120
- class TestExternalKVModels(unittest.TestCase):
119
+ class TestExternalKVModels(googletest.TestCase):
121
120
 
122
121
  def test_can_build_gemma(self):
123
122
  gemma.define_and_run_2b(checkpoint_path=None, test_model=True)
@@ -130,4 +129,4 @@ class TestExternalKVModels(unittest.TestCase):
130
129
 
131
130
 
132
131
  if __name__ == "__main__":
133
- unittest.main()
132
+ googletest.main()
@@ -16,15 +16,16 @@
16
16
 
17
17
  import os
18
18
  import tempfile
19
- import unittest
20
19
 
21
20
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
22
21
  from ai_edge_torch.generative.utilities import loader as loading_utils
23
22
  import safetensors.torch
24
23
  import torch
25
24
 
25
+ from tensorflow.python.platform import googletest
26
26
 
27
- class TestLoader(unittest.TestCase):
27
+
28
+ class TestLoader(googletest.TestCase):
28
29
  """Unit tests that check weight loader."""
29
30
 
30
31
  def test_load_safetensors(self):
@@ -78,4 +79,4 @@ class TestLoader(unittest.TestCase):
78
79
 
79
80
 
80
81
  if __name__ == "__main__":
81
- unittest.main()
82
+ googletest.main()
@@ -14,9 +14,6 @@
14
14
  # ==============================================================================
15
15
  # Testing model conversion for a few gen-ai models.
16
16
  import copy
17
- import os
18
- import tempfile
19
- import unittest
20
17
 
21
18
  import ai_edge_torch
22
19
  from ai_edge_torch.generative.examples.gemma import gemma
@@ -27,22 +24,24 @@ from ai_edge_torch.testing import model_coverage
27
24
  import numpy as np
28
25
  import torch
29
26
 
27
+ from tensorflow.python.platform import googletest
30
28
 
31
- class TestModelConversion(unittest.TestCase):
29
+
30
+ class TestModelConversion(googletest.TestCase):
32
31
  """Unit tests that check for model conversion and correctness."""
33
32
 
34
33
  def test_toy_model_with_kv_cache(self):
35
34
  config = toy_model_with_kv_cache.get_model_config()
36
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
35
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
37
36
  idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
38
37
  [10], dtype=torch.int64
39
38
  )
40
39
 
41
40
  edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
42
41
 
43
- # TODO(b/338288901): re-enable test to check output tensors.
42
+ # TODO: b/338288901 - re-enable test to check output tensors.
44
43
  skip_output_check = True
45
- if skip_output_check is False:
44
+ if not skip_output_check:
46
45
  self.assertTrue(
47
46
  model_coverage.compare_tflite_torch(
48
47
  edge_model,
@@ -57,16 +56,16 @@ class TestModelConversion(unittest.TestCase):
57
56
  def test_toy_model_with_multi_batches(self):
58
57
  config = toy_model_with_kv_cache.get_model_config()
59
58
  config.batch_size = 2
60
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
59
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
61
60
  idx, input_pos = torch.tensor([[1], [2]], dtype=torch.long), torch.tensor(
62
61
  [10], dtype=torch.int64
63
62
  )
64
63
 
65
64
  edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
66
65
 
67
- # TODO(b/338288901): re-enable test to check output tensors.
66
+ # TODO: b/338288901 - re-enable test to check output tensors.
68
67
  skip_output_check = True
69
- if skip_output_check is False:
68
+ if not skip_output_check:
70
69
  self.assertTrue(
71
70
  model_coverage.compare_tflite_torch(
72
71
  edge_model,
@@ -81,16 +80,16 @@ class TestModelConversion(unittest.TestCase):
81
80
  def test_toy_model_with_kv_cache_with_hlfb(self):
82
81
  config = toy_model_with_kv_cache.get_model_config()
83
82
  config.enable_hlfb = True
84
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
83
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
85
84
  idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
86
85
  [10], dtype=torch.int64
87
86
  )
88
87
 
89
88
  edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
90
89
 
91
- # TODO(b/338288901): re-enable test to check output tensors.
90
+ # TODO: b/338288901 - re-enable test to check output tensors.
92
91
  skip_output_check = True
93
- if skip_output_check is False:
92
+ if not skip_output_check:
94
93
  self.assertTrue(
95
94
  model_coverage.compare_tflite_torch(
96
95
  edge_model,
@@ -105,7 +104,7 @@ class TestModelConversion(unittest.TestCase):
105
104
  def test_tiny_llama(self):
106
105
  self.skipTest("b/338288901")
107
106
  config = tiny_llama.get_fake_model_config_for_test()
108
- pytorch_model = tiny_llama.TinyLLamma(config)
107
+ pytorch_model = tiny_llama.TinyLLamma(config).eval()
109
108
 
110
109
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
111
110
  tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
@@ -114,9 +113,9 @@ class TestModelConversion(unittest.TestCase):
114
113
 
115
114
  edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
116
115
 
117
- # TODO(b/338288901): re-enable test to check output tensors.
116
+ # TODO: b/338288901 - re-enable test to check output tensors.
118
117
  skip_output_check = True
119
- if skip_output_check is False:
118
+ if not skip_output_check:
120
119
  self.assertTrue(
121
120
  model_coverage.compare_tflite_torch(
122
121
  edge_model,
@@ -130,7 +129,7 @@ class TestModelConversion(unittest.TestCase):
130
129
 
131
130
  def test_tiny_llama_multisig(self):
132
131
  config = tiny_llama.get_fake_model_config_for_test()
133
- pytorch_model = tiny_llama.TinyLLamma(config)
132
+ pytorch_model = tiny_llama.TinyLLamma(config).eval()
134
133
 
135
134
  # prefill
136
135
  seq_len = 10
@@ -151,9 +150,9 @@ class TestModelConversion(unittest.TestCase):
151
150
  .convert()
152
151
  )
153
152
 
154
- # TODO(b/338288901): re-enable test to check output tensors.
153
+ # TODO: b/338288901 - re-enable test to check output tensors.
155
154
  skip_output_check = True
156
- if skip_output_check is False:
155
+ if not skip_output_check:
157
156
  copied_model = copy.deepcopy(pytorch_model)
158
157
 
159
158
  self.assertTrue(
@@ -188,9 +187,9 @@ class TestModelConversion(unittest.TestCase):
188
187
 
189
188
  edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
190
189
 
191
- # TODO(b/338288901): re-enable test to check output tensors.
190
+ # TODO: b/338288901 - re-enable test to check output tensors.
192
191
  skip_output_check = True
193
- if skip_output_check is False:
192
+ if not skip_output_check:
194
193
  # TODO(talumbau, haoliang): debug numerical diff.
195
194
  self.assertTrue(
196
195
  model_coverage.compare_tflite_torch(
@@ -206,7 +205,7 @@ class TestModelConversion(unittest.TestCase):
206
205
  def test_phi2(self):
207
206
  self.skipTest("b/338288901")
208
207
  config = phi2.get_fake_model_config_for_test()
209
- pytorch_model = phi2.Phi2(config)
208
+ pytorch_model = phi2.Phi2(config).eval()
210
209
 
211
210
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
212
211
  tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
@@ -215,9 +214,9 @@ class TestModelConversion(unittest.TestCase):
215
214
 
216
215
  edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
217
216
 
218
- # TODO(b/338288901): re-enable test to check output tensors.
217
+ # TODO: b/338288901 - re-enable test to check output tensors.
219
218
  skip_output_check = True
220
- if skip_output_check is False:
219
+ if not skip_output_check:
221
220
  self.assertTrue(
222
221
  model_coverage.compare_tflite_torch(
223
222
  edge_model,
@@ -231,4 +230,4 @@ class TestModelConversion(unittest.TestCase):
231
230
 
232
231
 
233
232
  if __name__ == "__main__":
234
- unittest.main()
233
+ googletest.main()
@@ -13,9 +13,8 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import unittest
17
-
18
16
  import ai_edge_torch
17
+ from ai_edge_torch import config
19
18
  from ai_edge_torch.generative.examples.test_models import toy_model # NOQA
20
19
  from ai_edge_torch.generative.quantize import quant_recipe
21
20
  from ai_edge_torch.generative.quantize import quant_recipe_utils
@@ -29,8 +28,10 @@ from ai_edge_torch.testing import model_coverage
29
28
  from parameterized import parameterized
30
29
  import torch
31
30
 
31
+ from tensorflow.python.platform import googletest
32
+
32
33
 
33
- class TestVerifyRecipes(unittest.TestCase):
34
+ class TestVerifyRecipes(googletest.TestCase):
34
35
  """Unit tests that check for model quantization recipes."""
35
36
 
36
37
  @parameterized.expand([
@@ -87,7 +88,7 @@ class TestVerifyRecipes(unittest.TestCase):
87
88
  ).verify()
88
89
 
89
90
 
90
- class TestQuantizeConvert(unittest.TestCase):
91
+ class TestQuantizeConvert(googletest.TestCase):
91
92
  """Test conversion with quantization."""
92
93
 
93
94
  def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
@@ -111,6 +112,10 @@ class TestQuantizeConvert(unittest.TestCase):
111
112
  (_attention_int8_dynamic_recipe()),
112
113
  (_feedforward_int8_dynamic_recipe()),
113
114
  ])
115
+ @googletest.skipIf(
116
+ not config.Config.use_torch_xla,
117
+ reason="Not working with odml_torch at the moment.",
118
+ )
114
119
  def test_quantize_convert_toy_sizes(self, quant_config):
115
120
  config = toy_model.get_model_config()
116
121
  pytorch_model = toy_model.ToySingleLayerModel(config)
@@ -157,4 +162,4 @@ class TestQuantizeConvert(unittest.TestCase):
157
162
 
158
163
 
159
164
  if __name__ == "__main__":
160
- unittest.main()
165
+ googletest.main()
@@ -92,9 +92,7 @@ def load_pytorch_statedict(full_path: str):
92
92
 
93
93
 
94
94
  class ModelLoader:
95
- """A utility class for loading and converting model checkpoints to the
96
- Edge Generative API layer format.
97
- """
95
+ """Utlity for loading model checkpoints to the Edge Generative API layer."""
98
96
 
99
97
  @dataclass
100
98
  class TensorNames:
@@ -116,12 +114,13 @@ class ModelLoader:
116
114
  lm_head: str = None
117
115
 
118
116
  def __init__(self, file_name: str, names: TensorNames) -> None:
119
- """ModelLoader constructor. Can be used to load multiple models of the same
120
- type.
117
+ """ModelLoader constructor.
118
+
119
+ Can be used to load multiple models of the same type.
121
120
 
122
121
  Args:
123
- file_name (str): Path to the checkpoint. Can be a directory or an
124
- exact file.
122
+ file_name (str): Path to the checkpoint. Can be a directory or an exact
123
+ file.
125
124
  names (TensorNames): An instance of `TensorNames` to determine mappings.
126
125
  """
127
126
  self._file_name = file_name
@@ -140,7 +139,8 @@ class ModelLoader:
140
139
 
141
140
  Returns:
142
141
  missing_keys (List[str]): a list of str containing the missing keys.
143
- unexpected_keys (List[str]): a list of str containing the unexpected keys.
142
+ unexpected_keys (List[str]): a list of str containing the unexpected
143
+ keys.
144
144
 
145
145
  Raises:
146
146
  ValueError: If conversion results in unmapped tensors and strict mode is
@@ -208,7 +208,7 @@ class ModelLoader:
208
208
  if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"):
209
209
  return load_pytorch_statedict
210
210
 
211
- raise ValueError(f"File format not supported.")
211
+ raise ValueError("File format not supported.")
212
212
 
213
213
  def _map_feedforward(
214
214
  self,
@@ -346,9 +346,9 @@ class ModelLoader:
346
346
  q_per_kv = (
347
347
  config.attn_config.num_heads // config.attn_config.num_query_groups
348
348
  )
349
- qs = torch.split(q, config.head_dim * q_per_kv)
350
- ks = torch.split(k, config.head_dim)
351
- vs = torch.split(v, config.head_dim)
349
+ qs = torch.split(q, config.attn_config.head_dim * q_per_kv)
350
+ ks = torch.split(k, config.attn_config.head_dim)
351
+ vs = torch.split(v, config.attn_config.head_dim)
352
352
  cycled = [t for group in zip(qs, ks, vs) for t in group]
353
353
  return torch.cat(cycled)
354
354
  else: