ai-edge-torch-nightly 0.3.0.dev20241003__py3-none-any.whl → 0.3.0.dev20241005__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -75,9 +75,7 @@ class CLIP(nn.Module):
75
75
  )
76
76
 
77
77
  @torch.inference_mode
78
- def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor:
79
- tokens = tokens.type(torch.int)
80
-
78
+ def forward(self, tokens: torch.IntTensor) -> torch.FloatTensor:
81
79
  state = self.tok_embedding(tokens) + self.tok_embedding_position
82
80
  for layer in self.transformer_blocks:
83
81
  state = layer(state, mask=self.mask_cache)
@@ -13,47 +13,54 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import argparse
17
16
  import os
18
- from pathlib import Path
19
- from typing import Optional
17
+ import pathlib
20
18
 
19
+ from absl import app
20
+ from absl import flags
21
21
  import ai_edge_torch
22
- import ai_edge_torch.generative.examples.stable_diffusion.clip as clip
23
- import ai_edge_torch.generative.examples.stable_diffusion.decoder as decoder
24
- import ai_edge_torch.generative.examples.stable_diffusion.diffusion as diffusion
25
- from ai_edge_torch.generative.examples.stable_diffusion.encoder import Encoder
26
- import ai_edge_torch.generative.examples.stable_diffusion.util as util
22
+ from ai_edge_torch.generative.examples.stable_diffusion import clip
23
+ from ai_edge_torch.generative.examples.stable_diffusion import decoder
24
+ from ai_edge_torch.generative.examples.stable_diffusion import diffusion
25
+ from ai_edge_torch.generative.examples.stable_diffusion import util
27
26
  from ai_edge_torch.generative.quantize import quant_recipes
28
- import ai_edge_torch.generative.utilities.stable_diffusion_loader as stable_diffusion_loader
27
+ from ai_edge_torch.generative.utilities import stable_diffusion_loader
29
28
  import torch
30
29
 
