ai-edge-torch-nightly 0.2.0.dev20240718__py3-none-any.whl → 0.2.0.dev20240719__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/conversion_utils.py +39 -18
- ai_edge_torch/convert/test/test_convert.py +106 -0
- {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240719.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240719.dist-info}/RECORD +7 -7
- {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240719.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240719.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240718.dist-info → ai_edge_torch_nightly-0.2.0.dev20240719.dist-info}/top_level.txt +0 -0
|
@@ -20,7 +20,7 @@ import gc
|
|
|
20
20
|
import itertools
|
|
21
21
|
import logging
|
|
22
22
|
import tempfile
|
|
23
|
-
from typing import Any, Dict, Optional, Tuple, Union
|
|
23
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
24
24
|
|
|
25
25
|
import torch
|
|
26
26
|
import torch.utils._pytree as pytree
|
|
@@ -79,28 +79,49 @@ class Signature:
|
|
|
79
79
|
for i in range(args_spec.num_leaves):
|
|
80
80
|
names.append(f"args_{i}")
|
|
81
81
|
|
|
82
|
-
|
|
83
|
-
kwargs_spec.context
|
|
84
|
-
if kwargs_spec.type is not collections.defaultdict
|
|
85
|
-
# ignore mismatch of `default_factory` for defaultdict
|
|
86
|
-
else kwargs_spec.context[1]
|
|
82
|
+
kwargs_names = self._flat_kwarg_names(
|
|
83
|
+
kwargs_spec.children_specs, kwargs_spec.context
|
|
87
84
|
)
|
|
85
|
+
names.extend(kwargs_names)
|
|
86
|
+
return names
|
|
88
87
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
88
|
+
def _flat_kwarg_names(self, specs, context) -> List[str]:
|
|
89
|
+
flat_names = []
|
|
90
|
+
if context is None:
|
|
91
|
+
for i, spec in enumerate(specs):
|
|
92
|
+
if spec.children_specs:
|
|
93
|
+
flat_names.extend(
|
|
94
|
+
[
|
|
95
|
+
f"{i}_{name}"
|
|
96
|
+
for name in self._flat_kwarg_names(spec.children_specs, spec.context)
|
|
97
|
+
]
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
flat_names.append(f"{i}")
|
|
101
|
+
else:
|
|
102
|
+
flat_ctx = self._flatten_list(context)
|
|
103
|
+
for prefix, spec in zip(flat_ctx, specs):
|
|
104
|
+
leaf_flat_names = self._flat_kwarg_names(spec.children_specs, spec.context)
|
|
105
|
+
if leaf_flat_names:
|
|
106
|
+
flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
|
|
107
|
+
else:
|
|
108
|
+
flat_names.append(prefix)
|
|
109
|
+
|
|
110
|
+
return flat_names
|
|
111
|
+
|
|
112
|
+
def _flatten_list(self, l: List) -> List:
|
|
113
|
+
flattened = []
|
|
114
|
+
for item in l:
|
|
115
|
+
if isinstance(item, list):
|
|
116
|
+
flattened.extend(self._flatten_list(item))
|
|
92
117
|
else:
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
# tensor containers as inputs.
|
|
96
|
-
# TODO(b/352584188): Decide the behavior of tensor container as input (flatten or reject)
|
|
97
|
-
for i in range(value_spec.num_leaves):
|
|
98
|
-
names.append(f"{name}_{i}")
|
|
99
|
-
return names
|
|
118
|
+
flattened.append(item)
|
|
119
|
+
return flattened
|
|
100
120
|
|
|
101
121
|
@property
|
|
102
|
-
def flat_args(self) -> tuple[
|
|
103
|
-
|
|
122
|
+
def flat_args(self) -> tuple[Any]:
|
|
123
|
+
args, kwargs = self._normalized_sample_args_kwargs
|
|
124
|
+
return tuple([*args, *kwargs.values()])
|
|
104
125
|
|
|
105
126
|
|
|
106
127
|
def exported_program_to_stablehlo_bundle(
|
|
@@ -14,10 +14,14 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
|
|
16
16
|
|
|
17
|
+
from dataclasses import dataclass
|
|
17
18
|
import os
|
|
18
19
|
import tempfile
|
|
20
|
+
from typing import Tuple
|
|
19
21
|
import unittest
|
|
20
22
|
|
|
23
|
+
import numpy as np
|
|
24
|
+
import tensorflow as tf
|
|
21
25
|
import torch
|
|
22
26
|
import torchvision
|
|
23
27
|
|
|
@@ -26,6 +30,15 @@ from ai_edge_torch.convert import conversion_utils as cutils
|
|
|
26
30
|
from ai_edge_torch.testing import model_coverage
|
|
27
31
|
|
|
28
32
|
|
|
33
|
+
@dataclass
|
|
34
|
+
class TestContainer1:
|
|
35
|
+
data_1: torch.Tensor
|
|
36
|
+
data_2: Tuple[torch.Tensor, torch.Tensor]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
torch.export.register_dataclass(TestContainer1, serialized_type_name="TestContainer1")
|
|
40
|
+
|
|
41
|
+
|
|
29
42
|
class TestConvert(unittest.TestCase):
|
|
30
43
|
"""Tests conversion of various modules."""
|
|
31
44
|
|
|
@@ -306,6 +319,99 @@ class TestConvert(unittest.TestCase):
|
|
|
306
319
|
model_coverage.compare_tflite_torch(edge_model, model, args_gen, kwargs_gen)
|
|
307
320
|
)
|
|
308
321
|
|
|
322
|
+
def test_convert_model_with_args_nested_kwargs_1(self):
|
|
323
|
+
"""
|
|
324
|
+
Test converting a simple model with both sample_args and nested sample_kwargs.
|
|
325
|
+
"""
|
|
326
|
+
|
|
327
|
+
class SampleModel(torch.nn.Module):
|
|
328
|
+
|
|
329
|
+
def forward(self, x: torch.Tensor, y: torch.Tensor, z: TestContainer1):
|
|
330
|
+
return x + y + z.data_1 + z.data_2[0] + z.data_2[1]
|
|
331
|
+
|
|
332
|
+
args = (torch.randn(10, 10),)
|
|
333
|
+
kwargs = dict(
|
|
334
|
+
y=torch.randn(10, 10),
|
|
335
|
+
z=TestContainer1(
|
|
336
|
+
data_1=torch.randn(10, 10),
|
|
337
|
+
data_2=(torch.randn(10, 10), torch.randn(10, 10)),
|
|
338
|
+
),
|
|
339
|
+
)
|
|
340
|
+
flat_inputs = {
|
|
341
|
+
"args_0": args[0].numpy(),
|
|
342
|
+
"y": kwargs["y"].numpy(),
|
|
343
|
+
"z_data_1": kwargs["z"].data_1.numpy(),
|
|
344
|
+
"z_data_2_0": kwargs["z"].data_2[0].numpy(),
|
|
345
|
+
"z_data_2_1": kwargs["z"].data_2[1].numpy(),
|
|
346
|
+
}
|
|
347
|
+
self._compare_tflite_torch_args_kwargs(SampleModel(), args, kwargs, flat_inputs)
|
|
348
|
+
|
|
349
|
+
def test_convert_model_with_args_nested_kwargs_2(self):
|
|
350
|
+
"""
|
|
351
|
+
Test converting a simple model with both sample_args and nested sample_kwargs.
|
|
352
|
+
"""
|
|
353
|
+
|
|
354
|
+
class SampleModel(torch.nn.Module):
|
|
355
|
+
|
|
356
|
+
def forward(self, x, y, z):
|
|
357
|
+
return x + y + z.data_1 + z.data_2[0][0] + z.data_2[1]
|
|
358
|
+
|
|
359
|
+
args = (torch.randn(10, 10),)
|
|
360
|
+
kwargs = dict(
|
|
361
|
+
y=torch.randn(10, 10),
|
|
362
|
+
z=TestContainer1(
|
|
363
|
+
data_1=torch.randn(10, 10),
|
|
364
|
+
data_2=[(torch.randn(10, 10),), torch.randn(10, 10)],
|
|
365
|
+
),
|
|
366
|
+
)
|
|
367
|
+
flat_inputs = {
|
|
368
|
+
"args_0": args[0].numpy(),
|
|
369
|
+
"y": kwargs["y"].numpy(),
|
|
370
|
+
"z_data_1": kwargs["z"].data_1.numpy(),
|
|
371
|
+
"z_data_2_0_0": kwargs["z"].data_2[0][0].numpy(),
|
|
372
|
+
"z_data_2_1": kwargs["z"].data_2[1].numpy(),
|
|
373
|
+
}
|
|
374
|
+
self._compare_tflite_torch_args_kwargs(SampleModel(), args, kwargs, flat_inputs)
|
|
375
|
+
|
|
376
|
+
def test_convert_model_with_args_nested_kwargs_3(self):
|
|
377
|
+
"""
|
|
378
|
+
Test converting a simple model with both sample_args and nested sample_kwargs.
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
class SampleModel(torch.nn.Module):
|
|
382
|
+
|
|
383
|
+
def forward(self, x, y, z):
|
|
384
|
+
return x + y + z.data_1 + z.data_2[0]["foo"] + z.data_2[1]
|
|
385
|
+
|
|
386
|
+
args = (torch.randn(10, 10),)
|
|
387
|
+
kwargs = dict(
|
|
388
|
+
y=torch.randn(10, 10),
|
|
389
|
+
z=TestContainer1(
|
|
390
|
+
data_1=torch.randn(10, 10),
|
|
391
|
+
data_2=(dict(foo=torch.randn(10, 10)), torch.randn(10, 10)),
|
|
392
|
+
),
|
|
393
|
+
)
|
|
394
|
+
flat_inputs = {
|
|
395
|
+
"args_0": args[0].numpy(),
|
|
396
|
+
"y": kwargs["y"].numpy(),
|
|
397
|
+
"z_data_1": kwargs["z"].data_1.numpy(),
|
|
398
|
+
"z_data_2_0_foo": kwargs["z"].data_2[0]["foo"].numpy(),
|
|
399
|
+
"z_data_2_1": kwargs["z"].data_2[1].numpy(),
|
|
400
|
+
}
|
|
401
|
+
self._compare_tflite_torch_args_kwargs(SampleModel(), args, kwargs, flat_inputs)
|
|
402
|
+
|
|
403
|
+
def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
|
|
404
|
+
model.eval()
|
|
405
|
+
edge_model = ai_edge_torch.convert(model, args, kwargs)
|
|
406
|
+
interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
|
|
407
|
+
runner = interpreter.get_signature_runner("serving_default")
|
|
408
|
+
input_details = runner.get_input_details()
|
|
409
|
+
self.assertEqual(input_details.keys(), flat_inputs.keys())
|
|
410
|
+
|
|
411
|
+
reference_output = model(*args, **kwargs)
|
|
412
|
+
tflite_output = edge_model(**flat_inputs)
|
|
413
|
+
np.testing.assert_almost_equal(reference_output, tflite_output)
|
|
414
|
+
|
|
309
415
|
|
|
310
416
|
if __name__ == "__main__":
|
|
311
417
|
unittest.main()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.2.0.
|
|
3
|
+
Version: 0.2.0.dev20240719
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=CNDboRP4zQBpz2hznNCQWcQCARvNXUm3DMa1Dw_XXFg,106
|
|
|
2
2
|
ai_edge_torch/model.py,sha256=8Ba9ia7TCM_fciulw6qObmzdcxL3IaLQKDqpR7Lxp-Q,4440
|
|
3
3
|
ai_edge_torch/convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
4
4
|
ai_edge_torch/convert/conversion.py,sha256=StJHglvx6cii36oi8sj-tZda009e9UqR6ufZOZkP1SY,4137
|
|
5
|
-
ai_edge_torch/convert/conversion_utils.py,sha256=
|
|
5
|
+
ai_edge_torch/convert/conversion_utils.py,sha256=TA-fbRApU_wdZYg8VmQSGiH4sm70iITsLRBBi5vODTw,13813
|
|
6
6
|
ai_edge_torch/convert/converter.py,sha256=hSrW6A-kix9cjdD6CuLL7rseWrLKoV6GRy-iUSW_nZc,7875
|
|
7
7
|
ai_edge_torch/convert/to_channel_last_io.py,sha256=zo5tY3yDhY_EPCkrL1XSXs2uRFS8B4_qu08dSjNsUGk,2778
|
|
8
8
|
ai_edge_torch/convert/fx_passes/__init__.py,sha256=EPs4PSIDLuRH5EBETi6deaOvaaf_Q4xD3_9NVcR7x8o,2810
|
|
@@ -22,7 +22,7 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partition
|
|
|
22
22
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
|
|
23
23
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=lklGxE1R32vsjFbhLLBDEFL4pfLi_iTgI9Ftb6Grezk,7156
|
|
24
24
|
ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
25
|
-
ai_edge_torch/convert/test/test_convert.py,sha256=
|
|
25
|
+
ai_edge_torch/convert/test/test_convert.py,sha256=itOZDKsh0-0aoly1b1M72M179Yr2BJqtTe6ivueZSc4,12607
|
|
26
26
|
ai_edge_torch/convert/test/test_convert_composites.py,sha256=8UkdPtGkjgSVLCzB_rpM2FmwYuMyt6WE48umX_kr_Sg,7601
|
|
27
27
|
ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
|
|
28
28
|
ai_edge_torch/convert/test/test_to_channel_last_io.py,sha256=I8c4ZG3v1vo0yxQYzLK_BTId4AOL9vadHGDtfCUZ4UI,2930
|
|
@@ -114,8 +114,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
|
|
|
114
114
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
115
115
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
116
116
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
|
|
117
|
-
ai_edge_torch_nightly-0.2.0.
|
|
118
|
-
ai_edge_torch_nightly-0.2.0.
|
|
119
|
-
ai_edge_torch_nightly-0.2.0.
|
|
120
|
-
ai_edge_torch_nightly-0.2.0.
|
|
121
|
-
ai_edge_torch_nightly-0.2.0.
|
|
117
|
+
ai_edge_torch_nightly-0.2.0.dev20240719.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
118
|
+
ai_edge_torch_nightly-0.2.0.dev20240719.dist-info/METADATA,sha256=X9TaI_3Rxn0rk89P3ZcXJlNtEIUBOhOIIKAncN3Xpos,1745
|
|
119
|
+
ai_edge_torch_nightly-0.2.0.dev20240719.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
120
|
+
ai_edge_torch_nightly-0.2.0.dev20240719.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
121
|
+
ai_edge_torch_nightly-0.2.0.dev20240719.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|