ai-edge-torch-nightly 0.2.0.dev20240718__py3-none-any.whl → 0.2.0.dev20240719__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -20,7 +20,7 @@ import gc
20
20
  import itertools
21
21
  import logging
22
22
  import tempfile
23
- from typing import Any, Dict, Optional, Tuple, Union
23
+ from typing import Any, Dict, List, Optional, Tuple, Union
24
24
 
25
25
  import torch
26
26
  import torch.utils._pytree as pytree
@@ -79,28 +79,49 @@ class Signature:
79
79
  for i in range(args_spec.num_leaves):
80
80
  names.append(f"args_{i}")
81
81
 
82
- dict_context = (
83
- kwargs_spec.context
84
- if kwargs_spec.type is not collections.defaultdict
85
- # ignore mismatch of `default_factory` for defaultdict
86
- else kwargs_spec.context[1]
82
+ kwargs_names = self._flat_kwarg_names(
83
+ kwargs_spec.children_specs, kwargs_spec.context
87
84
  )
85
+ names.extend(kwargs_names)
86
+ return names
88
87
 
89
- for name, value_spec in zip(dict_context, kwargs_spec.children_specs):
90
- if value_spec.num_leaves == 1:
91
- names.append(name)
88
+ def _flat_kwarg_names(self, specs, context) -> List[str]:
89
+ flat_names = []
90
+ if context is None:
91
+ for i, spec in enumerate(specs):
92
+ if spec.children_specs:
93
+ flat_names.extend(
94
+ [
95
+ f"{i}_{name}"
96
+ for name in self._flat_kwarg_names(spec.children_specs, spec.context)
97
+ ]
98
+ )
99
+ else:
100
+ flat_names.append(f"{i}")
101
+ else:
102
+ flat_ctx = self._flatten_list(context)
103
+ for prefix, spec in zip(flat_ctx, specs):
104
+ leaf_flat_names = self._flat_kwarg_names(spec.children_specs, spec.context)
105
+ if leaf_flat_names:
106
+ flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
107
+ else:
108
+ flat_names.append(prefix)
109
+
110
+ return flat_names
111
+
112
+ def _flatten_list(self, l: List) -> List:
113
+ flattened = []
114
+ for item in l:
115
+ if isinstance(item, list):
116
+ flattened.extend(self._flatten_list(item))
92
117
  else:
93
- # value_spec.num_leaves may be greater than 1 when the value is a (nested)
94
- # tuple of tensors. We haven't decided how we should support flattenable
95
- # tensor containers as inputs.
96
- # TODO(b/352584188): Decide the behavior of tensor container as input (flatten or reject)
97
- for i in range(value_spec.num_leaves):
98
- names.append(f"{name}_{i}")
99
- return names
118
+ flattened.append(item)
119
+ return flattened
100
120
 
101
121
  @property
102
- def flat_args(self) -> tuple[torch.Tensor]:
103
- return tuple(pytree.tree_flatten(self._normalized_sample_args_kwargs)[0])
122
+ def flat_args(self) -> tuple[Any]:
123
+ args, kwargs = self._normalized_sample_args_kwargs
124
+ return tuple([*args, *kwargs.values()])
104
125
 
105
126
 
106
127
  def exported_program_to_stablehlo_bundle(
@@ -14,10 +14,14 @@
14
14
  # ==============================================================================
15
15
 
16
16
 
17
+ from dataclasses import dataclass
17
18
  import os
18
19
  import tempfile
20
+ from typing import Tuple
19
21
  import unittest
20
22
 
23
+ import numpy as np
24
+ import tensorflow as tf
21
25
  import torch
22
26
  import torchvision
23
27
 
@@ -26,6 +30,15 @@ from ai_edge_torch.convert import conversion_utils as cutils
26
30
  from ai_edge_torch.testing import model_coverage
27
31
 
28
32
 
33
+ @dataclass
34
+ class TestContainer1:
35
+ data_1: torch.Tensor
36
+ data_2: Tuple[torch.Tensor, torch.Tensor]
37
+
38
+
39
+ torch.export.register_dataclass(TestContainer1, serialized_type_name="TestContainer1")
40
+
41
+
29
42
  class TestConvert(unittest.TestCase):
30
43
  """Tests conversion of various modules."""
31
44
 
@@ -306,6 +319,99 @@ class TestConvert(unittest.TestCase):
306
319
  model_coverage.compare_tflite_torch(edge_model, model, args_gen, kwargs_gen)
307
320
  )
308
321
 
322
+ def test_convert_model_with_args_nested_kwargs_1(self):
323
+ """
324
+ Test converting a simple model with both sample_args and nested sample_kwargs.
325
+ """
326
+
327
+ class SampleModel(torch.nn.Module):
328
+
329
+ def forward(self, x: torch.Tensor, y: torch.Tensor, z: TestContainer1):
330
+ return x + y + z.data_1 + z.data_2[0] + z.data_2[1]
331
+
332
+ args = (torch.randn(10, 10),)
333
+ kwargs = dict(
334
+ y=torch.randn(10, 10),
335
+ z=TestContainer1(
336
+ data_1=torch.randn(10, 10),
337
+ data_2=(torch.randn(10, 10), torch.randn(10, 10)),
338
+ ),
339
+ )
340
+ flat_inputs = {
341
+ "args_0": args[0].numpy(),
342
+ "y": kwargs["y"].numpy(),
343
+ "z_data_1": kwargs["z"].data_1.numpy(),
344
+ "z_data_2_0": kwargs["z"].data_2[0].numpy(),
345
+ "z_data_2_1": kwargs["z"].data_2[1].numpy(),
346
+ }
347
+ self._compare_tflite_torch_args_kwargs(SampleModel(), args, kwargs, flat_inputs)
348
+
349
+ def test_convert_model_with_args_nested_kwargs_2(self):
350
+ """
351
+ Test converting a simple model with both sample_args and nested sample_kwargs.
352
+ """
353
+
354
+ class SampleModel(torch.nn.Module):
355
+
356
+ def forward(self, x, y, z):
357
+ return x + y + z.data_1 + z.data_2[0][0] + z.data_2[1]
358
+
359
+ args = (torch.randn(10, 10),)
360
+ kwargs = dict(
361
+ y=torch.randn(10, 10),
362
+ z=TestContainer1(
363
+ data_1=torch.randn(10, 10),
364
+ data_2=[(torch.randn(10, 10),), torch.randn(10, 10)],
365
+ ),
366
+ )
367
+ flat_inputs = {
368
+ "args_0": args[0].numpy(),
369
+ "y": kwargs["y"].numpy(),
370
+ "z_data_1": kwargs["z"].data_1.numpy(),
371
+ "z_data_2_0_0": kwargs["z"].data_2[0][0].numpy(),
372
+ "z_data_2_1": kwargs["z"].data_2[1].numpy(),
373
+ }
374
+ self._compare_tflite_torch_args_kwargs(SampleModel(), args, kwargs, flat_inputs)
375
+
376
+ def test_convert_model_with_args_nested_kwargs_3(self):
377
+ """
378
+ Test converting a simple model with both sample_args and nested sample_kwargs.
379
+ """
380
+
381
+ class SampleModel(torch.nn.Module):
382
+
383
+ def forward(self, x, y, z):
384
+ return x + y + z.data_1 + z.data_2[0]["foo"] + z.data_2[1]
385
+
386
+ args = (torch.randn(10, 10),)
387
+ kwargs = dict(
388
+ y=torch.randn(10, 10),
389
+ z=TestContainer1(
390
+ data_1=torch.randn(10, 10),
391
+ data_2=(dict(foo=torch.randn(10, 10)), torch.randn(10, 10)),
392
+ ),
393
+ )
394
+ flat_inputs = {
395
+ "args_0": args[0].numpy(),
396
+ "y": kwargs["y"].numpy(),
397
+ "z_data_1": kwargs["z"].data_1.numpy(),
398
+ "z_data_2_0_foo": kwargs["z"].data_2[0]["foo"].numpy(),
399
+ "z_data_2_1": kwargs["z"].data_2[1].numpy(),
400
+ }
401
+ self._compare_tflite_torch_args_kwargs(SampleModel(), args, kwargs, flat_inputs)
402
+
403
+ def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
404
+ model.eval()
405
+ edge_model = ai_edge_torch.convert(model, args, kwargs)
406
+ interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
407
+ runner = interpreter.get_signature_runner("serving_default")
408
+ input_details = runner.get_input_details()
409
+ self.assertEqual(input_details.keys(), flat_inputs.keys())
410
+
411
+ reference_output = model(*args, **kwargs)
412
+ tflite_output = edge_model(**flat_inputs)
413
+ np.testing.assert_almost_equal(reference_output, tflite_output)
414
+
309
415
 
310
416
  if __name__ == "__main__":
311
417
  unittest.main()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240718
3
+ Version: 0.2.0.dev20240719
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=CNDboRP4zQBpz2hznNCQWcQCARvNXUm3DMa1Dw_XXFg,106
2
2
  ai_edge_torch/model.py,sha256=8Ba9ia7TCM_fciulw6qObmzdcxL3IaLQKDqpR7Lxp-Q,4440
3
3
  ai_edge_torch/convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
4
4
  ai_edge_torch/convert/conversion.py,sha256=StJHglvx6cii36oi8sj-tZda009e9UqR6ufZOZkP1SY,4137
5
- ai_edge_torch/convert/conversion_utils.py,sha256=PKXIlSCU-8DhppNBh9ICDNUlEOpV0HgCbt85jDVe3rA,13394
5
+ ai_edge_torch/convert/conversion_utils.py,sha256=TA-fbRApU_wdZYg8VmQSGiH4sm70iITsLRBBi5vODTw,13813
6
6
  ai_edge_torch/convert/converter.py,sha256=hSrW6A-kix9cjdD6CuLL7rseWrLKoV6GRy-iUSW_nZc,7875
7
7
  ai_edge_torch/convert/to_channel_last_io.py,sha256=zo5tY3yDhY_EPCkrL1XSXs2uRFS8B4_qu08dSjNsUGk,2778
8
8
  ai_edge_torch/convert/fx_passes/__init__.py,sha256=EPs4PSIDLuRH5EBETi6deaOvaaf_Q4xD3_9NVcR7x8o,2810
@@ -22,7 +22,7 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partition
22
22
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
23
23
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=lklGxE1R32vsjFbhLLBDEFL4pfLi_iTgI9Ftb6Grezk,7156
24
24
  ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
25
- ai_edge_torch/convert/test/test_convert.py,sha256=h0vOffr8saDQRkiXljNWDZ17EBjnS4xAtxd8DxETleY,9081
25
+ ai_edge_torch/convert/test/test_convert.py,sha256=itOZDKsh0-0aoly1b1M72M179Yr2BJqtTe6ivueZSc4,12607
26
26
  ai_edge_torch/convert/test/test_convert_composites.py,sha256=8UkdPtGkjgSVLCzB_rpM2FmwYuMyt6WE48umX_kr_Sg,7601
27
27
  ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
28
28
  ai_edge_torch/convert/test/test_to_channel_last_io.py,sha256=I8c4ZG3v1vo0yxQYzLK_BTId4AOL9vadHGDtfCUZ4UI,2930
@@ -114,8 +114,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDd
114
114
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
115
115
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
116
116
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
117
- ai_edge_torch_nightly-0.2.0.dev20240718.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
118
- ai_edge_torch_nightly-0.2.0.dev20240718.dist-info/METADATA,sha256=r8YZWPZEhL5gi1oIR9sDZppTSZIuxeHH5isLO4NiSj8,1745
119
- ai_edge_torch_nightly-0.2.0.dev20240718.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
120
- ai_edge_torch_nightly-0.2.0.dev20240718.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
121
- ai_edge_torch_nightly-0.2.0.dev20240718.dist-info/RECORD,,
117
+ ai_edge_torch_nightly-0.2.0.dev20240719.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
118
+ ai_edge_torch_nightly-0.2.0.dev20240719.dist-info/METADATA,sha256=X9TaI_3Rxn0rk89P3ZcXJlNtEIUBOhOIIKAncN3Xpos,1745
119
+ ai_edge_torch_nightly-0.2.0.dev20240719.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
120
+ ai_edge_torch_nightly-0.2.0.dev20240719.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
121
+ ai_edge_torch_nightly-0.2.0.dev20240719.dist-info/RECORD,,