ai-edge-torch-nightly 0.3.0.dev20241003__py3-none-any.whl → 0.3.0.dev20241005__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -75,9 +75,7 @@ class CLIP(nn.Module):
75
75
  )
76
76
 
77
77
  @torch.inference_mode
78
- def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
79
- tokens = tokens.type(torch.int)
80
-
78
+ def forward(self, tokens: torch.IntTensor) -> torch.FloatTensor:
81
79
  state = self.tok_embedding(tokens) + self.tok_embedding_position
82
80
  for layer in self.transformer_blocks:
83
81
  state = layer(state, mask=self.mask_cache)
@@ -13,47 +13,54 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import argparse
17
16
  import os
18
- from pathlib import Path
19
- from typing import Optional
17
+ import pathlib
20
18
 
19
+ from absl import app
20
+ from absl import flags
21
21
  import ai_edge_torch
22
- import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
23
- import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
24
- import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
25
- from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
26
- import ai_edge_torch.generative.examples.stable_diffusion.util as util
22
+ from ai_edge_torch.generative.examples.stable_diffusion import clip
23
+ from ai_edge_torch.generative.examples.stable_diffusion import decoder
24
+ from ai_edge_torch.generative.examples.stable_diffusion import diffusion
25
+ from ai_edge_torch.generative.examples.stable_diffusion import util
27
26
  from ai_edge_torch.generative.quantize import quant_recipes
28
- import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
27
+ from ai_edge_torch.generative.utilities import stable_diffusion_loader
29
28
  import torch
30
29
 