31
- arg_parser = argparse.ArgumentParser()
32
- arg_parser.add_argument(
33
- '--clip_ckpt',
34
- type=str,
30
+ _CLIP_CKPT = flags.DEFINE_string(
31
+ 'clip_ckpt',
32
+ None,
35
33
  help='Path to source CLIP model checkpoint',
36
34
  required=True,
37
35
  )
38
- arg_parser.add_argument(
39
- '--diffusion_ckpt',
40
- type=str,
36
+
37
+ _DIFFUSION_CKPT = flags.DEFINE_string(
38
+ 'diffusion_ckpt',
39
+ None,
41
40
  help='Path to source diffusion model checkpoint',
42
41
  required=True,
43
42
  )
44
- arg_parser.add_argument(
45
- '--decoder_ckpt',
46
- type=str,
43
+
44
+ _DECODER_CKPT = flags.DEFINE_string(
45
+ 'decoder_ckpt',
46
+ None,
47
47
  help='Path to source image decoder model checkpoint',
48
48
  required=True,
49
49
  )
50
- arg_parser.add_argument(
51
- '--output_dir',
52
- type=str,
50
+
51
+ _OUTPUT_DIR = flags.DEFINE_string(
52
+ 'output_dir',
53
+ None,
53
54
  help='Path to the converted TF Lite directory.',
54
55
  required=True,
55
56
  )
56
57
 
58
+ _QUANTIZE = flags.DEFINE_bool(
59
+ 'quantize',
60
+ help='Whether to quantize the model during conversion.',
61
+ default=True,
62
+ )
63
+
57
64
 
58
65
  @torch.inference_mode
59
66
  def convert_stable_diffusion_to_tflite(
@@ -111,7 +118,7 @@ def convert_stable_diffusion_to_tflite(
111
118
  time_embedding = util.get_time_embedding(timestamp)
112
119
 
113
120
  if not os.path.exists(output_dir):
114
- Path(output_dir).mkdir(parents=True, exist_ok=True)
121
+ pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True)
115
122
 
116
123
  quant_config = (
117
124
  quant_recipes.full_int8_weight_only_recipe() if quantize else None
@@ -142,14 +149,15 @@ def convert_stable_diffusion_to_tflite(
142
149
  ).export(f'{output_dir}/decoder.tflite')
143
150
 
144
151
 
145
- if __name__ == '__main__':
146
- args = arg_parser.parse_args()
152
+ def main(_):
147
153
  convert_stable_diffusion_to_tflite(
148
- output_dir=args.output_dir,
149
- clip_ckpt_path=args.clip_ckpt,
150
- diffusion_ckpt_path=args.diffusion_ckpt,
151
- decoder_ckpt_path=args.decoder_ckpt,
152
- image_height=512,
153
- image_width=512,
154
- quantize=True,
154
+ output_dir=_OUTPUT_DIR.value,
155
+ clip_ckpt_path=_CLIP_CKPT.value,
156
+ diffusion_ckpt_path=_DIFFUSION_CKPT.value,
157
+ decoder_ckpt_path=_DECODER_CKPT.value,
158
+ quantize=_QUANTIZE.value,
155
159
  )
160
+
161
+
162
+ if __name__ == '__main__':
163
+ app.run(main)
@@ -43,31 +43,40 @@ class TestModelConversion(googletest.TestCase):
43
43
  )
44
44
  )
45
45
 
46
- def _test_model_with_kv_cache(self, config, pytorch_model):
46
+ def _get_params(self, enable_hlfb: bool):
47
+ """Returns a model, edge model and the kwargs to use for testing."""
48
+ config = toy_model_with_kv_cache.get_model_config()
49
+ config.enable_hlfb = enable_hlfb
50
+ pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
47
51
  tokens, input_pos = torch.tensor([[1]], dtype=torch.int), torch.tensor(
48
52
  [10], dtype=torch.int
49
53
  )
50
54
  kv = kv_cache.KVCache.from_model_config(config)
55
+ kwargs = {
56
+ "tokens": tokens,
57
+ "input_pos": input_pos,
58
+ "kv_cache": kv,
59
+ }
51
60
 
52
61
  edge_model = ai_edge_torch.convert(
53
62
  pytorch_model,
54
- sample_kwargs={
55
- "tokens": tokens,
56
- "input_pos": input_pos,
57
- "kv_cache": kv,
58
- },
63
+ sample_kwargs=kwargs,
59
64
  )
60
65
  edge_model.set_interpreter_builder(
61
66
  self._interpreter_builder(edge_model.tflite_model())
62
67
  )
68
+ return pytorch_model, edge_model, kwargs
69
+
70
+ def _test_model_with_kv_cache(self, enable_hlfb: bool):
71
+ pytorch_model, edge_model, kwargs = self._get_params(enable_hlfb)
63
72
 
64
73
  self.assertTrue(
65
74
  test_utils.compare_tflite_torch(
66
75
  edge_model,
67
76
  pytorch_model,
68
- tokens,
69
- input_pos,
70
- kv,
77
+ kwargs["tokens"],
78
+ kwargs["input_pos"],
79
+ kwargs["kv_cache"],
71
80
  signature_name="serving_default",
72
81
  atol=1e-5,
73
82
  rtol=1e-5,
@@ -79,19 +88,31 @@ class TestModelConversion(googletest.TestCase):
79
88
  reason="tests with custom ops are not supported on oss",
80
89
  )
81
90
  def test_toy_model_with_kv_cache(self):
82
- config = toy_model_with_kv_cache.get_model_config()
83
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
84
- self._test_model_with_kv_cache(config, pytorch_model)
91
+ self._test_model_with_kv_cache(enable_hlfb=False)
85
92
 
86
93
  @googletest.skipIf(
87
94
  ai_edge_config.Config.use_torch_xla,
88
95
  reason="tests with custom ops are not supported on oss",
89
96
  )
90
97
  def test_toy_model_with_kv_cache_with_hlfb(self):
91
- config = toy_model_with_kv_cache.get_model_config()
92
- config.enable_hlfb = True
93
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKVCache(config).eval()
94
- self._test_model_with_kv_cache(config, pytorch_model)
98
+ self._test_model_with_kv_cache(enable_hlfb=True)
99
+
100
+ @googletest.skipIf(
101
+ ai_edge_config.Config.use_torch_xla,
102
+ reason="tests with custom ops are not supported on oss",
103
+ )
104
+ def test_toy_model_has_ekv_op(self):
105
+ """Tests that the model has the external kv cache op."""
106
+ _, edge_model, _ = self._get_params(enable_hlfb=True)
107
+ interpreter_ = interpreter.InterpreterWithCustomOps(
108
+ custom_op_registerers=["GenAIOpsRegisterer"],
109
+ model_content=edge_model.tflite_model(),
110
+ experimental_default_delegate_latest_features=True,
111
+ )
112
+
113
+ # pylint: disable=protected-access
114
+ op_names = [op["op_name"] for op in interpreter_._get_ops_details()]
115
+ self.assertIn("odml.update_external_kv_cache", op_names)
95
116
 
96
117
  def _test_multisig_model(self, config, pytorch_model, atol, rtol):
97
118
  # prefill
@@ -156,8 +156,8 @@ def translate_to_ai_edge_recipe(
156
156
 
157
157
 
158
158
  def quantize_model(
159
- model: bytearray, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
159
+ model: bytes, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
160
160
  ) -> bytearray:
161
- qt = quantizer.Quantizer(bytearray(model), recipe)
161
+ qt = quantizer.Quantizer(model, recipe)
162
162
  result = qt.quantize()
163
163
  return result.quantized_model
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241003"
16
+ __version__ = "0.3.0.dev20241005"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241003
3
+ Version: 0.3.0.dev20241005
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=WKaZCocAyLb42oFdC07BQ6qpSfohXBwt-HKGV7S2fXw,706
6
+ ai_edge_torch/version.py,sha256=y5TOP0Z8qFsjIuJuJtSmzOUpHyTa9UH46RdJjtRWYQA,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -71,8 +71,8 @@ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=M5qAcSUE5gxOSfq24a8lZk
71
71
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
72
72
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
73
73
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
74
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=lwWrKY1NpnbvHQRenpltVN65QlzjWmSScl5CLSipBkc,6110
75
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
74
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
75
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=i9mcBITt4jJqKLA4Qdt3uFotCrglv14tPg8VnqsVnaI,5004
76
76
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=ClXNntmh0PF3s6U3C7SW3tyVrsSSrV2kyz-_RF4BcqA,15715
77
77
  ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=mBEAUYjV1qDJy9ZAsHtm9RGce0Mbzv0VoPZpdcQl1mk,33730
78
78
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
@@ -122,7 +122,7 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
122
122
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
123
123
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
124
124
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
125
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=-qB-JEIfPFNlpGyJA1TYo_5fawTdyf1C6ee8cP4kYOY,5530
125
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=a4TzSw8KMxEafirxqkykZi-WgTs5Z7wHp-J1AfjRDzA,6353
126
126
  ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=bVCm_mubuGszCBON6oRjQXcBgPZqlVmmOaLWwhZJLio,9060
127
127
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
128
128
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
@@ -147,7 +147,7 @@ ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjF
147
147
  ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
148
148
  ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
149
149
  ai_edge_torch/lowertools/torch_xla_utils.py,sha256=S7RWzauts-15xP6VYuM3aAd9cyAGHstYD2A4dlv3d30,9059
150
- ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGPDpE0jvFCCYIzpnc,5667
150
+ ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
151
151
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
152
152
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
153
153
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
@@ -180,8 +180,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
180
180
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
181
181
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
182
182
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
183
- ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
- ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/METADATA,sha256=a6Q1LozCx-4NWkm1EKZJFeCJTYiTNUSigoVwRevV0oc,1897
185
- ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
- ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
- ai_edge_torch_nightly-0.3.0.dev20241003.dist-info/RECORD,,
183
+ ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
184
+ ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/METADATA,sha256=O3P5ofz2aERMO1xbvIC7Z4RWsUNLJOZgn4pxEH3ftRc,1897
185
+ ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
186
+ ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
187
+ ai_edge_torch_nightly-0.3.0.dev20241005.dist-info/RECORD,,