ai-edge-torch-nightly 0.3.0.dev20240904__py3-none-any.whl → 0.3.0.dev20240906__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -71,6 +71,56 @@ class ToySingleLayerModel(torch.nn.Module):
71
71
  return self.lm_head(x)
72
72
 
73
73
 
74
+ class ToySingleLayerModelWeightSharing(torch.nn.Module):
75
+
76
+ def __init__(self, config: cfg.ModelConfig) -> None:
77
+ super().__init__()
78
+ self.lm_head = nn.Linear(
79
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
80
+ )
81
+ self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
82
+ self.lm_head = nn.Linear(
83
+ config.embedding_dim,
84
+ config.vocab_size,
85
+ bias=config.lm_head_use_bias,
86
+ )
87
+ self.lm_head.weight.data = self.tok_embedding.weight.data
88
+ self.transformer_block = TransformerBlock(config)
89
+ self.final_norm = builder.build_norm(
90
+ config.embedding_dim,
91
+ config.final_norm_config,
92
+ )
93
+ self.rope_cache = attn_utils.build_rope_cache(
94
+ size=config.max_seq_len,
95
+ dim=int(
96
+ config.attn_config.rotary_percentage * config.attn_config.head_dim
97
+ ),
98
+ base=10_000,
99
+ condense_ratio=1,
100
+ dtype=torch.float32,
101
+ device=torch.device('cpu'),
102
+ )
103
+ self.mask_cache = attn_utils.build_causal_mask_cache(
104
+ size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
105
+ )
106
+ self.config = config
107
+
108
+ @torch.inference_mode
109
+ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
110
+ x = self.tok_embedding(idx)
111
+ cos, sin = self.rope_cache
112
+
113
+ cos = cos.index_select(0, input_pos)
114
+ sin = sin.index_select(0, input_pos)
115
+ mask = self.mask_cache.index_select(2, input_pos)
116
+ mask = mask[:, :, :, : self.config.max_seq_len]
117
+
118
+ x = self.transformer_block(x, (cos, sin), mask, input_pos)
119
+ x = self.final_norm(x)
120
+ res = self.lm_head(x)
121
+ return res
122
+
123
+
74
124
  def get_model_config() -> cfg.ModelConfig:
75
125
  attn_config = cfg.AttentionConfig(
76
126
  num_heads=32,
@@ -70,7 +70,6 @@ class TestModelConversion(googletest.TestCase):
70
70
  )
71
71
  )
72
72
 
73
-
74
73
  @googletest.skipIf(
75
74
  ai_edge_config.Config.use_torch_xla,
76
75
  reason="tests with custom ops are not supported on oss",
@@ -130,6 +129,7 @@ class TestModelConversion(googletest.TestCase):
130
129
  )
131
130
 
132
131
  copied_model = copy.deepcopy(pytorch_model)
132
+ copied_edge = copy.deepcopy(edge_model)
133
133
 
134
134
  self.assertTrue(
135
135
  model_coverage.compare_tflite_torch(
@@ -141,18 +141,15 @@ class TestModelConversion(googletest.TestCase):
141
141
  )
142
142
  )
143
143
 
144
- # TODO(b/362840003): figure why this decode output has big numerical diff.
145
- skip_output_check = True
146
- if not skip_output_check:
147
- self.assertTrue(
148
- model_coverage.compare_tflite_torch(
149
- edge_model,
150
- copied_model,
151
- (decode_token, decode_input_pos),
152
- signature_name="decode",
153
- num_valid_inputs=1,
154
- )
155
- )
144
+ self.assertTrue(
145
+ model_coverage.compare_tflite_torch(
146
+ copied_edge,
147
+ copied_model,
148
+ (decode_token, decode_input_pos),
149
+ signature_name="decode",
150
+ num_valid_inputs=1,
151
+ )
152
+ )
156
153
 
157
154
 
158
155
  if __name__ == "__main__":
@@ -82,28 +82,28 @@ class TestModelConversion(googletest.TestCase):
82
82
  model.eval()
