ai-edge-torch-nightly 0.3.0.dev20240817__py3-none-any.whl → 0.3.0.dev20240823__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/_convert/signature.py +2 -36
- ai_edge_torch/_convert/test/test_convert.py +32 -1
- ai_edge_torch/_convert/test/test_convert_multisig.py +2 -2
- ai_edge_torch/generative/examples/gemma/gemma2.py +1 -2
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +6 -4
- ai_edge_torch/generative/utilities/loader.py +3 -2
- ai_edge_torch/lowertools/__init__.py +1 -0
- ai_edge_torch/lowertools/common_utils.py +53 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +22 -5
- ai_edge_torch/lowertools/torch_xla_utils.py +24 -13
- ai_edge_torch/model.py +13 -5
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240817.dist-info → ai_edge_torch_nightly-0.3.0.dev20240823.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240817.dist-info → ai_edge_torch_nightly-0.3.0.dev20240823.dist-info}/RECORD +17 -17
- {ai_edge_torch_nightly-0.3.0.dev20240817.dist-info → ai_edge_torch_nightly-0.3.0.dev20240823.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240817.dist-info → ai_edge_torch_nightly-0.3.0.dev20240823.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240817.dist-info → ai_edge_torch_nightly-0.3.0.dev20240823.dist-info}/top_level.txt +0 -0
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
import dataclasses
|
|
17
17
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
18
18
|
|
|
19
|
+
from ai_edge_torch import lowertools
|
|
19
20
|
import torch
|
|
20
21
|
import torch.utils._pytree as pytree
|
|
21
22
|
|
|
@@ -53,47 +54,12 @@ class Signature:
|
|
|
53
54
|
for i in range(args_spec.num_leaves):
|
|
54
55
|
names.append(f"args_{i}")
|
|
55
56
|
|
|
56
|
-
kwargs_names =
|
|
57
|
+
kwargs_names = lowertools.flat_dict_names(
|
|
57
58
|
kwargs_spec.children_specs, kwargs_spec.context
|
|
58
59
|
)
|
|
59
60
|
names.extend(kwargs_names)
|
|
60
61
|
return names
|
|
61
62
|
|
|
62
|
-
def _flat_kwarg_names(self, specs, context) -> List[str]:
|
|
63
|
-
flat_names = []
|
|
64
|
-
if context is None:
|
|
65
|
-
for i, spec in enumerate(specs):
|
|
66
|
-
if spec.children_specs:
|
|
67
|
-
flat_names.extend([
|
|
68
|
-
f"{i}_{name}"
|
|
69
|
-
for name in self._flat_kwarg_names(
|
|
70
|
-
spec.children_specs, spec.context
|
|
71
|
-
)
|
|
72
|
-
])
|
|
73
|
-
else:
|
|
74
|
-
flat_names.append(f"{i}")
|
|
75
|
-
else:
|
|
76
|
-
flat_ctx = self._flatten_list(context)
|
|
77
|
-
for prefix, spec in zip(flat_ctx, specs):
|
|
78
|
-
leaf_flat_names = self._flat_kwarg_names(
|
|
79
|
-
spec.children_specs, spec.context
|
|
80
|
-
)
|
|
81
|
-
if leaf_flat_names:
|
|
82
|
-
flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
|
|
83
|
-
else:
|
|
84
|
-
flat_names.append(prefix)
|
|
85
|
-
|
|
86
|
-
return flat_names
|
|
87
|
-
|
|
88
|
-
def _flatten_list(self, l: List) -> List:
|
|
89
|
-
flattened = []
|
|
90
|
-
for item in l:
|
|
91
|
-
if isinstance(item, list):
|
|
92
|
-
flattened.extend(self._flatten_list(item))
|
|
93
|
-
else:
|
|
94
|
-
flattened.append(item)
|
|
95
|
-
return flattened
|
|
96
|
-
|
|
97
63
|
@property
|
|
98
64
|
def flat_args(self) -> tuple[Any]:
|
|
99
65
|
args, kwargs = self._normalized_sample_args_kwargs
|
|
@@ -174,7 +174,7 @@ class TestConvert(googletest.TestCase):
|
|
|
174
174
|
self.assertTrue(result)
|
|
175
175
|
|
|
176
176
|
def test_12_outputs_model(self):
|
|
177
|
-
"""Tests conversion of a model that returns
|
|
177
|
+
"""Tests conversion of a model that returns more than 10 outputs."""
|
|
178
178
|
|
|
179
179
|
class BasicAddModelWithMultipleOutputs(torch.nn.Module):
|
|
180
180
|
"""A model that returns multiple outputs."""
|
|
@@ -421,6 +421,37 @@ class TestConvert(googletest.TestCase):
|
|
|
421
421
|
SampleModel(), args, kwargs, flat_inputs
|
|
422
422
|
)
|
|
423
423
|
|
|
424
|
+
def test_convert_model_non_flat_output_dict(self):
|
|
425
|
+
"""Test converting a model with non-flat output structure."""
|
|
426
|
+
|
|
427
|
+
class SampleModel(torch.nn.Module):
|
|
428
|
+
|
|
429
|
+
def forward(self, x, y, z):
|
|
430
|
+
return {"x": x, "y": TestContainer1(data_1=y, data_2=[y, z])}
|
|
431
|
+
|
|
432
|
+
args = (torch.randn(10, 10), torch.randn(10, 10), torch.randn(10, 10))
|
|
433
|
+
kwargs = dict()
|
|
434
|
+
flat_inputs = {
|
|
435
|
+
"args_0": args[0].numpy(),
|
|
436
|
+
"args_1": args[1].numpy(),
|
|
437
|
+
"args_2": args[2].numpy(),
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
edge_model = ai_edge_torch.convert(SampleModel().eval(), args, kwargs)
|
|
441
|
+
edge_output = edge_model(**flat_inputs)
|
|
442
|
+
np.testing.assert_almost_equal(edge_output["x"], args[0])
|
|
443
|
+
np.testing.assert_almost_equal(edge_output["y_data_1"], args[1])
|
|
444
|
+
np.testing.assert_almost_equal(edge_output["y_data_2_0"], args[1])
|
|
445
|
+
np.testing.assert_almost_equal(edge_output["y_data_2_1"], args[2])
|
|
446
|
+
|
|
447
|
+
interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
|
|
448
|
+
runner = interpreter.get_signature_runner("serving_default")
|
|
449
|
+
output_details = runner.get_output_details()
|
|
450
|
+
self.assertIn("x", output_details.keys())
|
|
451
|
+
self.assertIn("y_data_1", output_details.keys())
|
|
452
|
+
self.assertIn("y_data_2_0", output_details.keys())
|
|
453
|
+
self.assertIn("y_data_2_1", output_details.keys())
|
|
454
|
+
|
|
424
455
|
def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
|
|
425
456
|
model.eval()
|
|
426
457
|
edge_model = ai_edge_torch.convert(model, args, kwargs)
|
|
@@ -94,8 +94,8 @@ class TestConvertMultiSignature(googletest.TestCase):
|
|
|
94
94
|
signature_name = "large_input"
|
|
95
95
|
|
|
96
96
|
edge_model = ai_edge_torch.signature(
|
|
97
|
-
signature_name, torch_module,
|
|
98
|
-
).convert(torch_module,
|
|
97
|
+
signature_name, torch_module, large_args
|
|
98
|
+
).convert(torch_module, args)
|
|
99
99
|
|
|
100
100
|
self.assertTrue(
|
|
101
101
|
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
|
|
@@ -203,7 +203,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
|
203
203
|
final_norm_config=norm_config,
|
|
204
204
|
parallel_residual=False,
|
|
205
205
|
lm_head_use_bias=False,
|
|
206
|
-
enable_hlfb=
|
|
206
|
+
enable_hlfb=True,
|
|
207
207
|
final_logit_softcap=30.0,
|
|
208
208
|
)
|
|
209
209
|
return config
|
|
@@ -242,7 +242,6 @@ def define_and_run_2b() -> None:
|
|
|
242
242
|
out = model.forward(tokens, input_pos)
|
|
243
243
|
out_final = out[0, 8, :]
|
|
244
244
|
assert torch.allclose(gemma2_goldens, out_final, atol=1e-04)
|
|
245
|
-
print(out)
|
|
246
245
|
|
|
247
246
|
|
|
248
247
|
if __name__ == "__main__":
|
|
@@ -99,14 +99,16 @@ def scaled_dot_product_attention_with_hlfb(
|
|
|
99
99
|
The output tensor of scaled_dot_product_attention.
|
|
100
100
|
"""
|
|
101
101
|
|
|
102
|
-
if softcap is not None:
|
|
103
|
-
raise NotImplementedError("SDPA with HLFB not available with softcap.")
|
|
104
|
-
|
|
105
102
|
if scale is None:
|
|
106
103
|
scale = 1.0 / math.sqrt(head_size)
|
|
107
104
|
|
|
105
|
+
attrs = {"scale": scale}
|
|
106
|
+
|
|
107
|
+
if softcap is not None:
|
|
108
|
+
attrs["logit_cap"] = softcap
|
|
109
|
+
|
|
108
110
|
builder = StableHLOCompositeBuilder(
|
|
109
|
-
name="odml.scaled_dot_product_attention", attr=
|
|
111
|
+
name="odml.scaled_dot_product_attention", attr=attrs
|
|
110
112
|
)
|
|
111
113
|
q, k, v, mask = builder.mark_inputs(q, k, v, mask)
|
|
112
114
|
|
|
@@ -72,7 +72,7 @@ def load_pytorch_statedict(full_path: str):
|
|
|
72
72
|
patterns = []
|
|
73
73
|
if os.path.isdir(full_path):
|
|
74
74
|
patterns.append(os.path.join(full_path, "*.bin"))
|
|
75
|
-
patterns.append(os.path.join(full_path, "
|
|
75
|
+
patterns.append(os.path.join(full_path, "*pt"))
|
|
76
76
|
else:
|
|
77
77
|
patterns.append(full_path)
|
|
78
78
|
for pattern in patterns:
|
|
@@ -149,6 +149,7 @@ class ModelLoader:
|
|
|
149
149
|
enabled.
|
|
150
150
|
"""
|
|
151
151
|
state = self._loader(self._file_name)
|
|
152
|
+
state = state["model_state_dict"] if "model_state_dict" in state else state
|
|
152
153
|
converted_state = dict()
|
|
153
154
|
if self._names.embedding is not None:
|
|
154
155
|
converted_state["tok_embedding.weight"] = state.pop(
|
|
@@ -200,7 +201,7 @@ class ModelLoader:
|
|
|
200
201
|
if glob.glob(os.path.join(self._file_name, "*.safetensors")):
|
|
201
202
|
return load_safetensors
|
|
202
203
|
if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob(
|
|
203
|
-
os.path.join(self._file_name, "
|
|
204
|
+
os.path.join(self._file_name, "*pt")
|
|
204
205
|
):
|
|
205
206
|
return load_pytorch_statedict
|
|
206
207
|
|
|
@@ -14,10 +14,63 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
import logging
|
|
17
|
+
from typing import List
|
|
17
18
|
|
|
18
19
|
from ai_edge_torch._convert import signature as signature_module
|
|
19
20
|
import tensorflow as tf
|
|
20
21
|
import torch
|
|
22
|
+
import torch.utils._pytree as pytree
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _flatten_list(l: List) -> List:
|
|
26
|
+
flattened = []
|
|
27
|
+
for item in l:
|
|
28
|
+
if isinstance(item, list):
|
|
29
|
+
flattened.extend(_flatten_list(item))
|
|
30
|
+
else:
|
|
31
|
+
flattened.append(item)
|
|
32
|
+
return flattened
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def flat_dict_names(
|
|
36
|
+
tree_spec: pytree.TreeSpec, context: pytree.Context
|
|
37
|
+
) -> List[str]:
|
|
38
|
+
"""Given a TreeSpec, this produces a list of names for the leaves.
|
|
39
|
+
|
|
40
|
+
The list of names embeddeds the structure of the tree_spec. A nesting level is
|
|
41
|
+
indicated by an `_` and elements in a list are indicated by `_<index>`.
|
|
42
|
+
|
|
43
|
+
TODO b/361601485: The flattening of names is not collision-free and needs to
|
|
44
|
+
be revised.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
tree_spec: The TreeSpec to extract the names from.
|
|
48
|
+
context: The context used to check if the provided spec belongs to a
|
|
49
|
+
dictionary or a list.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
A list of flattened names.
|
|
53
|
+
"""
|
|
54
|
+
flat_names = []
|
|
55
|
+
if context is None:
|
|
56
|
+
for i, spec in enumerate(tree_spec):
|
|
57
|
+
if spec.children_specs:
|
|
58
|
+
flat_names.extend([
|
|
59
|
+
f"{i}_{name}"
|
|
60
|
+
for name in flat_dict_names(spec.children_specs, spec.context)
|
|
61
|
+
])
|
|
62
|
+
else:
|
|
63
|
+
flat_names.append(f"{i}")
|
|
64
|
+
else:
|
|
65
|
+
flat_ctx = _flatten_list(context)
|
|
66
|
+
for prefix, spec in zip(flat_ctx, tree_spec):
|
|
67
|
+
leaf_flat_names = flat_dict_names(spec.children_specs, spec.context)
|
|
68
|
+
if leaf_flat_names:
|
|
69
|
+
flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
|
|
70
|
+
else:
|
|
71
|
+
flat_names.append(prefix)
|
|
72
|
+
|
|
73
|
+
return flat_names
|
|
21
74
|
|
|
22
75
|
|
|
23
76
|
def _torch_to_tf_variable(torch_tensor: torch.Tensor):
|
|
@@ -38,6 +38,7 @@ class MergedBundle:
|
|
|
38
38
|
"""A bundle of MlirLowered that has been merged."""
|
|
39
39
|
|
|
40
40
|
bundles: list[odml_torch.export.MlirLowered]
|
|
41
|
+
exported_programs: list[torch.export.ExportedProgram]
|
|
41
42
|
deduped_tf_vars: list[tf.Variable]
|
|
42
43
|
|
|
43
44
|
|
|
@@ -74,12 +75,16 @@ def _extract_call_args(
|
|
|
74
75
|
return call_args
|
|
75
76
|
|
|
76
77
|
|
|
77
|
-
def _wrap_as_tf_func(
|
|
78
|
+
def _wrap_as_tf_func(
|
|
79
|
+
bundle: export.MlirLowered,
|
|
80
|
+
tf_state_dict: Dict[str, tf.Variable],
|
|
81
|
+
exported_program: torch.export.ExportedProgram,
|
|
82
|
+
):
|
|
78
83
|
def inner(*args):
|
|
79
84
|
t_outs = [torch_dtype_to_tf(sig.dtype) for sig in bundle.output_signature]
|
|
80
85
|
s_outs = [_get_shape_with_dynamic(sig) for sig in bundle.output_signature]
|
|
81
86
|
call_args = _extract_call_args(bundle, args, tf_state_dict)
|
|
82
|
-
|
|
87
|
+
call_module_return = tfxla.call_module(
|
|
83
88
|
tuple(call_args),
|
|
84
89
|
version=5,
|
|
85
90
|
Tout=t_outs, # dtype information
|
|
@@ -87,6 +92,14 @@ def _wrap_as_tf_func(bundle, tf_state_dict):
|
|
|
87
92
|
function_list=[],
|
|
88
93
|
module=bundle.module_bytecode,
|
|
89
94
|
)
|
|
95
|
+
spec = exported_program.call_spec.out_spec
|
|
96
|
+
|
|
97
|
+
# The module returning a flat array.
|
|
98
|
+
if not spec.context:
|
|
99
|
+
return call_module_return
|
|
100
|
+
|
|
101
|
+
flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
|
|
102
|
+
return {name: value for name, value in zip(flat_names, call_module_return)}
|
|
90
103
|
|
|
91
104
|
return inner
|
|
92
105
|
|
|
@@ -128,8 +141,10 @@ def merged_bundle_to_tfl_model(
|
|
|
128
141
|
for bundle, sig in zip(merged_bundle.bundles, signatures)
|
|
129
142
|
]
|
|
130
143
|
tf_functions = [
|
|
131
|
-
_wrap_as_tf_func(bundle, tf_state_dict)
|
|
132
|
-
for bundle in
|
|
144
|
+
_wrap_as_tf_func(bundle, tf_state_dict, ep)
|
|
145
|
+
for bundle, ep in zip(
|
|
146
|
+
merged_bundle.bundles, merged_bundle.exported_programs
|
|
147
|
+
)
|
|
133
148
|
]
|
|
134
149
|
|
|
135
150
|
tf_module = tf.Module()
|
|
@@ -202,7 +217,9 @@ def merge_mlir_bundles(
|
|
|
202
217
|
)
|
|
203
218
|
|
|
204
219
|
merged_bundle = MergedBundle(
|
|
205
|
-
bundles=bundles.copy(),
|
|
220
|
+
bundles=bundles.copy(),
|
|
221
|
+
exported_programs=exported_programs,
|
|
222
|
+
deduped_tf_vars=deduped_vars,
|
|
206
223
|
)
|
|
207
224
|
for bundle, signature in zip(merged_bundle.bundles, signatures):
|
|
208
225
|
bundle.state_dict = state_dict
|
|
@@ -51,6 +51,7 @@ MlirBundle = stablehlo.StableHLOModelBundle
|
|
|
51
51
|
class MergedBundle:
|
|
52
52
|
|
|
53
53
|
bundle: stablehlo.StableHLOModelBundle
|
|
54
|
+
exported_programs: list[torch.export.ExportedProgram]
|
|
54
55
|
deduped_tf_vars: list[tf.Variable]
|
|
55
56
|
|
|
56
57
|
|
|
@@ -58,9 +59,9 @@ def exported_program_to_mlir(
|
|
|
58
59
|
exported_program: torch.export.ExportedProgram,
|
|
59
60
|
sample_args: tuple[torch.Tensor],
|
|
60
61
|
) -> stablehlo.StableHLOModelBundle:
|
|
61
|
-
# Setting export_weights to False here so that pytorch/xla avoids copying the
|
|
62
|
-
# to a numpy array which would lead to memory bloat. This means that
|
|
63
|
-
# in the returned bundle is going to be empty.
|
|
62
|
+
# Setting export_weights to False here so that pytorch/xla avoids copying the
|
|
63
|
+
# weights to a numpy array which would lead to memory bloat. This means that
|
|
64
|
+
# the state_dict in the returned bundle is going to be empty.
|
|
64
65
|
return stablehlo.exported_program_to_stablehlo(
|
|
65
66
|
exported_program,
|
|
66
67
|
stablehlo.StableHLOExportOptions(
|
|
@@ -96,7 +97,9 @@ def merge_mlir_bundles(
|
|
|
96
97
|
bundle.additional_constants
|
|
97
98
|
)
|
|
98
99
|
return MergedBundle(
|
|
99
|
-
bundle=new_shlo_model_bundle,
|
|
100
|
+
bundle=new_shlo_model_bundle,
|
|
101
|
+
exported_programs=exported_programs,
|
|
102
|
+
deduped_tf_vars=deduped_tf_vars,
|
|
100
103
|
)
|
|
101
104
|
|
|
102
105
|
|
|
@@ -108,7 +111,9 @@ def _get_shape_with_dynamic(signature: stablehlo.VariableSignature):
|
|
|
108
111
|
|
|
109
112
|
|
|
110
113
|
def _wrap_as_tf_func(
|
|
111
|
-
func: stablehlo.StableHLOFunc,
|
|
114
|
+
func: stablehlo.StableHLOFunc,
|
|
115
|
+
bundle: stablehlo.StableHLOModelBundle,
|
|
116
|
+
exported_program: torch.export.ExportedProgram,
|
|
112
117
|
):
|
|
113
118
|
def inner(*args):
|
|
114
119
|
type_info = [sig.dtype for sig in func.meta.output_signature]
|
|
@@ -116,7 +121,7 @@ def _wrap_as_tf_func(
|
|
|
116
121
|
_get_shape_with_dynamic(sig) for sig in func.meta.output_signature
|
|
117
122
|
]
|
|
118
123
|
call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
|
|
119
|
-
|
|
124
|
+
call_module_return = tfxla.call_module(
|
|
120
125
|
tuple(call_args),
|
|
121
126
|
version=5,
|
|
122
127
|
Tout=type_info,
|
|
@@ -124,15 +129,16 @@ def _wrap_as_tf_func(
|
|
|
124
129
|
function_list=[],
|
|
125
130
|
module=func.bytecode,
|
|
126
131
|
)
|
|
132
|
+
spec = exported_program.call_spec.out_spec
|
|
127
133
|
|
|
128
|
-
|
|
134
|
+
# The module returning a flat array.
|
|
135
|
+
if not spec.context:
|
|
136
|
+
return call_module_return
|
|
129
137
|
|
|
138
|
+
flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
|
|
139
|
+
return {name: value for name, value in zip(flat_names, call_module_return)}
|
|
130
140
|
|
|
131
|
-
|
|
132
|
-
bundle: stablehlo.StableHLOModelBundle = None,
|
|
133
|
-
):
|
|
134
|
-
bundle = bundle if bundle is None else bundle
|
|
135
|
-
return [_wrap_as_tf_func(func, bundle) for func in bundle.stablehlo_funcs]
|
|
141
|
+
return inner
|
|
136
142
|
|
|
137
143
|
|
|
138
144
|
def _make_tf_signature(
|
|
@@ -205,7 +211,12 @@ def merged_bundle_to_tfl_model(
|
|
|
205
211
|
for func, sig in zip(shlo_bundle.stablehlo_funcs, signatures)
|
|
206
212
|
)
|
|
207
213
|
|
|
208
|
-
tf_functions =
|
|
214
|
+
tf_functions = [
|
|
215
|
+
_wrap_as_tf_func(func, shlo_bundle, ep)
|
|
216
|
+
for func, ep in zip(
|
|
217
|
+
shlo_bundle.stablehlo_funcs, merged_bundle.exported_programs
|
|
218
|
+
)
|
|
219
|
+
]
|
|
209
220
|
|
|
210
221
|
tf_module.f = []
|
|
211
222
|
for tf_sig, func in zip(tf_signatures, tf_functions):
|
ai_edge_torch/model.py
CHANGED
|
@@ -21,6 +21,7 @@ PyTorch models can be converted to this representation through
|
|
|
21
21
|
from __future__ import annotations
|
|
22
22
|
|
|
23
23
|
import abc
|
|
24
|
+
import re
|
|
24
25
|
|
|
25
26
|
import numpy.typing as npt
|
|
26
27
|
import tensorflow as tf
|
|
@@ -115,11 +116,18 @@ class TfLiteModel(Model):
|
|
|
115
116
|
inputs = {**inputs, **kwargs}
|
|
116
117
|
outputs = runner(**inputs)
|
|
117
118
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
)
|
|
119
|
+
# When attempting to run a model, check if all the output tensors are named
|
|
120
|
+
# output_<number>. If so, assume the pytorch model returned a tuple and not
|
|
121
|
+
# a dictionary.
|
|
122
|
+
output_heuristic = lambda key: bool(re.search(r'output_\d+', key))
|
|
123
|
+
if all(output_heuristic(key) for key in outputs.keys()):
|
|
124
|
+
return (
|
|
125
|
+
outputs['output_0']
|
|
126
|
+
if len(outputs) == 1
|
|
127
|
+
else [outputs[f'output_{idx}'] for idx in range(len(outputs))]
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
return outputs
|
|
123
131
|
|
|
124
132
|
def export(self, path: str) -> None:
|
|
125
133
|
"""Serializes the edge model to disk.
|
ai_edge_torch/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.3.0.
|
|
3
|
+
Version: 0.3.0.dev20240823
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,1168
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
|
4
|
-
ai_edge_torch/model.py,sha256=
|
|
5
|
-
ai_edge_torch/version.py,sha256=
|
|
4
|
+
ai_edge_torch/model.py,sha256=7tox6sdFIlCYPLDYpjFcD8cPTSivURCL_VV6-Dt5Sfc,4910
|
|
5
|
+
ai_edge_torch/version.py,sha256=a63GrjqX4sRjk0WbC_0gGhT-ax_TLEi4iCEw0Iys7bw,706
|
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
|
9
9
|
ai_edge_torch/_convert/converter.py,sha256=ezmaATnQi7NWDo37LUb-hEXtZSmT7_AT6vqXC6Fcq1o,8615
|
|
10
|
-
ai_edge_torch/_convert/signature.py,sha256=
|
|
10
|
+
ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
|
|
11
11
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
|
12
12
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=hVuNftOcZIpwkUcPab44mhFmi9Z1f1REV5o3j39Sf-w,2818
|
|
13
13
|
ai_edge_torch/_convert/fx_passes/_pass_base.py,sha256=WVYZuocpygHAzk9u1GNoGowAIOHTlJXyA_NklmYkRms,1672
|
|
@@ -26,9 +26,9 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
|
26
26
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
|
27
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
|
28
28
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
29
|
-
ai_edge_torch/_convert/test/test_convert.py,sha256=
|
|
29
|
+
ai_edge_torch/_convert/test/test_convert.py,sha256=tvj7fWHHmA9ddtcu-Fp3lJ6emaAQMrtK9wCG0cjgRAo,14413
|
|
30
30
|
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=CBiOqq-m7QT2ggBI1jBl9MkTIT5d0nK1tA0BUga0LGs,7994
|
|
31
|
-
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=
|
|
31
|
+
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=0xIkahEU26Qx9GGn6Dm05ObIqJvsCdh692dREcaHEdE,4725
|
|
32
32
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=jLAmyHw5llT2ff8qA8mem3eVN57e_o5EpBnW72ZtP2I,3026
|
|
33
33
|
ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
|
|
34
34
|
ai_edge_torch/debug/culprit.py,sha256=7UYVpVWpiCXbMAyThVtHt_kc_poT7sCTh5UUPvcycgk,14832
|
|
@@ -53,7 +53,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
|
|
|
53
53
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=pseJExH35lSAK0ZtzSHB1sFtRtF_EuT2xcSpGU0gKVI,2524
|
|
54
54
|
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=w589IJETATd6Z9_1XCIWbrlCV3E92X_5ac3VVCVFXG0,2522
|
|
55
55
|
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=cCki-0cKvmGxK4Md6dRNdPDWZUyhkJUI854OCTFf3h0,6262
|
|
56
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
|
56
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=q9Zil66EvRKrSpLVQHxKHu_8NL0HAgY2FbtThoTZVUY,8226
|
|
57
57
|
ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
58
58
|
ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=ON6zLO-nFS8eJ2yhyWzT5x2Somr-Ca-VjpjT7OGFU10,2506
|
|
59
59
|
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=C_kFYsPrEQ9GJCnc6h-jh8B5qQryvEpI6O6t4FBxg1I,5858
|
|
@@ -94,7 +94,7 @@ ai_edge_torch/generative/layers/kv_cache.py,sha256=Ob8QeXWW5xt-6hcGA0uoC48eRQ8lf
|
|
|
94
94
|
ai_edge_torch/generative/layers/model_config.py,sha256=WpZ9djUBAZddyeSODHDaVMG37EQqfzGGrlMPi8AA-Hc,5752
|
|
95
95
|
ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQcfDK9JM2IaUQFQdn7xs,1860
|
|
96
96
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
|
97
|
-
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=
|
|
97
|
+
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
|
|
98
98
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
99
99
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=4a0wh0co8Avz1wvxS3XqsgrgL5G-X1GSARI5Rj3L-xg,26995
|
|
100
100
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
|
@@ -114,7 +114,7 @@ ai_edge_torch/generative/test/test_loader.py,sha256=1ZqAq0HY5uIioumsReOVIsbGBx0W
|
|
|
114
114
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=52ciFy_Qol2Xuym6P6EqdL29oai35LSWGvsUwyEdFTo,8477
|
|
115
115
|
ai_edge_torch/generative/test/test_quantize.py,sha256=3SmJm7Kq98gAneU6IGwwJrJYCVH1qwWR6oUxPfb6qiI,5346
|
|
116
116
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
117
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
|
117
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=QFZ2lkeoYQ9MZ1CAFVxBHG4OT192SH74UtJCvbDsdeI,12727
|
|
118
118
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
|
|
119
119
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=_UXcc1QKT-S92hikfo-fTBFhnYLzROqcyRqKonVsqj4,16885
|
|
120
120
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
|
@@ -124,12 +124,12 @@ ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=uiYRfzD1T8deCEAGfdAFusRbI41m14
|
|
|
124
124
|
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
125
125
|
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=e53YNSO2w7Sd9Y717jAr6WKjnXq34Tx_52hXRGtGs3A,4833
|
|
126
126
|
ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=7Qbba7GJCBc-J1TUwWIvrpBK0Hwza9nift7sKpW2YVE,8449
|
|
127
|
-
ai_edge_torch/lowertools/__init__.py,sha256=
|
|
127
|
+
ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
|
|
128
128
|
ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
|
|
129
|
-
ai_edge_torch/lowertools/common_utils.py,sha256=
|
|
130
|
-
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=
|
|
129
|
+
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
|
130
|
+
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=EA2ylE4abCyC1G3lYuNouPGx0lmwtIZe7c42dtX0-3g,7146
|
|
131
131
|
ai_edge_torch/lowertools/test_utils.py,sha256=vsjaX3Ix2U1163jVUNSJgK9io2WNUtJjRvNFE9DrqF4,1932
|
|
132
|
-
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=-
|
|
132
|
+
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=-SRm9YNsIGsaVd5Cyp2PP-tdLBJH8EDoMFAa2y89a1w,9043
|
|
133
133
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
|
134
134
|
ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
|
|
135
135
|
ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=eARD1LxLi5m7Z0n_psAkeX_AtUp4fNkE--oECBfivv4,36208
|
|
@@ -137,8 +137,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
|
137
137
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
138
138
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
139
139
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
|
140
|
-
ai_edge_torch_nightly-0.3.0.
|
|
141
|
-
ai_edge_torch_nightly-0.3.0.
|
|
142
|
-
ai_edge_torch_nightly-0.3.0.
|
|
143
|
-
ai_edge_torch_nightly-0.3.0.
|
|
144
|
-
ai_edge_torch_nightly-0.3.0.
|
|
140
|
+
ai_edge_torch_nightly-0.3.0.dev20240823.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
141
|
+
ai_edge_torch_nightly-0.3.0.dev20240823.dist-info/METADATA,sha256=OPYTq5RQCL2lvIFeBKOwbFusi4rq_Qo2ytgn_JQTVb0,1885
|
|
142
|
+
ai_edge_torch_nightly-0.3.0.dev20240823.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
|
143
|
+
ai_edge_torch_nightly-0.3.0.dev20240823.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
144
|
+
ai_edge_torch_nightly-0.3.0.dev20240823.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|