ai-edge-torch-nightly 0.3.0.dev20240904__py3-none-any.whl → 0.3.0.dev20240906__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/test_models/toy_model.py +50 -0
- ai_edge_torch/generative/test/test_model_conversion.py +10 -13
- ai_edge_torch/generative/test/test_model_conversion_large.py +17 -17
- ai_edge_torch/generative/test/test_quantize.py +23 -10
- ai_edge_torch/generative/utilities/loader.py +1 -1
- ai_edge_torch/lowertools/odml_torch_utils.py +20 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +1 -1
- ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +24 -12
- ai_edge_torch/odml_torch/export.py +3 -6
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +5 -7
- ai_edge_torch/odml_torch/lowerings/registry.py +8 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240904.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/METADATA +2 -2
- {ai_edge_torch_nightly-0.3.0.dev20240904.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/RECORD +17 -18
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -14
- {ai_edge_torch_nightly-0.3.0.dev20240904.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240904.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240904.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/top_level.txt +0 -0
@@ -71,6 +71,56 @@ class ToySingleLayerModel(torch.nn.Module):
|
|
71
71
|
return self.lm_head(x)
|
72
72
|
|
73
73
|
|
74
|
+
class ToySingleLayerModelWeightSharing(torch.nn.Module):
|
75
|
+
|
76
|
+
def __init__(self, config: cfg.ModelConfig) -> None:
|
77
|
+
super().__init__()
|
78
|
+
self.lm_head = nn.Linear(
|
79
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
80
|
+
)
|
81
|
+
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
82
|
+
self.lm_head = nn.Linear(
|
83
|
+
config.embedding_dim,
|
84
|
+
config.vocab_size,
|
85
|
+
bias=config.lm_head_use_bias,
|
86
|
+
)
|
87
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
88
|
+
self.transformer_block = TransformerBlock(config)
|
89
|
+
self.final_norm = builder.build_norm(
|
90
|
+
config.embedding_dim,
|
91
|
+
config.final_norm_config,
|
92
|
+
)
|
93
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
94
|
+
size=config.max_seq_len,
|
95
|
+
dim=int(
|
96
|
+
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
97
|
+
),
|
98
|
+
base=10_000,
|
99
|
+
condense_ratio=1,
|
100
|
+
dtype=torch.float32,
|
101
|
+
device=torch.device('cpu'),
|
102
|
+
)
|
103
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
104
|
+
size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
|
105
|
+
)
|
106
|
+
self.config = config
|
107
|
+
|
108
|
+
@torch.inference_mode
|
109
|
+
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
110
|
+
x = self.tok_embedding(idx)
|
111
|
+
cos, sin = self.rope_cache
|
112
|
+
|
113
|
+
cos = cos.index_select(0, input_pos)
|
114
|
+
sin = sin.index_select(0, input_pos)
|
115
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
116
|
+
mask = mask[:, :, :, : self.config.max_seq_len]
|
117
|
+
|
118
|
+
x = self.transformer_block(x, (cos, sin), mask, input_pos)
|
119
|
+
x = self.final_norm(x)
|
120
|
+
res = self.lm_head(x)
|
121
|
+
return res
|
122
|
+
|
123
|
+
|
74
124
|
def get_model_config() -> cfg.ModelConfig:
|
75
125
|
attn_config = cfg.AttentionConfig(
|
76
126
|
num_heads=32,
|
@@ -70,7 +70,6 @@ class TestModelConversion(googletest.TestCase):
|
|
70
70
|
)
|
71
71
|
)
|
72
72
|
|
73
|
-
|
74
73
|
@googletest.skipIf(
|
75
74
|
ai_edge_config.Config.use_torch_xla,
|
76
75
|
reason="tests with custom ops are not supported on oss",
|
@@ -130,6 +129,7 @@ class TestModelConversion(googletest.TestCase):
|
|
130
129
|
)
|
131
130
|
|
132
131
|
copied_model = copy.deepcopy(pytorch_model)
|
132
|
+
copied_edge = copy.deepcopy(edge_model)
|
133
133
|
|
134
134
|
self.assertTrue(
|
135
135
|
model_coverage.compare_tflite_torch(
|
@@ -141,18 +141,15 @@ class TestModelConversion(googletest.TestCase):
|
|
141
141
|
)
|
142
142
|
)
|
143
143
|
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
num_valid_inputs=1,
|
154
|
-
)
|
155
|
-
)
|
144
|
+
self.assertTrue(
|
145
|
+
model_coverage.compare_tflite_torch(
|
146
|
+
copied_edge,
|
147
|
+
copied_model,
|
148
|
+
(decode_token, decode_input_pos),
|
149
|
+
signature_name="decode",
|
150
|
+
num_valid_inputs=1,
|
151
|
+
)
|
152
|
+
)
|
156
153
|
|
157
154
|
|
158
155
|
if __name__ == "__main__":
|
@@ -82,28 +82,28 @@ class TestModelConversion(googletest.TestCase):
|
|
82
82
|
model.eval()
|
83
83
|
|
84
84
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
85
|
-
|
86
|
-
|
87
|
-
|
85
|
+
prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
86
|
+
prefill_tokens[0, :4] = idx
|
87
|
+
prefill_input_pos = torch.arange(0, 10)
|
88
88
|
|
89
|
-
edge_model = ai_edge_torch.
|
89
|
+
edge_model = ai_edge_torch.signature(
|
90
|
+
"prefill", model, (prefill_tokens, prefill_input_pos)
|
91
|
+
).convert()
|
90
92
|
edge_model.set_interpreter_builder(
|
91
93
|
self._interpreter_builder(edge_model.tflite_model())
|
92
94
|
)
|
93
95
|
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
)
|
106
|
-
)
|
96
|
+
self.assertTrue(
|
97
|
+
model_coverage.compare_tflite_torch(
|
98
|
+
edge_model,
|
99
|
+
model,
|
100
|
+
(prefill_tokens, prefill_input_pos),
|
101
|
+
signature_name="prefill",
|
102
|
+
num_valid_inputs=1,
|
103
|
+
atol=1e-2,
|
104
|
+
rtol=1e-5,
|
105
|
+
)
|
106
|
+
)
|
107
107
|
|
108
108
|
@googletest.skipIf(
|
109
109
|
ai_edge_config.Config.use_torch_xla,
|
@@ -25,16 +25,16 @@ from ai_edge_torch.generative.quantize.quant_attrs import Granularity
|
|
25
25
|
from ai_edge_torch.generative.quantize.quant_attrs import Mode
|
26
26
|
from ai_edge_torch.quantize import quant_config
|
27
27
|
from ai_edge_torch.testing import model_coverage
|
28
|
-
from parameterized import parameterized
|
29
28
|
import torch
|
30
29
|
|
31
30
|
from absl.testing import absltest as googletest
|
31
|
+
from absl.testing import parameterized
|
32
32
|
|
33
33
|
|
34
|
-
class TestVerifyRecipes(
|
34
|
+
class TestVerifyRecipes(parameterized.TestCase):
|
35
35
|
"""Unit tests that check for model quantization recipes."""
|
36
36
|
|
37
|
-
@parameterized.
|
37
|
+
@parameterized.parameters([
|
38
38
|
(Dtype.FP32, Dtype.FP32),
|
39
39
|
(Dtype.INT8, Dtype.INT8),
|
40
40
|
(Dtype.INT8, Dtype.FP16),
|
@@ -52,7 +52,7 @@ class TestVerifyRecipes(googletest.TestCase):
|
|
52
52
|
with self.assertRaises(ValueError):
|
53
53
|
quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
|
54
54
|
|
55
|
-
@parameterized.
|
55
|
+
@parameterized.parameters([
|
56
56
|
(
|
57
57
|
Dtype.FP32,
|
58
58
|
Dtype.INT8,
|
@@ -88,7 +88,7 @@ class TestVerifyRecipes(googletest.TestCase):
|
|
88
88
|
).verify()
|
89
89
|
|
90
90
|
|
91
|
-
class TestQuantizeConvert(
|
91
|
+
class TestQuantizeConvert(parameterized.TestCase):
|
92
92
|
"""Test conversion with quantization."""
|
93
93
|
|
94
94
|
def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
@@ -105,17 +105,13 @@ class TestQuantizeConvert(googletest.TestCase):
|
|
105
105
|
)
|
106
106
|
)
|
107
107
|
|
108
|
-
@parameterized.
|
108
|
+
@parameterized.parameters([
|
109
109
|
(quant_recipes.full_fp16_recipe()),
|
110
110
|
(quant_recipes.full_int8_dynamic_recipe()),
|
111
111
|
(quant_recipes.full_int8_weight_only_recipe()),
|
112
112
|
(_attention_int8_dynamic_recipe()),
|
113
113
|
(_feedforward_int8_dynamic_recipe()),
|
114
114
|
])
|
115
|
-
@googletest.skipIf(
|
116
|
-
not config.Config.use_torch_xla,
|
117
|
-
reason="Not working with odml_torch at the moment.",
|
118
|
-
)
|
119
115
|
def test_quantize_convert_toy_sizes(self, quant_config):
|
120
116
|
config = toy_model.get_model_config()
|
121
117
|
pytorch_model = toy_model.ToySingleLayerModel(config)
|
@@ -132,6 +128,23 @@ class TestQuantizeConvert(googletest.TestCase):
|
|
132
128
|
"Quantized model isn't smaller than F32 model.",
|
133
129
|
)
|
134
130
|
|
131
|
+
def test_quantize_convert_toy_weight_sharing(self):
|
132
|
+
config = toy_model.get_model_config()
|
133
|
+
pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
|
134
|
+
idx = torch.unsqueeze(torch.arange(0, 100), 0)
|
135
|
+
input_pos = torch.arange(0, 100)
|
136
|
+
|
137
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe()
|
138
|
+
quantized_model = ai_edge_torch.convert(
|
139
|
+
pytorch_model, (idx, input_pos), quant_config=quant_config
|
140
|
+
)
|
141
|
+
float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
142
|
+
self.assertLess(
|
143
|
+
len(quantized_model._tflite_model),
|
144
|
+
len(float_model._tflite_model),
|
145
|
+
"Quantized model isn't smaller than F32 model.",
|
146
|
+
)
|
147
|
+
|
135
148
|
def test_quantize_convert_compare_toy(self):
|
136
149
|
self.skipTest("b/338288901")
|
137
150
|
config = toy_model_with_kv_cache.get_model_config()
|
@@ -208,7 +208,7 @@ class ModelLoader:
|
|
208
208
|
if self._file_name.endswith(".safetensors"):
|
209
209
|
return load_safetensors
|
210
210
|
|
211
|
-
if self._file_name.endswith(".bin") or self._file_name.endswith("
|
211
|
+
if self._file_name.endswith(".bin") or self._file_name.endswith("pt"):
|
212
212
|
return load_pytorch_statedict
|
213
213
|
|
214
214
|
raise ValueError("File format not supported.")
|
@@ -21,6 +21,7 @@ from ai_edge_torch import odml_torch
|
|
21
21
|
from ai_edge_torch._convert import conversion_utils
|
22
22
|
from ai_edge_torch._convert import signature as signature_module
|
23
23
|
from ai_edge_torch.lowertools import common_utils
|
24
|
+
from ai_edge_torch.lowertools import translate_recipe
|
24
25
|
from ai_edge_torch.odml_torch import export
|
25
26
|
from ai_edge_torch.odml_torch import export_utils
|
26
27
|
from ai_edge_torch.quantize import quant_config as qcfg
|
@@ -186,10 +187,29 @@ def merged_bundle_to_tfl_model(
|
|
186
187
|
converter._experimental_enable_composite_direct_lowering = True
|
187
188
|
converter.model_origin_framework = "PYTORCH"
|
188
189
|
|
190
|
+
conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
|
191
|
+
if (
|
192
|
+
quant_config is not None
|
193
|
+
and quant_config._quantizer_mode
|
194
|
+
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
195
|
+
):
|
196
|
+
translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
|
197
|
+
quant_config.generative_recipe
|
198
|
+
)
|
199
|
+
|
189
200
|
conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
|
190
201
|
|
191
202
|
tflite_model = converter.convert()
|
192
203
|
|
204
|
+
if (
|
205
|
+
quant_config is not None
|
206
|
+
and quant_config._quantizer_mode
|
207
|
+
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
208
|
+
):
|
209
|
+
tflite_model = translate_recipe.quantize_model(
|
210
|
+
tflite_model, translated_recipe
|
211
|
+
)
|
212
|
+
|
193
213
|
return tflite_model
|
194
214
|
|
195
215
|
|
@@ -25,8 +25,8 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|
25
25
|
from ai_edge_torch import model
|
26
26
|
from ai_edge_torch._convert import conversion_utils
|
27
27
|
from ai_edge_torch._convert import signature as signature_module
|
28
|
-
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
|
29
28
|
from ai_edge_torch.lowertools import common_utils
|
29
|
+
from ai_edge_torch.lowertools import translate_recipe
|
30
30
|
from ai_edge_torch.quantize import quant_config as qcfg
|
31
31
|
import torch
|
32
32
|
from torch_xla import stablehlo
|
@@ -17,7 +17,8 @@ from ai_edge_quantizer import quantizer
|
|
17
17
|
from ai_edge_torch.generative.quantize import quant_attrs
|
18
18
|
from ai_edge_torch.generative.quantize import quant_recipe
|
19
19
|
|
20
|
-
|
20
|
+
_ComputePrecision = quantizer.qtyping.ComputePrecision
|
21
|
+
_QuantGranularity = quantizer.qtyping.QuantGranularity
|
21
22
|
_OpName = quantizer.qtyping.TFLOperationName
|
22
23
|
_TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
|
23
24
|
_OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
|
@@ -50,21 +51,31 @@ def _get_dtype_from_dtype(
|
|
50
51
|
return quantizer.qtyping.TensorDataType.INT
|
51
52
|
|
52
53
|
|
53
|
-
def
|
54
|
+
def _get_compute_precision_from_mode(
|
55
|
+
mode: quant_attrs.Mode,
|
56
|
+
) -> _ComputePrecision:
|
54
57
|
if mode == quant_attrs.Mode.DYNAMIC_RANGE:
|
55
|
-
return
|
58
|
+
return _ComputePrecision.INTEGER
|
56
59
|
elif mode == quant_attrs.Mode.WEIGHT_ONLY:
|
57
|
-
return
|
60
|
+
return _ComputePrecision.FLOAT
|
58
61
|
raise ValueError('Unimplemented execution mode')
|
59
62
|
|
60
63
|
|
61
|
-
def
|
64
|
+
def _get_explicit_dequant_from_mode(mode: quant_attrs.Mode) -> bool:
|
65
|
+
if mode == quant_attrs.Mode.DYNAMIC_RANGE:
|
66
|
+
return False
|
67
|
+
elif mode == quant_attrs.Mode.WEIGHT_ONLY:
|
68
|
+
return True
|
69
|
+
raise ValueError('Unimplemented execution mode')
|
70
|
+
|
71
|
+
|
72
|
+
def _get_granularity(
|
62
73
|
granularity: quant_attrs.Granularity,
|
63
74
|
) -> bool:
|
64
75
|
if granularity == quant_attrs.Granularity.CHANNELWISE:
|
65
|
-
return
|
66
|
-
|
67
|
-
return
|
76
|
+
return _QuantGranularity.CHANNELWISE
|
77
|
+
if granularity == quant_attrs.Granularity.NONE:
|
78
|
+
return _QuantGranularity.TENSORWISE
|
68
79
|
raise ValueError('Unimplemented granularity')
|
69
80
|
|
70
81
|
|
@@ -88,12 +99,13 @@ def _set_quant_config(
|
|
88
99
|
weight_tensor_config=_TensorQuantConfig(
|
89
100
|
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
|
90
101
|
symmetric=True,
|
91
|
-
|
92
|
-
layer_recipe.granularity
|
93
|
-
),
|
102
|
+
granularity=_get_granularity(layer_recipe.granularity),
|
94
103
|
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
|
95
104
|
),
|
96
|
-
|
105
|
+
compute_precision=_get_compute_precision_from_mode(layer_recipe.mode),
|
106
|
+
explicit_dequantize=_get_explicit_dequant_from_mode(
|
107
|
+
layer_recipe.mode
|
108
|
+
),
|
97
109
|
),
|
98
110
|
algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
|
99
111
|
)
|
@@ -227,12 +227,9 @@ def exported_program_to_mlir(
|
|
227
227
|
exported_program: torch.export.ExportedProgram,
|
228
228
|
) -> MlirLowered:
|
229
229
|
"""Lower the exported program to MLIR."""
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
exported_program = exported_program.run_decompositions(
|
234
|
-
lowerings.decompositions()
|
235
|
-
)
|
230
|
+
exported_program = exported_program.run_decompositions(
|
231
|
+
lowerings.decompositions()
|
232
|
+
)
|
236
233
|
|
237
234
|
with export_utils.create_ir_context() as context, ir.Location.unknown():
|
238
235
|
|
@@ -35,7 +35,7 @@ jax.config.update("jax_enable_x64", True)
|
|
35
35
|
|
36
36
|
def _lower_to_ir_text(
|
37
37
|
jaxfn, args, kwargs, ir_input_names: list[str] = None
|
38
|
-
) -> str:
|
38
|
+
) -> tuple[str, list[ir.Value]]:
|
39
39
|
args = utils.tree_map_list_to_tuple(args)
|
40
40
|
kwargs = utils.tree_map_list_to_tuple(kwargs)
|
41
41
|
|
@@ -74,7 +74,9 @@ def _lower_to_ir_text(
|
|
74
74
|
x for x in pytree.tree_flatten(arg)[0] if isinstance(x, ir.Value)
|
75
75
|
]
|
76
76
|
|
77
|
-
def
|
77
|
+
def lower_wrapper(*args):
|
78
|
+
nonlocal jax_lower_static_kwargs
|
79
|
+
|
78
80
|
jaxfn_args = []
|
79
81
|
jaxfn_kwargs = jax_lower_static_kwargs.copy()
|
80
82
|
for name, arg in zip(jax_lower_argnames, args):
|
@@ -85,11 +87,7 @@ def _lower_to_ir_text(
|
|
85
87
|
|
86
88
|
return jaxfn(*jaxfn_args, **jaxfn_kwargs)
|
87
89
|
|
88
|
-
return (
|
89
|
-
jax.jit(new_lowering, static_argnames=static_argnames)
|
90
|
-
.lower(*jax_lower_args, **jax_lower_static_kwargs)
|
91
|
-
.as_text()
|
92
|
-
), ir_inputs
|
90
|
+
return jax.jit(lower_wrapper).lower(*jax_lower_args).as_text(), ir_inputs
|
93
91
|
|
94
92
|
|
95
93
|
def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
|
@@ -52,6 +52,7 @@ class LoweringRegistry:
|
|
52
52
|
|
53
53
|
|
54
54
|
global_registry = LoweringRegistry()
|
55
|
+
global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
|
55
56
|
global_registry.decompositions.update(
|
56
57
|
torch._decomp.get_decompositions([
|
57
58
|
torch.ops.aten.upsample_nearest2d,
|
@@ -70,6 +71,13 @@ global_registry.decompositions.update(
|
|
70
71
|
])
|
71
72
|
)
|
72
73
|
|
74
|
+
torch._decomp.remove_decompositions(
|
75
|
+
global_registry.decompositions,
|
76
|
+
[
|
77
|
+
torch.ops.aten.roll,
|
78
|
+
],
|
79
|
+
)
|
80
|
+
|
73
81
|
|
74
82
|
def lookup(op):
|
75
83
|
return global_registry.lookup(op)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240906
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -30,7 +30,7 @@ Requires-Dist: tabulate
|
|
30
30
|
Requires-Dist: torch>=2.4.0
|
31
31
|
Requires-Dist: torch-xla>=2.4.0
|
32
32
|
Requires-Dist: tf-nightly>=2.18.0.dev20240722
|
33
|
-
Requires-Dist: ai-edge-quantizer-nightly
|
33
|
+
Requires-Dist: ai-edge-quantizer-nightly
|
34
34
|
|
35
35
|
Library that supports converting PyTorch models into a .tflite format, which can
|
36
36
|
then be run with TensorFlow Lite and MediaPipe. This enables applications for
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=vEc_GracKJpLkIs6M45gCFWkBMuXTjmvfvJnfXBSyrs,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -77,7 +77,7 @@ ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=CZVuNEL8OHPkdsz
|
|
77
77
|
ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8vGIk-KtKp9D_H0,20871
|
78
78
|
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
|
79
79
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
80
|
-
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=
|
80
|
+
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_R_pp-A_7gKGSdHWDSXyis97r1ELVI,5622
|
81
81
|
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=l9swUKTcDtnTibNSNExaMgLvDeJ4Er2tVh5ZW1EtRgk,5809
|
82
82
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
|
83
83
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -106,16 +106,14 @@ ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1
|
|
106
106
|
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
|
107
107
|
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
|
108
108
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
109
|
-
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
110
|
-
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=sSHc_4hUEvi-3KmqbpqWbrRKBjCI1AOctM3dr2EH3vk,5263
|
111
109
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
112
110
|
ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
|
113
111
|
ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
|
114
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
115
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
116
|
-
ai_edge_torch/generative/test/test_quantize.py,sha256=
|
112
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=b3InJ8Rx03YtHpE9h-j0pSXAY1cCf-dLlx4Y5LSJnRQ,5174
|
113
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=9JXcd-rX8MpsYeEWUFEXf783GOwYOLY64KzDfFdmRJ8,4484
|
114
|
+
ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
|
117
115
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
118
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
116
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
|
119
117
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
|
120
118
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=_UXcc1QKT-S92hikfo-fTBFhnYLzROqcyRqKonVsqj4,16885
|
121
119
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
@@ -128,13 +126,14 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4
|
|
128
126
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
129
127
|
ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
|
130
128
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
131
|
-
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=
|
129
|
+
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
|
132
130
|
ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
|
133
|
-
ai_edge_torch/lowertools/torch_xla_utils.py,sha256
|
131
|
+
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=n6G3pFGmHar7kgKDsdTB74kv1PUuTTu1XjV7R-QizzE,9003
|
132
|
+
ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGPDpE0jvFCCYIzpnc,5667
|
134
133
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
135
134
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
136
135
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
137
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
136
|
+
ai_edge_torch/odml_torch/export.py,sha256=_n43AlaTLvAK6r1szs47gSBqp-x19ZNCNtyFIWzuE4Q,10322
|
138
137
|
ai_edge_torch/odml_torch/export_utils.py,sha256=q84U69ZQ82hLXw-xncJ8IW-K71Xux-NWlzZTs7hdZWA,5127
|
139
138
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
|
140
139
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
@@ -144,7 +143,7 @@ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu
|
|
144
143
|
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
|
145
144
|
ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
|
146
145
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
|
147
|
-
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=
|
146
|
+
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
|
148
147
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
149
148
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GqYk6oBJw7KWeG4_6gxSu_OvYhjJcC2FpGzWPPEdH6w,933
|
150
149
|
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
|
@@ -152,7 +151,7 @@ ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_
|
|
152
151
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=B6BILeu-UlwGB1O6g7111X1TaIFznsfxXrB72ygBsBA,3885
|
153
152
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=I0Y4IK7Zap8m6xfxMw7DfQ9Mg4htKOoypdHVAMHqx9c,10669
|
154
153
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
155
|
-
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=
|
154
|
+
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
|
156
155
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
|
157
156
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
158
157
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
@@ -162,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
162
161
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
163
162
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
164
163
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
165
|
-
ai_edge_torch_nightly-0.3.0.
|
166
|
-
ai_edge_torch_nightly-0.3.0.
|
167
|
-
ai_edge_torch_nightly-0.3.0.
|
168
|
-
ai_edge_torch_nightly-0.3.0.
|
169
|
-
ai_edge_torch_nightly-0.3.0.
|
164
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
165
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/METADATA,sha256=u4yKvulxsV9xZmKSKnNO6L_FE8P_Iy96IZ0UL_voxAE,1859
|
166
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
167
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
168
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/RECORD,,
|
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
File without changes
|
File without changes
|