83
83
 
84
84
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
85
- tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
86
- tokens[0, :4] = idx
87
- input_pos = torch.arange(0, 10)
85
+ prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
86
+ prefill_tokens[0, :4] = idx
87
+ prefill_input_pos = torch.arange(0, 10)
88
88
 
89
- edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
89
+ edge_model = ai_edge_torch.signature(
90
+ "prefill", model, (prefill_tokens, prefill_input_pos)
91
+ ).convert()
90
92
  edge_model.set_interpreter_builder(
91
93
  self._interpreter_builder(edge_model.tflite_model())
92
94
  )
93
95
 
94
- # TODO(b/362840003): debug numerical diff.
95
- skip_output_check = True
96
- if not skip_output_check:
97
- self.assertTrue(
98
- model_coverage.compare_tflite_torch(
99
- edge_model,
100
- model,
101
- (tokens, input_pos),
102
- num_valid_inputs=1,
103
- atol=1e-2,
104
- rtol=1e-5,
105
- )
106
- )
96
+ self.assertTrue(
97
+ model_coverage.compare_tflite_torch(
98
+ edge_model,
99
+ model,
100
+ (prefill_tokens, prefill_input_pos),
101
+ signature_name="prefill",
102
+ num_valid_inputs=1,
103
+ atol=1e-2,
104
+ rtol=1e-5,
105
+ )
106
+ )
107
107
 
