ai-edge-torch-nightly 0.3.0.dev20240904__py3-none-any.whl → 0.3.0.dev20240906__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -71,6 +71,56 @@ class ToySingleLayerModel(torch.nn.Module):
71
71
  return self.lm_head(x)
72
72
 
73
73
 
74
+ class ToySingleLayerModelWeightSharing(torch.nn.Module):
75
+
76
+ def __init__(self, config: cfg.ModelConfig) -> None:
77
+ super().__init__()
78
+ self.lm_head = nn.Linear(
79
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
80
+ )
81
+ self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
82
+ self.lm_head = nn.Linear(
83
+ config.embedding_dim,
84
+ config.vocab_size,
85
+ bias=config.lm_head_use_bias,
86
+ )
87
+ self.lm_head.weight.data = self.tok_embedding.weight.data
88
+ self.transformer_block = TransformerBlock(config)
89
+ self.final_norm = builder.build_norm(
90
+ config.embedding_dim,
91
+ config.final_norm_config,
92
+ )
93
+ self.rope_cache = attn_utils.build_rope_cache(
94
+ size=config.max_seq_len,
95
+ dim=int(
96
+ config.attn_config.rotary_percentage * config.attn_config.head_dim
97
+ ),
98
+ base=10_000,
99
+ condense_ratio=1,
100
+ dtype=torch.float32,
101
+ device=torch.device('cpu'),
102
+ )
103
+ self.mask_cache = attn_utils.build_causal_mask_cache(
104
+ size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
105
+ )
106
+ self.config = config
107
+
108
+ @torch.inference_mode
109
+ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
110
+ x = self.tok_embedding(idx)
111
+ cos, sin = self.rope_cache
112
+
113
+ cos = cos.index_select(0, input_pos)
114
+ sin = sin.index_select(0, input_pos)
115
+ mask = self.mask_cache.index_select(2, input_pos)
116
+ mask = mask[:, :, :, : self.config.max_seq_len]
117
+
118
+ x = self.transformer_block(x, (cos, sin), mask, input_pos)
119
+ x = self.final_norm(x)
120
+ res = self.lm_head(x)
121
+ return res
122
+
123
+
74
124
  def get_model_config() -> cfg.ModelConfig:
75
125
  attn_config = cfg.AttentionConfig(
76
126
  num_heads=32,
@@ -70,7 +70,6 @@ class TestModelConversion(googletest.TestCase):
70
70
  )
71
71
  )
72
72
 
73
-
74
73
  @googletest.skipIf(
75
74
  ai_edge_config.Config.use_torch_xla,
76
75
  reason="tests with custom ops are not supported on oss",
@@ -130,6 +129,7 @@ class TestModelConversion(googletest.TestCase):
130
129
  )
131
130
 
132
131
  copied_model = copy.deepcopy(pytorch_model)
132
+ copied_edge = copy.deepcopy(edge_model)
133
133
 
134
134
  self.assertTrue(
135
135
  model_coverage.compare_tflite_torch(
@@ -141,18 +141,15 @@ class TestModelConversion(googletest.TestCase):
141
141
  )
142
142
  )
143
143
 
144
- # TODO(b/362840003): figure why this decode output has big numerical diff.
145
- skip_output_check = True
146
- if not skip_output_check:
147
- self.assertTrue(
148
- model_coverage.compare_tflite_torch(
149
- edge_model,
150
- copied_model,
151
- (decode_token, decode_input_pos),
152
- signature_name="decode",
153
- num_valid_inputs=1,
154
- )
155
- )
144
+ self.assertTrue(
145
+ model_coverage.compare_tflite_torch(
146
+ copied_edge,
147
+ copied_model,
148
+ (decode_token, decode_input_pos),
149
+ signature_name="decode",
150
+ num_valid_inputs=1,
151
+ )
152
+ )
156
153
 
157
154
 
158
155
  if __name__ == "__main__":
@@ -82,28 +82,28 @@ class TestModelConversion(googletest.TestCase):
82
82
  model.eval()
