ai-edge-torch-nightly 0.3.0.dev20241003__py3-none-any.whl → 0.3.0.dev20241004__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/test/test_model_conversion.py +37 -16
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241004.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241004.dist-info}/RECORD +7 -7
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241004.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241004.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241004.dist-info}/top_level.txt +0 -0
@@ -43,31 +43,40 @@ class TestModelConversion(googletest.TestCase):
|
|
43
43
|
)
|
44
44
|
)
|
45
45
|
|
46
|
-
def
|
46
|
+
def _get_params(self, enable_hlfb: bool):
|
47
|
+
"""Returns a model, edge model and the kwargs to use for testing."""
|
48
|
+
config = toy_model_with_kv_cache.get_model_config()
|
49
|
+
config.enable_hlfb = enable_hlfb
|
50
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
47
51
|
tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
48
52
|
[10], dtype=torch.int
|
49
53
|
)
|
50
54
|
kv = kv_cache.KVCache.from_model_config(config)
|
55
|
+
kwargs = {
|
56
|
+
"tokens": tokens,
|
57
|
+
"input_pos": input_pos,
|
58
|
+
"kv_cache": kv,
|
59
|
+
}
|
51
60
|
|
52
61
|
edge_model = ai_edge_torch.convert(
|
53
62
|
pytorch_model,
|
54
|
-
sample_kwargs=
|
55
|
-
"tokens": tokens,
|
56
|
-
"input_pos": input_pos,
|
57
|
-
"kv_cache": kv,
|
58
|
-
},
|
63
|
+
sample_kwargs=kwargs,
|
59
64
|
)
|
60
65
|
edge_model.set_interpreter_builder(
|
61
66
|
self._interpreter_builder(edge_model.tflite_model())
|
62
67
|
)
|
68
|
+
return pytorch_model, edge_model, kwargs
|
69
|
+
|
70
|
+
def _test_model_with_kv_cache(self, enable_hlfb: bool):
|
71
|
+
pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb)
|
63
72
|
|
64
73
|
self.assertTrue(
|
65
74
|
test_utils.compare_tflite_torch(
|
66
75
|
edge_model,
|
67
76
|
pytorch_model,
|
68
|
-
tokens,
|
69
|
-
input_pos,
|
70
|
-
|
77
|
+
kwargs["tokens"],
|
78
|
+
kwargs["input_pos"],
|
79
|
+
kwargs["kv_cache"],
|
71
80
|
signature_name="serving_default",
|
72
81
|
atol=1e-5,
|
73
82
|
rtol=1e-5,
|
@@ -79,19 +88,31 @@ class TestModelConversion(googletest.TestCase):
|
|
79
88
|
reason="tests with custom ops are not supported on oss",
|
80
89
|
)
|
81
90
|
def test_toy_model_with_kv_cache(self):
|
82
|
-
|
83
|
-
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
84
|
-
self._test_model_with_kv_cache(config, pytorch_model)
|
91
|
+
self._test_model_with_kv_cache(enable_hlfb=False)
|
85
92
|
|
86
93
|
@googletest.skipIf(
|
87
94
|
ai_edge_config.Config.use_torch_xla,
|
88
95
|
reason="tests with custom ops are not supported on oss",
|
89
96
|
)
|
90
97
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
98
|
+
self._test_model_with_kv_cache(enable_hlfb=True)
|
99
|
+
|
100
|
+
@googletest.skipIf(
|
101
|
+
ai_edge_config.Config.use_torch_xla,
|
102
|
+
reason="tests with custom ops are not supported on oss",
|
103
|
+
)
|
104
|
+
def test_toy_model_has_ekv_op(self):
|
105
|
+
"""Tests that the model has the external kv cache op."""
|
106
|
+
_, edge_model, _ = self._get_params(enable_hlfb=True)
|
107
|
+
interpreter_ = interpreter.InterpreterWithCustomOps(
|
108
|
+
custom_op_registerers=["GenAIOpsRegisterer"],
|
109
|
+
model_content=edge_model.tflite_model(),
|
110
|
+
experimental_default_delegate_latest_features=True,
|
111
|
+
)
|
112
|
+
|
113
|
+
# pylint: disable=protected-access
|
114
|
+
op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
|
115
|
+
self.assertIn("odml.update_external_kv_cache", op_names)
|
95
116
|
|
96
117
|
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
|
97
118
|
# prefill
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241004
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=tIC9MEJewU0lAFO_930WizESB627b7x4xfE3qbYWtLw,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -122,7 +122,7 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
|
|
122
122
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
123
123
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
124
124
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
125
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256
|
125
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=a4TzSw8KMxEafirxqkykZi-WgTs5Z7wHp-J1AfjRDzA,6353
|
126
126
|
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bVCm_mubuGszCBON6oRjQXcBgPZqlVmmOaLWwhZJLio,9060
|
127
127
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
128
128
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
@@ -180,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
180
180
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
181
181
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
182
182
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
183
|
-
ai_edge_torch_nightly-0.3.0.
|
184
|
-
ai_edge_torch_nightly-0.3.0.
|
185
|
-
ai_edge_torch_nightly-0.3.0.
|
186
|
-
ai_edge_torch_nightly-0.3.0.
|
187
|
-
ai_edge_torch_nightly-0.3.0.
|
183
|
+
ai_edge_torch_nightly-0.3.0.dev20241004.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
184
|
+
ai_edge_torch_nightly-0.3.0.dev20241004.dist-info/METADATA,sha256=LZEnnjuiIFRFASjn-R5mEPu8juBMx7ZvLgbGZuv9CQw,1897
|
185
|
+
ai_edge_torch_nightly-0.3.0.dev20241004.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
186
|
+
ai_edge_torch_nightly-0.3.0.dev20241004.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
187
|
+
ai_edge_torch_nightly-0.3.0.dev20241004.dist-info/RECORD,,
|
File without changes
|
File without changes
|