31
- arg_parser = argparse.ArgumentParser()
32
- arg_parser.add_argument(
33
- '--clip_ckpt',
34
- type=str,
30
+ _CLIP_CKPT = flags.DEFINE_string(
31
+ 'clip_ckpt',
32
+ None,
35
33
  help='Path to source CLIP model checkpoint',
36
34
  required=True,
37
35
  )
38
- arg_parser.add_argument(
39
- '--diffusion_ckpt',
40
- type=str,
36
+
37
+ _DIFFUSION_CKPT = flags.DEFINE_string(
38
+ 'diffusion_ckpt',
39
+ None,
41
40
  help='Path to source diffusion model checkpoint',
42
41
  required=True,
43
42
  )
44
- arg_parser.add_argument(
45
- '--decoder_ckpt',
46
- type=str,
43
+
44
+ _DECODER_CKPT = flags.DEFINE_string(
45
+ 'decoder_ckpt',
46
+ None,
47
47
  help='Path to source image decoder model checkpoint',
48
48
  required=True,
49
49
  )
50
- arg_parser.add_argument(
51
- '--output_dir',
52
- type=str,
50
+
51
+ _OUTPUT_DIR = flags.DEFINE_string(
52
+ 'output_dir',
53
+ None,
53
54
  help='Path to the converted TF Lite directory.',
54
55
  required=True,
55
56
  )
56
57
 
58
+ _QUANTIZE = flags.DEFINE_bool(
59
+ 'quantize',
60
+ help='Whether to quantize the model during conversion.',
61
+ default=True,
62
+ )
63
+
57
64
 
58
65
  @torch.inference_mode
59
66
  def convert_stable_diffusion_to_tflite(
@@ -111,7 +118,7 @@ def convert_stable_diffusion_to_tflite(
111
118
  time_embedding = util.get_time_embedding(timestamp)
112
119
 
113
120
  if not os.path.exists(output_dir):
114
- Path(output_dir).mkdir(parents=True, exist_ok=True)
121
+ pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
115
122
 
116
123
  quant_config = (
117
124
  quant_recipes.full_int8_weight_only_recipe() if quantize else None
@@ -142,14 +149,15 @@ def convert_stable_diffusion_to_tflite(
142
149
  ).export(f'{output_dir}/decoder.tflite')
143
150
 
144
151
 
145
- if __name__ == '__main__':
146
- args = arg_parser.parse_args()
152
+ def main(_):
147
153
  convert_stable_diffusion_to_tflite(
148
- output_dir=args.output_dir,
149
- clip_ckpt_path=args.clip_ckpt,
150
- diffusion_ckpt_path=args.diffusion_ckpt,
151
- decoder_ckpt_path=args.decoder_ckpt,
152
- image_height=512,
153
- image_width=512,
154
- quantize=True,
154
+ output_dir=_OUTPUT_DIR.value,
155
+ clip_ckpt_path=_CLIP_CKPT.value,
156
+ diffusion_ckpt_path=_DIFFUSION_CKPT.value,
157
+ decoder_ckpt_path=_DECODER_CKPT.value,
158
+ quantize=_QUANTIZE.value,
155
159
  )
160
+
161
+
162
+ if __name__ == '__main__':
163
+ app.run(main)
@@ -43,31 +43,40 @@ class TestModelConversion(googletest.TestCase):
43
43
  )
44
44
  )
45
45
 
46
- def _test_model_with_kv_cache(self, config, pytorch_model):
46
+ def _get_params(self, enable_hlfb: bool):
47
+ """Returns a model, edge model and the kwargs to use for testing."""
48
+ config = toy_model_with_kv_cache.get_model_config()
49
+ config.enable_hlfb = enable_hlfb
50
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
47
51
  tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
48
52
  [10], dtype=torch.int
49
53
  )
50
54
  kv = kv_cache.KVCache.from_model_config(config)
55
+ kwargs = {
56
+ "tokens": tokens,
57
+ "input_pos": input_pos,
58
+ "kv_cache": kv,
59
+ }
51
60
 
52
61
  edge_model = ai_edge_torch.convert(
53
62
  pytorch_model,
54
- sample_kwargs={
55
- "tokens": tokens,
56
- "input_pos": input_pos,
57
- "kv_cache": kv,
58
- },
63
+ sample_kwargs=kwargs,
59
64
  )
60
65
  edge_model.set_interpreter_builder(
61
66
  self._interpreter_builder(edge_model.tflite_model())
62
67
  )
68
+ return pytorch_model, edge_model, kwargs
69
+
70
+ def _test_model_with_kv_cache(self, enable_hlfb: bool):
71
+ pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb)
63
72
 
64
73
  self.assertTrue(
65
74
  test_utils.compare_tflite_torch(
66
75
  edge_model,
67
76
  pytorch_model,
68
- tokens,
69
- input_pos,
70
- kv,
77
+ kwargs["tokens"],
78
+ kwargs["input_pos"],
79
+ kwargs["kv_cache"],
71
80
  signature_name="serving_default",
72
81
  atol=1e-5,
73
82
  rtol=1e-5,
@@ -79,19 +88,31 @@ class TestModelConversion(googletest.TestCase):
79
88
  reason="tests with custom ops are not supported on oss",
80
89
  )
81
90
  def test_toy_model_with_kv_cache(self):
82
- config = toy_model_with_kv_cache.get_model_config()
83
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
84
- self._test_model_with_kv_cache(config, pytorch_model)
91
+ self._test_model_with_kv_cache(enable_hlfb=False)
85
92
 
86
93
  @googletest.skipIf(
87
94
  ai_edge_config.Config.use_torch_xla,
88
95
  reason="tests with custom ops are not supported on oss",
89
96
  )
90
97
  def test_toy_model_with_kv_cache_with_hlfb(self):
91
- config = toy_model_with_kv_cache.get_model_config()
92
- config.enable_hlfb = True
93
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
94
- self._test_model_with_kv_cache(config, pytorch_model)
98
+ self._test_model_with_kv_cache(enable_hlfb=True)
99
+
100
+ @googletest.skipIf(
101
+ ai_edge_config.Config.use_torch_xla,
102
+ reason="tests with custom ops are not supported on oss",
103
+ )
104
+ def test_toy_model_has_ekv_op(self):
105
+ """Tests that the model has the external kv cache op."""
106
+ _, edge_model, _ = self._get_params(enable_hlfb=True)
107
+ interpreter_ = interpreter.InterpreterWithCustomOps(
108
+ custom_op_registerers=["GenAIOpsRegisterer"],
109
+ model_content=edge_model.tflite_model(),
110
+ experimental_default_delegate_latest_features=True,
111
+ )
112
+
113
+ # pylint: disable=protected-access
114
+ op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
115
+ self.assertIn("odml.update_external_kv_cache", op_names)
95
116
 
96
117
  def _test_multisig_model(self, config, pytorch_model, atol, rtol):
97
118
  # prefill
@@ -156,8 +156,8 @@ def translate_to_ai_edge_recipe(
156
156
 
157
157
 
158
158
  def quantize_model(
159
- model: bytearray, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
159
+ model: bytes, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
160
160
  ) -> bytearray:
161
- qt = quantizer.Quantizer(bytearray(model), recipe)
161
+ qt = quantizer.Quantizer(model, recipe)
162
162
  result = qt.quantize()
163
163
  return result.quantized_model
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241003"
16
+ __version__ = "0.3.0.dev20241005"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241003
3
+ Version: 0.3.0.dev20241005
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=WKaZCocAyLb42oFdC07BQ6qpSfohXBwt-HKGV7S2fXw,706
6
+ ai_edge_torch/version.py,sha256=y5TOP0Z8qFsjIuJuJtSmzOUpHyTa9UH46RdJjtRWYQA,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -71,8 +71,8 @@ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=M5qAcSUE5gxOSfq24a8lZk
71
71
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
72
72
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
73
73
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
74
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=lwWrKY1NpnbvHQRenpltVN65QlzjWmSScl5CLSipBkc,6110
75
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
74
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
75
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=i9mcBITt4jJqKLA4Qdt3uFotCrglv14tPg8VnqsVnaI,5004
76
76
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=ClXNntmh0PF3s6U3C7SW3tyVrsSSrV2kyz-_RF4BcqA,15715
77
77
  ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=mBEAUYjV1qDJy9ZAsHtm9RGce0Mbzv0VoPZpdcQl1mk,33730
78
78
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
@@ -122,7 +122,7 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
122
122
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
123
123
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
124
124
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
125
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=-qB-JEIfPFNlpGyJA1TYo_5fawTdyf1C6ee8cP4kYOY,5530
125
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=a4TzSw8KMxEafirxqkykZi-WgTs5Z7wHp-J1AfjRDzA,6353
126
126
  ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bVCm_mubuGszCBON6oRjQXcBgPZqlVmmOaLWwhZJLio,9060
127
127
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
128
128
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
@@ -147,7 +147,7 @@ ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjF
147
147
  ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
148
148
  ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
149
149
  ai_edge_torch/lowertools/torch_xla_utils.py,sha256=S7RWzauts-15xP6VYuM3aAd9cyAGHstYD2A4dlv3d30,9059
150
- ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGPDpE0jvFCCYIzpnc,5667
150
+ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
151
151
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
152
152
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
153
153
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
@@ -180,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
180
180
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
181
181
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
182
182
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
183
- ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
- ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/METADATA,sha256=a6Q1LozCx-4NWkm1EKZJFeCJTYiTNUSigoVwRevV0oc,1897
185
- ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
- ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
- ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/RECORD,,
183
+ ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
+ ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/METADATA,sha256=O3P5ofz2aERMO1xbvIC7Z4RWsUNLJOZgn4pxEH3ftRc,1897
185
+ ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
+ ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
+ ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/RECORD,,