ai-edge-torch-nightly 0.3.0.dev20241027__py3-none-any.whl → 0.3.0.dev20241029__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -15,7 +15,9 @@
15
15
 
16
16
  import ai_edge_torch
17
17
  from ai_edge_torch.generative.examples.gemma import gemma1
18
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
18
19
  from ai_edge_torch.generative.quantize import quant_recipes
20
+ from ai_edge_torch.generative.utilities import model_builder
19
21
  import numpy as np
20
22
  import torch
21
23
 
@@ -23,11 +25,12 @@ import torch
23
25
  def main():
24
26
  # Build a PyTorch model as usual
25
27
  config = gemma1.get_fake_model_config()
26
- model = gemma1.Gemma(config)
28
+ model = model_builder.DecoderOnlyModel(config).eval()
27
29
  idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
28
30
  tokens = torch.full((1, 10), 0, dtype=torch.int, device="cpu")
29
31
  tokens[0, :4] = idx
30
32
  input_pos = torch.arange(0, 10, dtype=torch.int)
33
+ kv = kv_utils.KVCache.from_model_config(config)
31
34
 
32
35
  # Create a quantization recipe to be applied to the model
33
36
  quant_config = quant_recipes.full_int8_dynamic_recipe()
@@ -35,7 +38,7 @@ def main():
35
38
 
36
39
  # Convert with quantization
37
40
  edge_model = ai_edge_torch.convert(
38
- model, (tokens, input_pos), quant_config=quant_config
41
+ model, (tokens, input_pos, kv), quant_config=quant_config
39
42
  )
40
43
  edge_model.export("/tmp/gemma_2b_quantized.tflite")
41
44
 
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241027"
16
+ __version__ = "0.3.0.dev20241029"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241027
3
+ Version: 0.3.0.dev20241029
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=SrYveglaiA_DXPoRBqSXClWM1q7853I5ujRorq_MV0M,4251
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=VekzumwXByceYkTQ97jSNSKfX2vYBmx4ZSsHs9cyT-0,706
6
+ ai_edge_torch/version.py,sha256=BBJF2KL772nA3u0liHz3Awc8txMvaam40qeMeEdgqqo,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -119,7 +119,7 @@ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=JwndhL3Z31TvkdGlAoTL5PQ
119
119
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
120
120
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=8ze9kVWMuyZVQcgK7hWYw9TM1W9lXD-2j0iMHlxoGX4,9267
121
121
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
122
- ai_edge_torch/generative/quantize/example.py,sha256=tlACaRsz6lqOxakzpXVFJZYfFKOiFqetcYVJqWVRdPE,1542
122
+ ai_edge_torch/generative/quantize/example.py,sha256=1lfVNUd2cEyRUnoZ7BLbRJ9IN-FTKiWBtZNPFUzAiWE,1747
123
123
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
124
124
  ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1UHAwdbChkgPShiVaz4CE,5156
125
125
  ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
@@ -186,8 +186,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
186
186
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
187
187
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
188
188
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
189
- ai_edge_torch_nightly-0.3.0.dev20241027.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
190
- ai_edge_torch_nightly-0.3.0.dev20241027.dist-info/METADATA,sha256=WYTOBwCoMZ3Z8G223xG54Lj8PTR9HUW2Yr5dUVtF0nA,1897
191
- ai_edge_torch_nightly-0.3.0.dev20241027.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
192
- ai_edge_torch_nightly-0.3.0.dev20241027.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
193
- ai_edge_torch_nightly-0.3.0.dev20241027.dist-info/RECORD,,
189
+ ai_edge_torch_nightly-0.3.0.dev20241029.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
190
+ ai_edge_torch_nightly-0.3.0.dev20241029.dist-info/METADATA,sha256=W7mORj6kIG6zf-dO9VElbtwjOl5RaxGz1W365OELbjY,1897
191
+ ai_edge_torch_nightly-0.3.0.dev20241029.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
192
+ ai_edge_torch_nightly-0.3.0.dev20241029.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
193
+ ai_edge_torch_nightly-0.3.0.dev20241029.dist-info/RECORD,,