83
83
 
84
84
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
85
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
86
- tokens[0, :4] = idx
87
- input_pos = torch.arange(0, 10)
85
+ prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
86
+ prefill_tokens[0, :4] = idx
87
+ prefill_input_pos = torch.arange(0, 10)
88
88
 
89
- edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
89
+ edge_model = ai_edge_torch.signature(
90
+ "prefill", model, (prefill_tokens, prefill_input_pos)
91
+ ).convert()
90
92
  edge_model.set_interpreter_builder(
91
93
  self._interpreter_builder(edge_model.tflite_model())
92
94
  )
93
95
 
94
- # TODO(b/362840003): debug numerical diff.
95
- skip_output_check = True
96
- if not skip_output_check:
97
- self.assertTrue(
98
- model_coverage.compare_tflite_torch(
99
- edge_model,
100
- model,
101
- (tokens, input_pos),
102
- num_valid_inputs=1,
103
- atol=1e-2,
104
- rtol=1e-5,
105
- )
106
- )
96
+ self.assertTrue(
97
+ model_coverage.compare_tflite_torch(
98
+ edge_model,
99
+ model,
100
+ (prefill_tokens, prefill_input_pos),
101
+ signature_name="prefill",
102
+ num_valid_inputs=1,
103
+ atol=1e-2,
104
+ rtol=1e-5,
105
+ )
106
+ )
107
107
 
