ai-edge-torch-nightly 0.3.0.dev20240904__py3-none-any.whl → 0.3.0.dev20240906__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/test_models/toy_model.py +50 -0
- ai_edge_torch/generative/test/test_model_conversion.py +10 -13
- ai_edge_torch/generative/test/test_model_conversion_large.py +17 -17
- ai_edge_torch/generative/test/test_quantize.py +23 -10
- ai_edge_torch/generative/utilities/loader.py +1 -1
- ai_edge_torch/lowertools/odml_torch_utils.py +20 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +1 -1
- ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +24 -12
- ai_edge_torch/odml_torch/export.py +3 -6
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +5 -7
- ai_edge_torch/odml_torch/lowerings/registry.py +8 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240904.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/METADATA +2 -2
- {ai_edge_torch_nightly-0.3.0.dev20240904.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/RECORD +17 -18
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -14
- {ai_edge_torch_nightly-0.3.0.dev20240904.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240904.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240904.dist-info → ai_edge_torch_nightly-0.3.0.dev20240906.dist-info}/top_level.txt +0 -0
@@ -71,6 +71,56 @@ class ToySingleLayerModel(torch.nn.Module):
|
|
71
71
|
return self.lm_head(x)
|
72
72
|
|
73
73
|
|
74
|
+
class ToySingleLayerModelWeightSharing(torch.nn.Module):
|
75
|
+
|
76
|
+
def __init__(self, config: cfg.ModelConfig) -> None:
|
77
|
+
super().__init__()
|
78
|
+
self.lm_head = nn.Linear(
|
79
|
+
config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
|
80
|
+
)
|
81
|
+
self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
|
82
|
+
self.lm_head = nn.Linear(
|
83
|
+
config.embedding_dim,
|
84
|
+
config.vocab_size,
|
85
|
+
bias=config.lm_head_use_bias,
|
86
|
+
)
|
87
|
+
self.lm_head.weight.data = self.tok_embedding.weight.data
|
88
|
+
self.transformer_block = TransformerBlock(config)
|
89
|
+
self.final_norm = builder.build_norm(
|
90
|
+
config.embedding_dim,
|
91
|
+
config.final_norm_config,
|
92
|
+
)
|
93
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
94
|
+
size=config.max_seq_len,
|
95
|
+
dim=int(
|
96
|
+
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
97
|
+
),
|
98
|
+
base=10_000,
|
99
|
+
condense_ratio=1,
|
100
|
+
dtype=torch.float32,
|
101
|
+
device=torch.device('cpu'),
|
102
|
+
)
|
103
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
104
|
+
size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
|
105
|
+
)
|
106
|
+
self.config = config
|
107
|
+
|
108
|
+
@torch.inference_mode
|
109
|
+
def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
|
110
|
+
x = self.tok_embedding(idx)
|
111
|
+
cos, sin = self.rope_cache
|
112
|
+
|
113
|
+
cos = cos.index_select(0, input_pos)
|
114
|
+
sin = sin.index_select(0, input_pos)
|
115
|
+
mask = self.mask_cache.index_select(2, input_pos)
|
116
|
+
mask = mask[:, :, :, : self.config.max_seq_len]
|
117
|
+
|
118
|
+
x = self.transformer_block(x, (cos, sin), mask, input_pos)
|
119
|
+
x = self.final_norm(x)
|
120
|
+
res = self.lm_head(x)
|
121
|
+
return res
|
122
|
+
|
123
|
+
|
74
124
|
def get_model_config() -> cfg.ModelConfig:
|
75
125
|
attn_config = cfg.AttentionConfig(
|
76
126
|
num_heads=32,
|
@@ -70,7 +70,6 @@ class TestModelConversion(googletest.TestCase):
|
|
70
70
|
)
|
71
71
|
)
|
72
72
|
|
73
|
-
|
74
73
|
@googletest.skipIf(
|
75
74
|
ai_edge_config.Config.use_torch_xla,
|
76
75
|
reason="tests with custom ops are not supported on oss",
|
@@ -130,6 +129,7 @@ class TestModelConversion(googletest.TestCase):
|
|
130
129
|
)
|
131
130
|
|
132
131
|
copied_model = copy.deepcopy(pytorch_model)
|
132
|
+
copied_edge = copy.deepcopy(edge_model)
|
133
133
|
|
134
134
|
self.assertTrue(
|
135
135
|
model_coverage.compare_tflite_torch(
|
@@ -141,18 +141,15 @@ class TestModelConversion(googletest.TestCase):
|
|
141
141
|
)
|
142
142
|
)
|
143
143
|
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
num_valid_inputs=1,
|
154
|
-
)
|
155
|
-
)
|
144
|
+
self.assertTrue(
|
145
|
+
model_coverage.compare_tflite_torch(
|
146
|
+
copied_edge,
|
147
|
+
copied_model,
|
148
|
+
(decode_token, decode_input_pos),
|
149
|
+
signature_name="decode",
|
150
|
+
num_valid_inputs=1,
|
151
|
+
)
|
152
|
+
)
|
156
153
|
|
157
154
|
|
158
155
|
if __name__ == "__main__":
|
@@ -82,28 +82,28 @@ class TestModelConversion(googletest.TestCase):
|
|
82
82
|
model.eval()
|
83
83
|
|
84
84
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
85
|
-
|
86
|
-
|
87
|
-
|
85
|
+
prefill_tokens = torch.full((1, 10), 0, dtype=torch.long, device="cpu")
|
86
|
+
prefill_tokens[0, :4] = idx
|
87
|
+
prefill_input_pos = torch.arange(0, 10)
|
88
88
|
|
89
|
-
edge_model = ai_edge_torch.
|
89
|
+
edge_model = ai_edge_torch.signature(
|
90
|
+
"prefill", model, (prefill_tokens, prefill_input_pos)
|
91
|
+
).convert()
|
90
92
|
edge_model.set_interpreter_builder(
|
91
93
|
self._interpreter_builder(edge_model.tflite_model())
|
92
94
|
)
|
93
95
|
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
)
|
106
|
-
)
|
96
|
+
self.assertTrue(
|
97
|
+
model_coverage.compare_tflite_torch(
|
98
|
+
edge_model,
|
99
|
+
model,
|
100
|
+
(prefill_tokens, prefill_input_pos),
|
101
|
+
signature_name="prefill",
|
102
|
+
num_valid_inputs=1,
|
103
|
+
atol=1e-2,
|
104
|
+
rtol=1e-5,
|
105
|
+
)
|
106
|
+
)
|
107
107
|
|
108
108
|
@googletest.skipIf(
|
109
109
|
ai_edge_config.Config.use_torch_xla,
|
@@ -25,16 +25,16 @@ from ai_edge_torch.generative.quantize.quant_attrs import Granularity
|
|
25
25
|
from ai_edge_torch.generative.quantize.quant_attrs import Mode
|
26
26
|
from ai_edge_torch.quantize import quant_config
|
27
27
|
from ai_edge_torch.testing import model_coverage
|
28
|
-
from parameterized import parameterized
|
29
28
|
import torch
|
30
29
|
|
31
30
|
from absl.testing import absltest as googletest
|
31
|
+
from absl.testing import parameterized
|
32
32
|
|
33
33
|
|
34
|
-
class TestVerifyRecipes(
|
34
|
+
class TestVerifyRecipes(parameterized.TestCase):
|
35
35
|
"""Unit tests that check for model quantization recipes."""
|
36
36
|
|
37
|
-
@parameterized.
|
37
|
+
@parameterized.parameters([
|
38
38
|
(Dtype.FP32, Dtype.FP32),
|
39
39
|
(Dtype.INT8, Dtype.INT8),
|
40
40
|
(Dtype.INT8, Dtype.FP16),
|
@@ -52,7 +52,7 @@ class TestVerifyRecipes(googletest.TestCase):
|
|
52
52
|
with self.assertRaises(ValueError):
|
53
53
|
quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
|
54
54
|
|
55
|
-
@parameterized.
|
55
|
+
@parameterized.parameters([
|
56
56
|
(
|
57
57
|
Dtype.FP32,
|
58
58
|
Dtype.INT8,
|
@@ -88,7 +88,7 @@ class TestVerifyRecipes(googletest.TestCase):
|
|
88
88
|
).verify()
|
89
89
|
|
90
90
|
|
91
|
-
class TestQuantizeConvert(
|
91
|
+
class TestQuantizeConvert(parameterized.TestCase):
|
92
92
|
"""Test conversion with quantization."""
|
93
93
|
|
94
94
|
def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
|
@@ -105,17 +105,13 @@ class TestQuantizeConvert(googletest.TestCase):
|
|
105
105
|
)
|
106
106
|
)
|
107
107
|
|
108
|
-
@parameterized.
|
108
|
+
@parameterized.parameters([
|
109
109
|
(quant_recipes.full_fp16_recipe()),
|
110
110
|
(quant_recipes.full_int8_dynamic_recipe()),
|
111
111
|
(quant_recipes.full_int8_weight_only_recipe()),
|
112
112
|
(_attention_int8_dynamic_recipe()),
|
113
113
|
(_feedforward_int8_dynamic_recipe()),
|
114
114
|
])
|
115
|
-
@googletest.skipIf(
|
116
|
-
not config.Config.use_torch_xla,
|
117
|
-
reason="Not working with odml_torch at the moment.",
|
118
|
-
)
|
119
115
|
def test_quantize_convert_toy_sizes(self, quant_config):
|
120
116
|
config = toy_model.get_model_config()
|
121
117
|
pytorch_model = toy_model.ToySingleLayerModel(config)
|
@@ -132,6 +128,23 @@ class TestQuantizeConvert(googletest.TestCase):
|
|
132
128
|
"Quantized model isn't smaller than F32 model.",
|
133
129
|
)
|
134
130
|
|
131
|
+
def test_quantize_convert_toy_weight_sharing(self):
|
132
|
+
config = toy_model.get_model_config()
|
133
|
+
pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
|
134
|
+
idx = torch.unsqueeze(torch.arange(0, 100), 0)
|
135
|
+
input_pos = torch.arange(0, 100)
|
136
|
+
|
137
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe()
|
138
|
+
quantized_model = ai_edge_torch.convert(
|
139
|
+
pytorch_model, (idx, input_pos), quant_config=quant_config
|
140
|
+
)
|
141
|
+
float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
142
|
+
self.assertLess(
|
143
|
+
len(quantized_model._tflite_model),
|
144
|
+
len(float_model._tflite_model),
|
145
|
+
"Quantized model isn't smaller than F32 model.",
|
146
|
+
)
|
147
|
+
|
135
148
|
def test_quantize_convert_compare_toy(self):
|
136
149
|
self.skipTest("b/338288901")
|
137
150
|
config = toy_model_with_kv_cache.get_model_config()
|
@@ -208,7 +208,7 @@ class ModelLoader:
|
|
208
208
|
if self._file_name.endswith(".safetensors"):
|
209
209
|
return load_safetensors
|
210
210
|
|
211
|
-
if self._file_name.endswith(".bin") or self._file_name.endswith("
|
211
|
+
if self._file_name.endswith(".bin") or self._file_name.endswith("pt"):
|
212
212
|
return load_pytorch_statedict
|
213
213
|
|
214
214
|
raise ValueError("File format not supported.")
|
@@ -21,6 +21,7 @@ from ai_edge_torch import odml_torch
|
|
21
21
|
from ai_edge_torch._convert import conversion_utils
|
22
22
|
from ai_edge_torch._convert import signature as signature_module
|
23
23
|
from ai_edge_torch.lowertools import common_utils
|
24
|
+
from ai_edge_torch.lowertools import translate_recipe
|
24
25
|
from ai_edge_torch.odml_torch import export
|
25
26
|
from ai_edge_torch.odml_torch import export_utils
|
26
27
|
from ai_edge_torch.quantize import quant_config as qcfg
|
@@ -186,10 +187,29 @@ def merged_bundle_to_tfl_model(
|
|
186
187
|
converter._experimental_enable_composite_direct_lowering = True
|
187
188
|
converter.model_origin_framework = "PYTORCH"
|
188
189
|
|
190
|
+
conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
|
191
|
+
if (
|
192
|
+
quant_config is not None
|
193
|
+
and quant_config._quantizer_mode
|
194
|
+
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
195
|
+
):
|
196
|
+
translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
|
197
|
+
quant_config.generative_recipe
|
198
|
+
)
|
199
|
+
|
189
200
|
conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
|
190
201
|
|
191
202
|
tflite_model = converter.convert()
|
192
203
|
|
204
|
+
if (
|
205
|
+
quant_config is not None
|
206
|
+
and quant_config._quantizer_mode
|
207
|
+
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
208
|
+
):
|
209
|
+
tflite_model = translate_recipe.quantize_model(
|
210
|
+
tflite_model, translated_recipe
|
211
|
+
)
|
212
|
+
|
193
213
|
return tflite_model
|
194
214
|
|
195
215
|
|
@@ -25,8 +25,8 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|
25
25
|
from ai_edge_torch import model
|
26
26
|
from ai_edge_torch._convert import conversion_utils
|
27
27
|
from ai_edge_torch._convert import signature as signature_module
|
28
|
-
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
|
29
28
|
from ai_edge_torch.lowertools import common_utils
|
29
|
+
from ai_edge_torch.lowertools import translate_recipe
|
30
30
|
from ai_edge_torch.quantize import quant_config as qcfg
|
31
31
|
import torch
|
32
32
|
from torch_xla import stablehlo
|
@@ -17,7 +17,8 @@ from ai_edge_quantizer import quantizer
|
|
17
17
|
from ai_edge_torch.generative.quantize import quant_attrs
|
18
18
|
from ai_edge_torch.generative.quantize import quant_recipe
|
19
19
|
|
20
|
-
|
20
|
+
_ComputePrecision = quantizer.qtyping.ComputePrecision
|
21
|
+
_QuantGranularity = quantizer.qtyping.QuantGranularity
|
21
22
|
_OpName = quantizer.qtyping.TFLOperationName
|
22
23
|
_TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
|
23
24
|
_OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
|
@@ -50,21 +51,31 @@ def _get_dtype_from_dtype(
|
|
50
51
|
return quantizer.qtyping.TensorDataType.INT
|
51
52
|
|
52
53
|
|
53
|
-
def
|
54
|
+
def _get_compute_precision_from_mode(
|
55
|
+
mode: quant_attrs.Mode,
|
56
|
+
) -> _ComputePrecision:
|
54
57
|
if mode == quant_attrs.Mode.DYNAMIC_RANGE:
|
55
|
-
return
|
58
|
+
return _ComputePrecision.INTEGER
|
56
59
|
elif mode == quant_attrs.Mode.WEIGHT_ONLY:
|
57
|
-
return
|
60
|
+
return _ComputePrecision.FLOAT
|
58
61
|
raise ValueError('Unimplemented execution mode')
|
59
62
|
|
60
63
|
|
61
|
-
def
|
64
|
+
def _get_explicit_dequant_from_mode(mode: quant_attrs.Mode) -> bool:
|
65
|
+
if mode == quant_attrs.Mode.DYNAMIC_RANGE:
|
66
|
+
return False
|
67
|
+
elif mode == quant_attrs.Mode.WEIGHT_ONLY:
|
68
|
+
return True
|
69
|
+
raise ValueError('Unimplemented execution mode')
|
70
|
+
|
71
|
+
|
72
|
+
def _get_granularity(
|
62
73
|
granularity: quant_attrs.Granularity,
|
63
74
|
) -> bool:
|
64
75
|
if granularity == quant_attrs.Granularity.CHANNELWISE:
|
65
|
-
return
|
66
|
-
|
67
|
-
return
|
76
|
+
return _QuantGranularity.CHANNELWISE
|
77
|
+
if granularity == quant_attrs.Granularity.NONE:
|
78
|
+
return _QuantGranularity.TENSORWISE
|
68
79
|
raise ValueError('Unimplemented granularity')
|
69
80
|
|
70
81
|
|
@@ -88,12 +99,13 @@ def _set_quant_config(
|
|
88
99
|
weight_tensor_config=_TensorQuantConfig(
|
89
100
|
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
|
90
101
|
symmetric=True,
|
91
|
-
|
92
|
-
layer_recipe.granularity
|
93
|
-
),
|
102
|
+
granularity=_get_granularity(layer_recipe.granularity),
|
94
103
|
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
|
95
104
|
),
|
96
|
-
|
105
|
+
compute_precision=_get_compute_precision_from_mode(layer_recipe.mode),
|
106
|
+
explicit_dequantize=_get_explicit_dequant_from_mode(
|
107
|
+
layer_recipe.mode
|
108
|
+
),
|
97
109
|
),
|
98
110
|
algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
|
99
111
|
)
|
@@ -227,12 +227,9 @@ def exported_program_to_mlir(
|
|
227
227
|
exported_program: torch.export.ExportedProgram,
|
228
228
|
) -> MlirLowered:
|
229
229
|
"""Lower the exported program to MLIR."""
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
exported_program = exported_program.run_decompositions(
|
234
|
-
lowerings.decompositions()
|
235
|
-
)
|
230
|
+
exported_program = exported_program.run_decompositions(
|
231
|
+
lowerings.decompositions()
|
232
|
+
)
|
236
233
|
|
237
234
|
with export_utils.create_ir_context() as context, ir.Location.unknown():
|
238
235
|
|
@@ -35,7 +35,7 @@ jax.config.update("jax_enable_x64", True)
|
|
35
35
|
|
36
36
|
def _lower_to_ir_text(
|
37
37
|
jaxfn, args, kwargs, ir_input_names: list[str] = None
|
38
|
-
) -> str:
|
38
|
+
) -> tuple[str, list[ir.Value]]:
|
39
39
|
args = utils.tree_map_list_to_tuple(args)
|
40
40
|
kwargs = utils.tree_map_list_to_tuple(kwargs)
|
41
41
|
|
@@ -74,7 +74,9 @@ def _lower_to_ir_text(
|
|
74
74
|
x for x in pytree.tree_flatten(arg)[0] if isinstance(x, ir.Value)
|
75
75
|
]
|
76
76
|
|
77
|
-
def
|
77
|
+
def lower_wrapper(*args):
|
78
|
+
nonlocal jax_lower_static_kwargs
|
79
|
+
|
78
80
|
jaxfn_args = []
|
79
81
|
jaxfn_kwargs = jax_lower_static_kwargs.copy()
|
80
82
|
for name, arg in zip(jax_lower_argnames, args):
|
@@ -85,11 +87,7 @@ def _lower_to_ir_text(
|
|
85
87
|
|
86
88
|
return jaxfn(*jaxfn_args, **jaxfn_kwargs)
|
87
89
|
|
88
|
-
return (
|
89
|
-
jax.jit(new_lowering, static_argnames=static_argnames)
|
90
|
-
.lower(*jax_lower_args, **jax_lower_static_kwargs)
|
91
|
-
.as_text()
|
92
|
-
), ir_inputs
|
90
|
+
return jax.jit(lower_wrapper).lower(*jax_lower_args).as_text(), ir_inputs
|
93
91
|
|
94
92
|
|
95
93
|
def wrap(jaxfn: Callable[Any, Any], ir_input_names: list[str] = None):
|
@@ -52,6 +52,7 @@ class LoweringRegistry:
|
|
52
52
|
|
53
53
|
|
54
54
|
global_registry = LoweringRegistry()
|
55
|
+
global_registry.decompositions.update(torch._decomp.core_aten_decompositions())
|
55
56
|
global_registry.decompositions.update(
|
56
57
|
torch._decomp.get_decompositions([
|
57
58
|
torch.ops.aten.upsample_nearest2d,
|
@@ -70,6 +71,13 @@ global_registry.decompositions.update(
|
|
70
71
|
])
|
71
72
|
)
|
72
73
|
|
74
|
+
torch._decomp.remove_decompositions(
|
75
|
+
global_registry.decompositions,
|
76
|
+
[
|
77
|
+
torch.ops.aten.roll,
|
78
|
+
],
|
79
|
+
)
|
80
|
+
|
73
81
|
|
74
82
|
def lookup(op):
|
75
83
|
return global_registry.lookup(op)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240906
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -30,7 +30,7 @@ Requires-Dist: tabulate
|
|
30
30
|
Requires-Dist: torch>=2.4.0
|
31
31
|
Requires-Dist: torch-xla>=2.4.0
|
32
32
|
Requires-Dist: tf-nightly>=2.18.0.dev20240722
|
33
|
-
Requires-Dist: ai-edge-quantizer-nightly
|
33
|
+
Requires-Dist: ai-edge-quantizer-nightly
|
34
34
|
|
35
35
|
Library that supports converting PyTorch models into a .tflite format, which can
|
36
36
|
then be run with TensorFlow Lite and MediaPipe. This enables applications for
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=vEc_GracKJpLkIs6M45gCFWkBMuXTjmvfvJnfXBSyrs,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -77,7 +77,7 @@ ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=CZVuNEL8OHPkdsz
|
|
77
77
|
ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8vGIk-KtKp9D_H0,20871
|
78
78
|
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
|
79
79
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
80
|
-
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=
|
80
|
+
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_R_pp-A_7gKGSdHWDSXyis97r1ELVI,5622
|
81
81
|
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=l9swUKTcDtnTibNSNExaMgLvDeJ4Er2tVh5ZW1EtRgk,5809
|
82
82
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
|
83
83
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -106,16 +106,14 @@ ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1
|
|
106
106
|
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
|
107
107
|
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
|
108
108
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
109
|
-
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
110
|
-
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=sSHc_4hUEvi-3KmqbpqWbrRKBjCI1AOctM3dr2EH3vk,5263
|
111
109
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
112
110
|
ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
|
113
111
|
ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
|
114
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
115
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
116
|
-
ai_edge_torch/generative/test/test_quantize.py,sha256=
|
112
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=b3InJ8Rx03YtHpE9h-j0pSXAY1cCf-dLlx4Y5LSJnRQ,5174
|
113
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=9JXcd-rX8MpsYeEWUFEXf783GOwYOLY64KzDfFdmRJ8,4484
|
114
|
+
ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
|
117
115
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
118
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
116
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
|
119
117
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
|
120
118
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=_UXcc1QKT-S92hikfo-fTBFhnYLzROqcyRqKonVsqj4,16885
|
121
119
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
@@ -128,13 +126,14 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4
|
|
128
126
|
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
129
127
|
ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
|
130
128
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
131
|
-
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=
|
129
|
+
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
|
132
130
|
ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
|
133
|
-
ai_edge_torch/lowertools/torch_xla_utils.py,sha256
|
131
|
+
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=n6G3pFGmHar7kgKDsdTB74kv1PUuTTu1XjV7R-QizzE,9003
|
132
|
+
ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGPDpE0jvFCCYIzpnc,5667
|
134
133
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
135
134
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
136
135
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
137
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
136
|
+
ai_edge_torch/odml_torch/export.py,sha256=_n43AlaTLvAK6r1szs47gSBqp-x19ZNCNtyFIWzuE4Q,10322
|
138
137
|
ai_edge_torch/odml_torch/export_utils.py,sha256=q84U69ZQ82hLXw-xncJ8IW-K71Xux-NWlzZTs7hdZWA,5127
|
139
138
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
|
140
139
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
@@ -144,7 +143,7 @@ ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=9ag6-WWRG50rPCtIV7OpIokEKu
|
|
144
143
|
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=1xCXOs3-9UcsOyLFH0uyQwLu7c06iYFTo0NQ7Ckbl2I,1465
|
145
144
|
ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
|
146
145
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
|
147
|
-
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=
|
146
|
+
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
|
148
147
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
149
148
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GqYk6oBJw7KWeG4_6gxSu_OvYhjJcC2FpGzWPPEdH6w,933
|
150
149
|
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
|
@@ -152,7 +151,7 @@ ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_
|
|
152
151
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=B6BILeu-UlwGB1O6g7111X1TaIFznsfxXrB72ygBsBA,3885
|
153
152
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=I0Y4IK7Zap8m6xfxMw7DfQ9Mg4htKOoypdHVAMHqx9c,10669
|
154
153
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
155
|
-
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=
|
154
|
+
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
|
156
155
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
|
157
156
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
158
157
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
@@ -162,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
162
161
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
163
162
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
164
163
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
165
|
-
ai_edge_torch_nightly-0.3.0.
|
166
|
-
ai_edge_torch_nightly-0.3.0.
|
167
|
-
ai_edge_torch_nightly-0.3.0.
|
168
|
-
ai_edge_torch_nightly-0.3.0.
|
169
|
-
ai_edge_torch_nightly-0.3.0.
|
164
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
165
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/METADATA,sha256=u4yKvulxsV9xZmKSKnNO6L_FE8P_Iy96IZ0UL_voxAE,1859
|
166
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
167
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
168
|
+
ai_edge_torch_nightly-0.3.0.dev20240906.dist-info/RECORD,,
|
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
File without changes
|
File without changes
|