ai-edge-torch-nightly 0.3.0.dev20241003__py3-none-any.whl → 0.3.0.dev20241004__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/test/test_model_conversion.py +37 -16
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241004.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241004.dist-info}/RECORD +7 -7
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241004.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241004.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241003.dist-info → ai_edge_torch_nightly-0.3.0.dev20241004.dist-info}/top_level.txt +0 -0
@@ -43,31 +43,40 @@ class TestModelConversion(googletest.TestCase):
|
|
43
43
|
)
|
44
44
|
)
|
45
45
|
|
46
|
-
def
|
46
|
+
def _get_params(self, enable_hlfb: bool):
|
47
|
+
"""Returns a model, edge model and the kwargs to use for testing."""
|
48
|
+
config = toy_model_with_kv_cache.get_model_config()
|
49
|
+
config.enable_hlfb = enable_hlfb
|
50
|
+
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
47
51
|
tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
|
48
52
|
[10], dtype=torch.int
|
49
53
|
)
|
50
54
|
kv = kv_cache.KVCache.from_model_config(config)
|
55
|
+
kwargs = {
|
56
|
+
"tokens": tokens,
|
57
|
+
"input_pos": input_pos,
|
58
|
+
"kv_cache": kv,
|
59
|
+
}
|
51
60
|
|
52
61
|
edge_model = ai_edge_torch.convert(
|
53
62
|
pytorch_model,
|
54
|
-
sample_kwargs=
|
55
|
-
"tokens": tokens,
|
56
|
-
"input_pos": input_pos,
|
57
|
-
"kv_cache": kv,
|
58
|
-
},
|
63
|
+
sample_kwargs=kwargs,
|
59
64
|
)
|
60
65
|
edge_model.set_interpreter_builder(
|
61
66
|
self._interpreter_builder(edge_model.tflite_model())
|
62
67
|
)
|
68
|
+
return pytorch_model, edge_model, kwargs
|
69
|
+
|
70
|
+
def _test_model_with_kv_cache(self, enable_hlfb: bool):
|
71
|
+
pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb)
|
63
72
|
|
64
73
|
self.assertTrue(
|
65
74
|
test_utils.compare_tflite_torch(
|
66
75
|
edge_model,
|
67
76
|
pytorch_model,
|
68
|
-
tokens,
|
69
|
-
input_pos,
|
70
|
-
|
77
|
+
kwargs["tokens"],
|
78
|
+
kwargs["input_pos"],
|
79
|
+
kwargs["kv_cache"],
|
71
80
|
signature_name="serving_default",
|
72
81
|
atol=1e-5,
|
73
82
|
rtol=1e-5,
|
@@ -79,19 +88,31 @@ class TestModelConversion(googletest.TestCase):
|
|
79
88
|
reason="tests with custom ops are not supported on oss",
|
80
89
|
)
|
81
90
|
def test_toy_model_with_kv_cache(self):
|
82
|
-
|
83
|
-
pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
|
84
|
-
self._test_model_with_kv_cache(config, pytorch_model)
|
91
|
+
self._test_model_with_kv_cache(enable_hlfb=False)
|
85
92
|
|
86
93
|
@googletest.skipIf(
|
87
94
|
ai_edge_config.Config.use_torch_xla,
|
88
95
|
reason="tests with custom ops are not supported on oss",
|
89
96
|
)
|
90
97
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
98
|
+
self._test_model_with_kv_cache(enable_hlfb=True)
|
99
|
+
|
100
|
+
@googletest.skipIf(
|
101
|
+
ai_edge_config.Config.use_torch_xla,
|
102
|
+
reason="tests with custom ops are not supported on oss",
|
103
|
+
)
|
104
|
+
def test_toy_model_has_ekv_op(self):
|
105
|
+
"""Tests that the model has the external kv cache op."""
|
106
|
+
_, edge_model, _ = self._get_params(enable_hlfb=True)
|
107
|
+
interpreter_ = interpreter.InterpreterWithCustomOps(
|
108
|
+
custom_op_registerers=["GenAIOpsRegisterer"],
|
109
|
+
model_content=edge_model.tflite_model(),
|
110
|
+
experimental_default_delegate_latest_features=True,
|
111
|
+
)
|
112
|
+
|
113
|
+
# pylint: disable=protected-access
|
114
|
+
op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
|
115
|
+
self.assertIn("odml.update_external_kv_cache", op_names)
|
95
116
|
|
96
117
|
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
|
97
118
|
# prefill
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241004
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=tIC9MEJewU0lAFO_930WizESB627b7x4xfE3qbYWtLw,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -122,7 +122,7 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
|
|
122
122
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
123
123
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
124
124
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
125
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256
|
125
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=a4TzSw8KMxEafirxqkykZi-WgTs5Z7wHp-J1AfjRDzA,6353
|
126
126
|
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bVCm_mubuGszCBON6oRjQXcBgPZqlVmmOaLWwhZJLio,9060
|
127
127
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
128
128
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
@@ -180,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
180
180
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
181
181
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
182
182
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
183
|
-
ai_edge_torch_nightly-0.3.0.
|
184
|
-
ai_edge_torch_nightly-0.3.0.
|
185
|
-
ai_edge_torch_nightly-0.3.0.
|
186
|
-
ai_edge_torch_nightly-0.3.0.
|
187
|
-
ai_edge_torch_nightly-0.3.0.
|
183
|
+
ai_edge_torch_nightly-0.3.0.dev20241004.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
184
|
+
ai_edge_torch_nightly-0.3.0.dev20241004.dist-info/METADATA,sha256=LZEnnjuiIFRFASjn-R5mEPu8juBMx7ZvLgbGZuv9CQw,1897
|
185
|
+
ai_edge_torch_nightly-0.3.0.dev20241004.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
186
|
+
ai_edge_torch_nightly-0.3.0.dev20241004.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
187
|
+
ai_edge_torch_nightly-0.3.0.dev20241004.dist-info/RECORD,,
|
File without changes
|
File without changes
|