108
108
  @googletest.skipIf(
109
109
  ai_edge_config.Config.use_torch_xla,
@@ -25,16 +25,16 @@ from ai_edge_torch.generative.quantize.quant_attrs import Granularity
25
25
  from ai_edge_torch.generative.quantize.quant_attrs import Mode
26
26
  from ai_edge_torch.quantize import quant_config
27
27
  from ai_edge_torch.testing import model_coverage
28
- from parameterized import parameterized
29
28
  import torch
30
29
 
31
30
  from absl.testing import absltest as googletest
31
+ from absl.testing import parameterized
32
32
 
33
33
 
34
- class TestVerifyRecipes(googletest.TestCase):
34
+ class TestVerifyRecipes(parameterized.TestCase):
35
35
  """Unit tests that check for model quantization recipes."""
36
36
 
37
- @parameterized.expand([
37
+ @parameterized.parameters([
38
38
  (Dtype.FP32, Dtype.FP32),
39
39
  (Dtype.INT8, Dtype.INT8),
40
40
  (Dtype.INT8, Dtype.FP16),
@@ -52,7 +52,7 @@ class TestVerifyRecipes(googletest.TestCase):
52
52
  with self.assertRaises(ValueError):
53
53
  quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
54
54
 
55
- @parameterized.expand([
55
+ @parameterized.parameters([
56
56
  (
57
57
  Dtype.FP32,
58
58
  Dtype.INT8,
@@ -88,7 +88,7 @@ class TestVerifyRecipes(googletest.TestCase):
88
88
  ).verify()
89
89
 
90
90
 
91
- class TestQuantizeConvert(googletest.TestCase):
91
+ class TestQuantizeConvert(parameterized.TestCase):
92
92
  """Test conversion with quantization."""
93
93
 
94
94
  def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
@@ -105,17 +105,13 @@ class TestQuantizeConvert(googletest.TestCase):
105
105
  )
106
106
  )
107
107
 
108
- @parameterized.expand([
108
+ @parameterized.parameters([
109
109
  (quant_recipes.full_fp16_recipe()),
110
110
  (quant_recipes.full_int8_dynamic_recipe()),
111
111
  (quant_recipes.full_int8_weight_only_recipe()),
112
112
  (_attention_int8_dynamic_recipe()),
113
113
  (_feedforward_int8_dynamic_recipe()),
114
114
  ])
115
- @googletest.skipIf(
116
- not config.Config.use_torch_xla,
117
- reason="Not working with odml_torch at the moment.",
118
- )
119
115
  def test_quantize_convert_toy_sizes(self, quant_config):
120
116
  config = toy_model.get_model_config()
121
117
  pytorch_model = toy_model.ToySingleLayerModel(config)
@@ -132,6 +128,23 @@ class TestQuantizeConvert(googletest.TestCase):
132
128
  "Quantized model isn't smaller than F32 model.",
133
129
  )
134
130
 
131
+ def test_quantize_convert_toy_weight_sharing(self):
132
+ config = toy_model.get_model_config()
133
+ pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
134
+ idx = torch.unsqueeze(torch.arange(0, 100), 0)
135
+ input_pos = torch.arange(0, 100)
136
+
137
+ quant_config = quant_recipes.full_int8_dynamic_recipe()
138
+ quantized_model = ai_edge_torch.convert(
139
+ pytorch_model, (idx, input_pos), quant_config=quant_config
140
+ )
141
+ float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
142
+ self.assertLess(
143
+ len(quantized_model._tflite_model),
144
+ len(float_model._tflite_model),
145
+ "Quantized model isn't smaller than F32 model.",
146
+ )
147
+
135
148
  def test_quantize_convert_compare_toy(self):
136
149
  self.skipTest("b/338288901")
137
150
  config = toy_model_with_kv_cache.get_model_config()
@@ -208,7 +208,7 @@ class ModelLoader:
208
208
  if self._file_name.endswith(".safetensors"):
209
209
  return load_safetensors
210
210
 
211
- if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"):
211
+ if self._file_name.endswith(".bin") or self._file_name.endswith("pt"):
212
212
  return load_pytorch_statedict
213
213
 
214
214
  raise ValueError("File format not supported.")
@@ -21,6 +21,7 @@ from ai_edge_torch import odml_torch
21
21
  from ai_edge_torch._convert import conversion_utils
22
22
  from ai_edge_torch._convert import signature as signature_module
23
23
  from ai_edge_torch.lowertools import common_utils
24
+ from ai_edge_torch.lowertools import translate_recipe
24
25
  from ai_edge_torch.odml_torch import export
25
26
  from ai_edge_torch.odml_torch import export_utils
26
27
  from ai_edge_torch.quantize import quant_config as qcfg
@@ -186,10 +187,29 @@ def merged_bundle_to_tfl_model(
186
187
  converter._experimental_enable_composite_direct_lowering = True
187
188
  converter.model_origin_framework = "PYTORCH"
188
189
 
190
+ conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
191
+ if (
192
+ quant_config is not None
193
+ and quant_config._quantizer_mode
194
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
195
+ ):
196
+ translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
197
+ quant_config.generative_recipe
198
+ )
199
+
189
200
  conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
190
201
 
191
202
  tflite_model = converter.convert()
192
203
 
204
+ if (
205
+ quant_config is not None
206
+ and quant_config._quantizer_mode
207
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
208
+ ):
209
+ tflite_model = translate_recipe.quantize_model(
210
+ tflite_model, translated_recipe
211
+ )
212
+
193
213
  return tflite_model
194
214
 
195
215
 
@@ -25,8 +25,8 @@ from typing import Any, Dict, Optional, Tuple, Union
25
25
  from ai_edge_torch import model
26
26
  from ai_edge_torch._convert import conversion_utils
27
27
  from ai_edge_torch._convert import signature as signature_module
28
- from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
29
28
  from ai_edge_torch.lowertools import common_utils
29
+ from ai_edge_torch.lowertools import translate_recipe
30
30
  from ai_edge_torch.quantize import quant_config as qcfg
31
31
  import torch
32
32
  from torch_xla import stablehlo
@@ -17,7 +17,8 @@ from ai_edge_quantizer import quantizer
17
17
  from ai_edge_torch.generative.quantize import quant_attrs
18
18
  from ai_edge_torch.generative.quantize import quant_recipe
19
19
 
20
- _OpExecutionMode = quantizer.qtyping.OpExecutionMode
20
+ _ComputePrecision = quantizer.qtyping.ComputePrecision
21
+ _QuantGranularity = quantizer.qtyping.QuantGranularity
21
22
  _OpName = quantizer.qtyping.TFLOperationName
22
23
  _TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
23
24
  _OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
@@ -50,21 +51,31 @@ def _get_dtype_from_dtype(
50
51
  return quantizer.qtyping.TensorDataType.INT
51
52
 
52
53
 
53
- def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode:
54
+ def _get_compute_precision_from_mode(
55
+ mode: quant_attrs.Mode,
56
+ ) -> _ComputePrecision:
54
57
  if mode == quant_attrs.Mode.DYNAMIC_RANGE:
55
- return _OpExecutionMode.DRQ
58
+ return _ComputePrecision.INTEGER
56
59
  elif mode == quant_attrs.Mode.WEIGHT_ONLY:
57
- return _OpExecutionMode.WEIGHT_ONLY
60
+ return _ComputePrecision.FLOAT
58
61
  raise ValueError('Unimplemented execution mode')
59
62
 
60
63
 
61
- def _get_channelwise_from_granularity(
64
+ def _get_explicit_dequant_from_mode(mode: quant_attrs.Mode) -> bool:
65
+ if mode == quant_attrs.Mode.DYNAMIC_RANGE:
66
+ return False
67
+ elif mode == quant_attrs.Mode.WEIGHT_ONLY:
68
+ return True
69
+ raise ValueError('Unimplemented execution mode')
70
+
71
+
72
+ def _get_granularity(
62
73
  granularity: quant_attrs.Granularity,
63
74
  ) -> bool:
64
75
  if granularity == quant_attrs.Granularity.CHANNELWISE:
65
- return True
66
- elif granularity == quant_attrs.Granularity.NONE:
67
- return False
76
+ return _QuantGranularity.CHANNELWISE
77
+ if granularity == quant_attrs.Granularity.NONE:
78
+ return _QuantGranularity.TENSORWISE
68
79
  raise ValueError('Unimplemented granularity')
69
80
 
70
81
 
@@ -88,12 +99,13 @@ def _set_quant_config(
88
99
  weight_tensor_config=_TensorQuantConfig(
89
100
  num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
90
101
  symmetric=True,
91
- channel_wise=_get_channelwise_from_granularity(
92
- layer_recipe.granularity
93
- ),
102
+ granularity=_get_granularity(layer_recipe.granularity),
94
103
  dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
95
104
  ),
96
- execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
105
+ compute_precision=_get_compute_precision_from_mode(layer_recipe.mode),
106
+ explicit_dequantize=_get_explicit_dequant_from_mode(
107
+ layer_recipe.mode
108
+ ),
97
109
  ),
98
110
  algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
99
111
  )
@@ -227,12 +227,9 @@ def exported_program_to_mlir(
227
227
  exported_program: torch.export.ExportedProgram,
228
228
  ) -> MlirLowered:
229
229
  """Lower the exported program to MLIR."""
230
- if torch.__version__ >= "2.2":
231
- # torch version 2.1 didn't expose this yet
232
- exported_program = exported_program.run_decompositions()
233
- exported_program = exported_program.run_decompositions(
234
- lowerings.decompositions()
235
- )
230
+ exported_program = exported_program.run_decompositions(
231
+ lowerings.decompositions()
232
+ )
236
233
 
237
234
  with export_utils.create_ir_context() as context, ir.Location.unknown():
238
235
 
@@ -35,7 +35,7 @@ jax.config.update("jax_enable_x64", True)
35
35
 
36
36
  def _lower_to_ir_text(
37
37
  jaxfn, args, kwargs, ir_input_names: list[str] = None
38
- ) -> str:
38
+ ) -> tuple[str, list[ir.Value]]:
39
39
  args = utils.tree_map_list_to_tuple(args)
40
40
  kwargs = utils.tree_map_list_to_tuple(kwargs)
41
41
 
@@ -74,7 +74,9 @@ def _lower_to_ir_text(
74
74
  x for x in pytree.tree_flatten(arg)[0] if isinstance(x, ir.Value)
75
75
  ]
76
76
 
77
- def new_lowering(*args, **jax_lower_static_kwargs):
77
+ def lower_wrapper(*args):
78
+ nonlocal jax_lower_static_kwargs
79
+
78
80
  jaxfn_args = []
79
81
  jaxfn_kwargs = jax_lower_static_kwargs.copy()
80
82
  for name, arg in zip(jax_lower_argnames, args):
@@ -85,11 +87,7 @@ def _lower_to_ir_text(
85
87
 
86
88
  return jaxfn(*jaxfn_args, **jaxfn_kwargs)
87
89
 
88
- return (
89
- jax.jit(new_lowering, static_argnames=static_argnames)
90
- .lower(*jax_lower_args, **jax_lower_static_kwargs)
91
- .as_text()
92
- ), ir_inputs
90
+ return jax.jit(lower_wrapper).lower(*jax_lower_args).as_text(), ir_inputs
93
91
 
94
92
 
95
93
  def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
@@ -52,6 +52,7 @@ class LoweringRegistry:
52
52
 
53
53
 
54
54
  global_registry = LoweringRegistry()
55
+ global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
55
56
  global_registry.decompositions.update(
56
57
  torch._decomp.get_decompositions([
57
58
  torch.ops.aten.upsample_nearest2d,
@@ -70,6 +71,13 @@ global_registry.decompositions.update(
70
71
  ])
71
72
  )
72
73
 
74
+ torch._decomp.remove_decompositions(
75
+ global_registry.decompositions,
76
+ [
77
+ torch.ops.aten.roll,
78
+ ],
79
+ )
80
+
73
81
 
74
82
  def lookup(op):
75
83
  return global_registry.lookup(op)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240904"
16
+ __version__ = "0.3.0.dev20240906"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240904
3
+ Version: 0.3.0.dev20240906
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -30,7 +30,7 @@ Requires-Dist: tabulate
30
30
  Requires-Dist: torch>=2.4.0
31
31
  Requires-Dist: torch-xla>=2.4.0
32
32
  Requires-Dist: tf-nightly>=2.18.0.dev20240722
33
- Requires-Dist: ai-edge-quantizer-nightly==0.0.1.dev20240718
33
+ Requires-Dist: ai-edge-quantizer-nightly
34
34
 
35
35
  Library that supports converting PyTorch models into a .tflite format, which can
36
36
  then be run with TensorFlow Lite and MediaPipe. This enables applications for
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
5
- ai_edge_torch/version.py,sha256=YC_qiN7DiHKGt_u71WY_mNK2BaoZr6mvPw7s7rWHA84,706
5
+ ai_edge_torch/version.py,sha256=vEc_GracKJpLkIs6M45gCFWkBMuXTjmvfvJnfXBSyrs,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -77,7 +77,7 @@ ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=CZVuNEL8OHPkdsz
77
77
  ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8vGIk-KtKp9D_H0,20871
78
78
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
79
79
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
80
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=LfWO_gSr1f66V1pxAc6yh21mtaJs7TVeuO9748zXBnE,3963
80
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_R_pp-A_7gKGSdHWDSXyis97r1ELVI,5622
81
81
  ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=l9swUKTcDtnTibNSNExaMgLvDeJ4Er2tVh5ZW1EtRgk,5809
82
82
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
83
83
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -106,16 +106,14 @@ ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1
106
106
  ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
107
107
  ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
108
108
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
109
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=sSHc_4hUEvi-3KmqbpqWbrRKBjCI1AOctM3dr2EH3vk,5263
111
109
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
112
110
  ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
113
111
  ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
114
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=zgQN0I9z8Xm0HcAGJIrnGzZWXEgNN041n4C5MXMNMqA,5286
115
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=o3l7HFHP-sg8aHeLNTSpMF91YovPODjp4QzYUnSJiIE,4479
116
- ai_edge_torch/generative/test/test_quantize.py,sha256=JEsk9SAkHK0SFm44K_quISc5yBBS6yvtBP1MDyFHdFw,5344
112
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=b3InJ8Rx03YtHpE9h-j0pSXAY1cCf-dLlx4Y5LSJnRQ,5174
113
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=9JXcd-rX8MpsYeEWUFEXf783GOwYOLY64KzDfFdmRJ8,4484
114
+ ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
117
115
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
118
- ai_edge_torch/generative/utilities/loader.py,sha256=QFZ2lkeoYQ9MZ1CAFVxBHG4OT192SH74UtJCvbDsdeI,12727
116
+ ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
119
117
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
120
118
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=_UXcc1QKT-S92hikfo-fTBFhnYLzROqcyRqKonVsqj4,16885
121
119
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
@@ -128,13 +126,14 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4
128
126
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
129
127
  ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
130
128
  ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
131
- ai_edge_torch/lowertools/odml_torch_utils.py,sha256=GKfW1X-QSFffQdVlBuD-bNpP265xcdUlfBY3-9I4f_o,7447
129
+ ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
132
130
  ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
133
- ai_edge_torch/lowertools/torch_xla_utils.py,sha256=-SRm9YNsIGsaVd5Cyp2PP-tdLBJH8EDoMFAa2y89a1w,9043
131
+ ai_edge_torch/lowertools/torch_xla_utils.py,sha256=n6G3pFGmHar7kgKDsdTB74kv1PUuTTu1XjV7R-QizzE,9003
132
+ ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGPDpE0jvFCCYIzpnc,5667
134
133
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
135
134
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
136
135
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
137
- ai_edge_torch/odml_torch/export.py,sha256=OXN6jipwFtBvQ9XdyeDGQTQ_-UnCxPYnLc_WW7xF0aI,10469
136
+ ai_edge_torch/odml_torch/export.py,sha256=_n43AlaTLvAK6r1szs47gSBqp-x19ZNCNtyFIWzuE4Q,10322
138
137
  ai_edge_torch/odml_torch/export_utils.py,sha256=q84U69ZQ82hLXw-xncJ8IW-K71Xux-NWlzZTs7hdZWA,5127
139
138
  ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
140
139
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
@@ -144,7 +143,7 @@ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu
144
143
  ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
145
144
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
146
145
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
147
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=hXvhKtbH7lGytm6QZOKpTmaLJN3kfENBcSIKQ39ReXA,5478
146
+ ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
148
147
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
149
148
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GqYk6oBJw7KWeG4_6gxSu_OvYhjJcC2FpGzWPPEdH6w,933
150
149
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
@@ -152,7 +151,7 @@ ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_
152
151
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=B6BILeu-UlwGB1O6g7111X1TaIFznsfxXrB72ygBsBA,3885
153
152
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=I0Y4IK7Zap8m6xfxMw7DfQ9Mg4htKOoypdHVAMHqx9c,10669
154
153
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
155
- ai_edge_torch/odml_torch/lowerings/registry.py,sha256=dcnxq8vV9rxSQqXkjSg9it7l6oP_sdfH8kIZdQNkQ_4,2653
154
+ ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
156
155
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
157
156
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
158
157
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
@@ -162,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
162
161
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
163
162
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
164
163
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
165
- ai_edge_torch_nightly-0.3.0.dev20240904.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
166
- ai_edge_torch_nightly-0.3.0.dev20240904.dist-info/METADATA,sha256=hyD9hQtmGBnd2MHwoSkhRnlupk00oncKRsN5PQe8iQk,1878
167
- ai_edge_torch_nightly-0.3.0.dev20240904.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
168
- ai_edge_torch_nightly-0.3.0.dev20240904.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
169
- ai_edge_torch_nightly-0.3.0.dev20240904.dist-info/RECORD,,
164
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/METADATA,sha256=u4yKvulxsV9xZmKSKnNO6L_FE8P_Iy96IZ0UL_voxAE,1859
166
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================