108
108
  @googletest.skipIf(
109
109
  ai_edge_config.Config.use_torch_xla,
@@ -25,16 +25,16 @@ from ai_edge_torch.generative.quantize.quant_attrs import Granularity
25
25
  from ai_edge_torch.generative.quantize.quant_attrs import Mode
26
26
  from ai_edge_torch.quantize import quant_config
27
27
  from ai_edge_torch.testing import model_coverage
28
- from parameterized import parameterized
29
28
  import torch
30
29
 
31
30
  from absl.testing import absltest as googletest
31
+ from absl.testing import parameterized
32
32
 
33
33
 
34
- class TestVerifyRecipes(googletest.TestCase):
34
+ class TestVerifyRecipes(parameterized.TestCase):
35
35
  """Unit tests that check for model quantization recipes."""
36
36
 
37
- @parameterized.expand([
37
+ @parameterized.parameters([
38
38
  (Dtype.FP32, Dtype.FP32),
39
39
  (Dtype.INT8, Dtype.INT8),
40
40
  (Dtype.INT8, Dtype.FP16),
@@ -52,7 +52,7 @@ class TestVerifyRecipes(googletest.TestCase):
52
52
  with self.assertRaises(ValueError):
53
53
  quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
54
54
 
55
- @parameterized.expand([
55
+ @parameterized.parameters([
56
56
  (
57
57
  Dtype.FP32,
58
58
  Dtype.INT8,
@@ -88,7 +88,7 @@ class TestVerifyRecipes(googletest.TestCase):
88
88
  ).verify()
89
89
 
90
90
 
91
- class TestQuantizeConvert(googletest.TestCase):
91
+ class TestQuantizeConvert(parameterized.TestCase):
92
92
  """Test conversion with quantization."""
93
93
 
94
94
  def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
@@ -105,17 +105,13 @@ class TestQuantizeConvert(googletest.TestCase):
105
105
  )
106
106
  )
107
107
 
108
- @parameterized.expand([
108
+ @parameterized.parameters([
109
109
  (quant_recipes.full_fp16_recipe()),
110
110
  (quant_recipes.full_int8_dynamic_recipe()),
111
111
  (quant_recipes.full_int8_weight_only_recipe()),
112
112
  (_attention_int8_dynamic_recipe()),
113
113
  (_feedforward_int8_dynamic_recipe()),
114
114
  ])
115
- @googletest.skipIf(
116
- not config.Config.use_torch_xla,
117
- reason="Not working with odml_torch at the moment.",
118
- )
119
115
  def test_quantize_convert_toy_sizes(self, quant_config):
120
116
  config = toy_model.get_model_config()
121
117
  pytorch_model = toy_model.ToySingleLayerModel(config)
@@ -132,6 +128,23 @@ class TestQuantizeConvert(googletest.TestCase):
132
128
  "Quantized model isn't smaller than F32 model.",
133
129
  )
134
130
 
131
+ def test_quantize_convert_toy_weight_sharing(self):
132
+ config = toy_model.get_model_config()
133
+ pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
134
+ idx = torch.unsqueeze(torch.arange(0, 100), 0)
135
+ input_pos = torch.arange(0, 100)
136
+
137
+ quant_config = quant_recipes.full_int8_dynamic_recipe()
138
+ quantized_model = ai_edge_torch.convert(
139
+ pytorch_model, (idx, input_pos), quant_config=quant_config
140
+ )
141
+ float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
142
+ self.assertLess(
143
+ len(quantized_model._tflite_model),
144
+ len(float_model._tflite_model),
145
+ "Quantized model isn't smaller than F32 model.",
146
+ )
147
+
135
148
  def test_quantize_convert_compare_toy(self):
136
149
  self.skipTest("b/338288901")
137
150
  config = toy_model_with_kv_cache.get_model_config()
@@ -208,7 +208,7 @@ class ModelLoader:
208
208
  if self._file_name.endswith(".safetensors"):
209
209
  return load_safetensors
210
210
 
211
- if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"):
211
+ if self._file_name.endswith(".bin") or self._file_name.endswith("pt"):
212
212
  return load_pytorch_statedict
213
213
 
214
214
  raise ValueError("File format not supported.")
@@ -21,6 +21,7 @@ from ai_edge_torch import odml_torch
21
21
  from ai_edge_torch._convert import conversion_utils
22
22
  from ai_edge_torch._convert import signature as signature_module
23
23
  from ai_edge_torch.lowertools import common_utils
24
+ from ai_edge_torch.lowertools import translate_recipe
24
25
  from ai_edge_torch.odml_torch import export
25
26
  from ai_edge_torch.odml_torch import export_utils
26
27
  from ai_edge_torch.quantize import quant_config as qcfg
@@ -186,10 +187,29 @@ def merged_bundle_to_tfl_model(
186
187
  converter._experimental_enable_composite_direct_lowering = True
187
188
  converter.model_origin_framework = "PYTORCH"
188
189
 
190
+ conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
191
+ if (
192
+ quant_config is not None
193
+ and quant_config._quantizer_mode
194
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
195
+ ):
196
+ translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
197
+ quant_config.generative_recipe
198
+ )
199
+
189
200
  conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
190
201
 
191
202
  tflite_model = converter.convert()
192
203
 
204
+ if (
205
+ quant_config is not None
206
+ and quant_config._quantizer_mode
207
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
208
+ ):
209
+ tflite_model = translate_recipe.quantize_model(
210
+ tflite_model, translated_recipe
211
+ )
212
+
193
213
  return tflite_model
194
214
 
195
215
 
@@ -25,8 +25,8 @@ from typing import Any, Dict, Optional, Tuple, Union
25
25
  from ai_edge_torch import model
26
26
  from ai_edge_torch._convert import conversion_utils
27
27
  from ai_edge_torch._convert import signature as signature_module
28
- from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
29
28
  from ai_edge_torch.lowertools import common_utils
29
+ from ai_edge_torch.lowertools import translate_recipe
30
30
  from ai_edge_torch.quantize import quant_config as qcfg
31
31
  import torch
32
32
  from torch_xla import stablehlo
@@ -17,7 +17,8 @@ from ai_edge_quantizer import quantizer
17
17
  from ai_edge_torch.generative.quantize import quant_attrs
18
18
  from ai_edge_torch.generative.quantize import quant_recipe
19
19
 
20
- _OpExecutionMode = quantizer.qtyping.OpExecutionMode
20
+ _ComputePrecision = quantizer.qtyping.ComputePrecision
21
+ _QuantGranularity = quantizer.qtyping.QuantGranularity
21
22
  _OpName = quantizer.qtyping.TFLOperationName
22
23
  _TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
23
24
  _OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
@@ -50,21 +51,31 @@ def _get_dtype_from_dtype(
50
51
  return quantizer.qtyping.TensorDataType.INT
51
52
 
52
53
 
53
- def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode:
54
+ def _get_compute_precision_from_mode(
55
+ mode: quant_attrs.Mode,
56
+ ) -> _ComputePrecision:
54
57
  if mode == quant_attrs.Mode.DYNAMIC_RANGE:
55
- return _OpExecutionMode.DRQ
58
+ return _ComputePrecision.INTEGER
56
59
  elif mode == quant_attrs.Mode.WEIGHT_ONLY:
57
- return _OpExecutionMode.WEIGHT_ONLY
60
+ return _ComputePrecision.FLOAT
58
61
  raise ValueError('Unimplemented execution mode')
59
62
 
60
63
 
61
- def _get_channelwise_from_granularity(
64
+ def _get_explicit_dequant_from_mode(mode: quant_attrs.Mode) -> bool:
65
+ if mode == quant_attrs.Mode.DYNAMIC_RANGE:
66
+ return False
67
+ elif mode == quant_attrs.Mode.WEIGHT_ONLY:
68
+ return True
69
+ raise ValueError('Unimplemented execution mode')
70
+
71
+
72
+ def _get_granularity(
62
73
  granularity: quant_attrs.Granularity,
63
74
  ) -> bool:
64
75
  if granularity == quant_attrs.Granularity.CHANNELWISE:
65
- return True
66
- elif granularity == quant_attrs.Granularity.NONE:
67
- return False
76
+ return _QuantGranularity.CHANNELWISE
77
+ if granularity == quant_attrs.Granularity.NONE:
78
+ return _QuantGranularity.TENSORWISE
68
79
  raise ValueError('Unimplemented granularity')
69
80
 
70
81
 
@@ -88,12 +99,13 @@ def _set_quant_config(
88
99
  weight_tensor_config=_TensorQuantConfig(
89
100
  num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
90
101
  symmetric=True,
91
- channel_wise=_get_channelwise_from_granularity(
92
- layer_recipe.granularity
93
- ),
102
+ granularity=_get_granularity(layer_recipe.granularity),
94
103
  dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
95
104
  ),
96
- execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
105
+ compute_precision=_get_compute_precision_from_mode(layer_recipe.mode),
106
+ explicit_dequantize=_get_explicit_dequant_from_mode(
107
+ layer_recipe.mode
108
+ ),
97
109
  ),
98
110
  algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
99
111
  )
@@ -227,12 +227,9 @@ def exported_program_to_mlir(
227
227
  exported_program: torch.export.ExportedProgram,
228
228
  ) -> MlirLowered:
229
229
  """Lower the exported program to MLIR."""
230
- if torch.__version__ >= "2.2":
231
- # torch version 2.1 didn't expose this yet
232
- exported_program = exported_program.run_decompositions()
233
- exported_program = exported_program.run_decompositions(
234
- lowerings.decompositions()
235
- )
230
+ exported_program = exported_program.run_decompositions(
231
+ lowerings.decompositions()
232
+ )
236
233
 
237
234
  with export_utils.create_ir_context() as context, ir.Location.unknown():
238
235
 
@@ -35,7 +35,7 @@ jax.config.update("jax_enable_x64", True)
35
35
 
36
36
  def _lower_to_ir_text(
37
37
  jaxfn, args, kwargs, ir_input_names: list[str] = None
38
- ) -> str:
38
+ ) -> tuple[str, list[ir.Value]]:
39
39
  args = utils.tree_map_list_to_tuple(args)
40
40
  kwargs = utils.tree_map_list_to_tuple(kwargs)
41
41
 
@@ -74,7 +74,9 @@ def _lower_to_ir_text(
74
74
  x for x in pytree.tree_flatten(arg)[0] if isinstance(x, ir.Value)
75
75
  ]
76
76
 
77
- def new_lowering(*args, **jax_lower_static_kwargs):
77
+ def lower_wrapper(*args):
78
+ nonlocal jax_lower_static_kwargs
79
+
78
80
  jaxfn_args = []
79
81
  jaxfn_kwargs = jax_lower_static_kwargs.copy()
80
82
  for name, arg in zip(jax_lower_argnames, args):
@@ -85,11 +87,7 @@ def _lower_to_ir_text(
85
87
 
86
88
  return jaxfn(*jaxfn_args, **jaxfn_kwargs)
87
89
 
88
- return (
89
- jax.jit(new_lowering, static_argnames=static_argnames)
90
- .lower(*jax_lower_args, **jax_lower_static_kwargs)
91
- .as_text()
92
- ), ir_inputs
90
+ return jax.jit(lower_wrapper).lower(*jax_lower_args).as_text(), ir_inputs
93
91
 
94
92
 
95
93
  def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
@@ -52,6 +52,7 @@ class LoweringRegistry:
52
52
 
53
53
 
54
54
  global_registry = LoweringRegistry()
55
+ global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
55
56
  global_registry.decompositions.update(
56
57
  torch._decomp.get_decompositions([
57
58
  torch.ops.aten.upsample_nearest2d,
@@ -70,6 +71,13 @@ global_registry.decompositions.update(
70
71
  ])
71
72
  )
72
73
 
74
+ torch._decomp.remove_decompositions(
75
+ global_registry.decompositions,
76
+ [
77
+ torch.ops.aten.roll,
78
+ ],
79
+ )
80
+
73
81
 
74
82
  def lookup(op):
75
83
  return global_registry.lookup(op)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240904"
16
+ __version__ = "0.3.0.dev20240906"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240904
3
+ Version: 0.3.0.dev20240906
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -30,7 +30,7 @@ Requires-Dist: tabulate
30
30
  Requires-Dist: torch>=2.4.0
31
31
  Requires-Dist: torch-xla>=2.4.0
32
32
  Requires-Dist: tf-nightly>=2.18.0.dev20240722
33
- Requires-Dist: ai-edge-quantizer-nightly==0.0.1.dev20240718
33
+ Requires-Dist: ai-edge-quantizer-nightly
34
34
 
35
35
  Library that supports converting PyTorch models into a .tflite format, which can
36
36
  then be run with TensorFlow Lite and MediaPipe. This enables applications for
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
5
- ai_edge_torch/version.py,sha256=YC_qiN7DiHKGt_u71WY_mNK2BaoZr6mvPw7s7rWHA84,706
5
+ ai_edge_torch/version.py,sha256=vEc_GracKJpLkIs6M45gCFWkBMuXTjmvfvJnfXBSyrs,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -77,7 +77,7 @@ ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=CZVuNEL8OHPkdsz
77
77
  ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8vGIk-KtKp9D_H0,20871
78
78
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
79
79
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
80
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=LfWO_gSr1f66V1pxAc6yh21mtaJs7TVeuO9748zXBnE,3963
80
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_R_pp-A_7gKGSdHWDSXyis97r1ELVI,5622
81
81
  ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=l9swUKTcDtnTibNSNExaMgLvDeJ4Er2tVh5ZW1EtRgk,5809
82
82
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
83
83
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -106,16 +106,14 @@ ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1
106
106
  ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
107
107
  ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
108
108
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
109
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=sSHc_4hUEvi-3KmqbpqWbrRKBjCI1AOctM3dr2EH3vk,5263
111
109
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
112
110
  ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
113
111
  ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
114
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=zgQN0I9z8Xm0HcAGJIrnGzZWXEgNN041n4C5MXMNMqA,5286
115
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=o3l7HFHP-sg8aHeLNTSpMF91YovPODjp4QzYUnSJiIE,4479
116
- ai_edge_torch/generative/test/test_quantize.py,sha256=JEsk9SAkHK0SFm44K_quISc5yBBS6yvtBP1MDyFHdFw,5344
112
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=b3InJ8Rx03YtHpE9h-j0pSXAY1cCf-dLlx4Y5LSJnRQ,5174
113
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=9JXcd-rX8MpsYeEWUFEXf783GOwYOLY64KzDfFdmRJ8,4484
114
+ ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
117
115
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
118
- ai_edge_torch/generative/utilities/loader.py,sha256=QFZ2lkeoYQ9MZ1CAFVxBHG4OT192SH74UtJCvbDsdeI,12727
116
+ ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
119
117
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
120
118
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=_UXcc1QKT-S92hikfo-fTBFhnYLzROqcyRqKonVsqj4,16885
121
119
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
@@ -128,13 +126,14 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4
128
126
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
129
127
  ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
130
128
  ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
131
- ai_edge_torch/lowertools/odml_torch_utils.py,sha256=GKfW1X-QSFffQdVlBuD-bNpP265xcdUlfBY3-9I4f_o,7447
129
+ ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
132
130
  ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
133
- ai_edge_torch/lowertools/torch_xla_utils.py,sha256=-SRm9YNsIGsaVd5Cyp2PP-tdLBJH8EDoMFAa2y89a1w,9043
131
+ ai_edge_torch/lowertools/torch_xla_utils.py,sha256=n6G3pFGmHar7kgKDsdTB74kv1PUuTTu1XjV7R-QizzE,9003
132
+ ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGPDpE0jvFCCYIzpnc,5667
134
133
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
135
134
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
136
135
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
137
- ai_edge_torch/odml_torch/export.py,sha256=OXN6jipwFtBvQ9XdyeDGQTQ_-UnCxPYnLc_WW7xF0aI,10469
136
+ ai_edge_torch/odml_torch/export.py,sha256=_n43AlaTLvAK6r1szs47gSBqp-x19ZNCNtyFIWzuE4Q,10322
138
137
  ai_edge_torch/odml_torch/export_utils.py,sha256=q84U69ZQ82hLXw-xncJ8IW-K71Xux-NWlzZTs7hdZWA,5127
139
138
  ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
140
139
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
@@ -144,7 +143,7 @@ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu
144
143
  ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
145
144
  ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
146
145
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
147
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=hXvhKtbH7lGytm6QZOKpTmaLJN3kfENBcSIKQ39ReXA,5478
146
+ ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
148
147
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
149
148
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GqYk6oBJw7KWeG4_6gxSu_OvYhjJcC2FpGzWPPEdH6w,933
150
149
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
@@ -152,7 +151,7 @@ ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_
152
151
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=B6BILeu-UlwGB1O6g7111X1TaIFznsfxXrB72ygBsBA,3885
153
152
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=I0Y4IK7Zap8m6xfxMw7DfQ9Mg4htKOoypdHVAMHqx9c,10669
154
153
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
155
- ai_edge_torch/odml_torch/lowerings/registry.py,sha256=dcnxq8vV9rxSQqXkjSg9it7l6oP_sdfH8kIZdQNkQ_4,2653
154
+ ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
156
155
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
157
156
  ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
158
157
  ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
@@ -162,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
162
161
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
163
162
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
164
163
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
165
- ai_edge_torch_nightly-0.3.0.dev20240904.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
166
- ai_edge_torch_nightly-0.3.0.dev20240904.dist-info/METADATA,sha256=hyD9hQtmGBnd2MHwoSkhRnlupk00oncKRsN5PQe8iQk,1878
167
- ai_edge_torch_nightly-0.3.0.dev20240904.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
168
- ai_edge_torch_nightly-0.3.0.dev20240904.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
169
- ai_edge_torch_nightly-0.3.0.dev20240904.dist-info/RECORD,,
164
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/METADATA,sha256=u4yKvulxsV9xZmKSKnNO6L_FE8P_Iy96IZ0UL_voxAE,1859
166
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
+ ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================