ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240912__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
  7. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  8. ai_edge_torch/generative/examples/smallm/smallm.py +119 -0
  9. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
  10. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  11. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +40 -24
  12. ai_edge_torch/generative/layers/attention.py +60 -63
  13. ai_edge_torch/generative/layers/builder.py +4 -2
  14. ai_edge_torch/generative/layers/kv_cache.py +160 -51
  15. ai_edge_torch/generative/layers/model_config.py +1 -0
  16. ai_edge_torch/generative/layers/normalization.py +158 -0
  17. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  18. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
  19. ai_edge_torch/generative/test/test_loader.py +1 -1
  20. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  21. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  22. ai_edge_torch/generative/test/utils.py +54 -0
  23. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  24. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  25. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  26. ai_edge_torch/version.py +1 -1
  27. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/METADATA +1 -1
  28. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/RECORD +33 -39
  29. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  30. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  31. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  32. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  33. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  34. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  35. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  36. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  37. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  38. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  39. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  40. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/LICENSE +0 -0
  41. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/WHEEL +0 -0
  42. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,78 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Provides lowering for coreaten to stablehlo for LayerNorm."""
16
+
17
+ import math
18
+ from typing import Optional
19
+ from ai_edge_torch.odml_torch.lowerings import registry
20
+ from ai_edge_torch.odml_torch.lowerings import utils
21
+ from jax._src.lib.mlir import ir
22
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import torch
24
+
25
+
26
+ # native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight,
27
+ # Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
28
+ @registry.lower(torch.ops.aten.native_layer_norm)
29
+ def _aten_native_layer_norm(
30
+ lctx,
31
+ data: ir.Value,
32
+ normalized_shape: list[int],
33
+ weight: Optional[ir.Value],
34
+ bias: Optional[ir.Value],
35
+ eps: float,
36
+ ):
37
+ data_type: ir.RankedTensorType = data.type
38
+ unnormalized_count = math.prod(data_type.shape) // math.prod(normalized_shape)
39
+ dest_shape = [
40
+ 1,
41
+ unnormalized_count,
42
+ math.prod(normalized_shape),
43
+ ]
44
+ dest_type = ir.RankedTensorType.get(dest_shape, data_type.element_type)
45
+
46
+ reshaped_data = stablehlo.reshape(dest_type, data)
47
+
48
+ one = utils.splat(1, data_type.element_type, [unnormalized_count])
49
+ zero = utils.splat(0, data_type.element_type, [unnormalized_count])
50
+ output, mean, var = stablehlo.batch_norm_training(
51
+ reshaped_data, one, zero, eps, 1
52
+ )
53
+ eps_splat = utils.splat(eps, var.type.element_type, var.type.shape)
54
+ rstd = stablehlo.rsqrt(stablehlo.add(var, eps_splat))
55
+
56
+ stats_shape = data_type.shape[: -1 * len(normalized_shape)] + [1] * len(
57
+ normalized_shape
58
+ )
59
+ stats_type = ir.RankedTensorType.get(stats_shape, data_type.element_type)
60
+ mean = stablehlo.reshape(stats_type, mean)
61
+ rstd = stablehlo.reshape(stats_type, rstd)
62
+
63
+ output = stablehlo.reshape(data_type, output)
64
+
65
+ data_rank = len(data_type.shape)
66
+ normalized_rank = len(normalized_shape)
67
+ if weight is not None:
68
+ weight = stablehlo.broadcast_in_dim(
69
+ data_type, weight, list(range(data_rank - normalized_rank, data_rank))
70
+ )
71
+ output = stablehlo.multiply(weight, output)
72
+ if bias is not None:
73
+ bias = stablehlo.broadcast_in_dim(
74
+ data_type, bias, list(range(data_rank - normalized_rank, data_rank))
75
+ )
76
+ output = stablehlo.add(bias, output)
77
+
78
+ return output, mean, rstd
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240910"
16
+ __version__ = "0.3.0.dev20240912"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240910
3
+ Version: 0.3.0.dev20240912
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
5
- ai_edge_torch/version.py,sha256=e4sh_RFYgNHGoVuOeICnFZtLu1MQCNv7qpq94nKFarU,706
5
+ ai_edge_torch/version.py,sha256=Li1VzlXx5ExydpfV93yVAd78cF1L_g3x30-daYdgsLA,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -39,24 +39,17 @@ ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOa
39
39
  ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
40
40
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
41
41
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
- ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
43
- ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
44
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
45
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=aCoD86pf4nuquUMk7MOR-jsN5FqvySSEuMx9Psxjblk,7261
46
- ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
48
- ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=Jxf3ZyYDpS78l6uh4_LGGIcHawrOhZ1vHoHFVxRaK40,6789
49
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=xPVvHQjLJHFiRv_-Fy2sDm0Aft7SG8SXiV6o3rF03cQ,3108
51
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=nUm0SQbCTmNAc5u-C9gbQRFPt7GDvUt6UjH6doTvH-I,6817
52
42
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=pseJExH35lSAK0ZtzSHB1sFtRtF_EuT2xcSpGU0gKVI,2524
54
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=w589IJETATd6Z9_1XCIWbrlCV3E92X_5ac3VVCVFXG0,2522
55
- ai_edge_torch/generative/examples/gemma/gemma.py,sha256=lc1-CfIObHj9D5VJy78BOtGTrQM4TYMI6NfVi8KM5qA,6747
56
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=OcUQLFR136e3QRVXRnmtYnRHXyHJS9EYEFlJ1ymXyRY,8859
57
- ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
58
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=ON6zLO-nFS8eJ2yhyWzT5x2Somr-Ca-VjpjT7OGFU10,2506
59
- ai_edge_torch/generative/examples/phi2/phi2.py,sha256=FFnhv1kx4fHRhSeOreLGj8kAqPnmkz9pD1RRSDVlM_w,6332
43
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=ZJvw8uFVu7FEJ7eXfpzn-pPKgPELoxkGz4Zg7LKKMSI,3048
44
+ ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=hM-fwjZG53p1UE_lkovLMmHRDHleJsb6_0ib0_k0v54,3040
45
+ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=oVV1lXgi9cMPES6JmiV8fJOgBQruRdHpyJL7MmXU09M,7283
46
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=X6WfUCDJDEqyyEAYGq1lmKtlDXcYLzy-n2moQPLJA_U,9769
47
+ ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
48
+ ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=vqEpZVmB0_wMKcAl6RXm7W57DqPTzEdVVN6W2Z-QYzI,3011
49
+ ai_edge_torch/generative/examples/phi/phi2.py,sha256=BzvUrClFx5HKf6PYzJc7ba2O3AwYUJE485u5GSOiPy4,6851
50
+ ai_edge_torch/generative/examples/smallm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
51
+ ai_edge_torch/generative/examples/smallm/convert_to_tflite.py,sha256=aqqxQMBBO_dtGB1iZ1tpF8hbGpdZkx0VIz62ZqfVMCc,3036
52
+ ai_edge_torch/generative/examples/smallm/smallm.py,sha256=j7SDdcX0WvgQWgpaAi7Gi39Jf0-w9D9PftDbugNrN1M,3919
60
53
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
54
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
62
55
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
@@ -78,25 +71,24 @@ ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8v
78
71
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
79
72
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
80
73
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_R_pp-A_7gKGSdHWDSXyis97r1ELVI,5622
81
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=l9swUKTcDtnTibNSNExaMgLvDeJ4Er2tVh5ZW1EtRgk,5809
82
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
74
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=PbWpfg3AOEZjI1FlnZCxRD-kIKtdkR9AOZ6l-9-TpRA,5664
83
75
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
84
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=CLRqO7ycMbpy7J3_Czp1sLx6hcdwGD9zVq04yRba0e8,2550
85
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=4ku0ni3MOWamhPrzLap0BmtdNFk7CH0hwjPNoRAKpvQ,6278
76
+ ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=y4LiWhwgflqrg4WWh3wq5ei3VOT_cV0A62x62qptQiM,3070
77
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=HwoEWls-uJ7oHj0HYxJtgZZhgiBR_OQPXlR6l14vm5E,6778
86
78
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
87
79
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
88
80
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
89
- ai_edge_torch/generative/layers/attention.py,sha256=2UujQePRJ1LK02PN-hGcuMu0ooCJC6ETfPvzEYVFyho,12284
81
+ ai_edge_torch/generative/layers/attention.py,sha256=ee0KHRakhjLjawP32FY2EntxOkyPvjiEZChLnBn_HPc,12601
90
82
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
91
- ai_edge_torch/generative/layers/builder.py,sha256=xb7rjADv3Jm4qfmlYtg6oLLe7ReDE9UjsEqiejPpDD8,4346
83
+ ai_edge_torch/generative/layers/builder.py,sha256=KMwMfZ08r5CXHhcPVZ72nZnIAcsMAIKsv7-QPntlqgI,4418
92
84
  ai_edge_torch/generative/layers/feed_forward.py,sha256=uto7xtwx6jPkk1GZ2x7pSTentQzRrPSKw4_PSE12ahA,3525
93
- ai_edge_torch/generative/layers/kv_cache.py,sha256=Ob8QeXWW5xt-6hcGA0uoC48eRQ8lfvKca8JbWtFx2CE,3082
94
- ai_edge_torch/generative/layers/model_config.py,sha256=WpZ9djUBAZddyeSODHDaVMG37EQqfzGGrlMPi8AA-Hc,5752
95
- ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQcfDK9JM2IaUQFQdn7xs,1860
85
+ ai_edge_torch/generative/layers/kv_cache.py,sha256=WDu03NQwkDCrrrT9Du_3ZOxlURZz3XDbS1PLzFozhMI,6013
86
+ ai_edge_torch/generative/layers/model_config.py,sha256=03tjidDM1uo_H0jsHNjYEUR5R1FEckc1GIxSoE7ItQQ,5780
87
+ ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvCU5a88ESd_s5WY34ErKA,6129
96
88
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
97
89
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
98
90
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
99
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=V4zUAqjWeBseMPG9B-93LDv1LM3Dds6Q-H0NxY0koSA,27212
91
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=cpygyJccLq6KHKxV7oz4YKh529YLjC9isupnsVmPi0A,27190
100
92
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
101
93
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
102
94
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -107,11 +99,12 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC
107
99
  ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
108
100
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
109
101
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
- ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
111
- ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
112
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=b3InJ8Rx03YtHpE9h-j0pSXAY1cCf-dLlx4Y5LSJnRQ,5174
113
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=9JXcd-rX8MpsYeEWUFEXf783GOwYOLY64KzDfFdmRJ8,4484
102
+ ai_edge_torch/generative/test/test_kv_cache.py,sha256=FU2rmU03Lp-vZ5wWXXCao1WEw7xbpqebFMANL_O2chA,3713
103
+ ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
104
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=SIv7_sc5qHvbHFN8SbAfY00iXGvH7J6cJLkERU_cd5k,5888
105
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=F3q3K9ZgWBzlLy4WpE8-w6UWSuJ-UoJwMm3N6Zb3Y14,5016
114
106
  ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
107
+ ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
115
108
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
116
109
  ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
117
110
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
@@ -145,11 +138,12 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW
145
138
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
146
139
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
147
140
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
148
- ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GqYk6oBJw7KWeG4_6gxSu_OvYhjJcC2FpGzWPPEdH6w,933
141
+ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
149
142
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
150
143
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
151
144
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
152
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=s-cT_tIQHu7w5hXl8MCixRxLlHplpXW-UCzHT9TY--o,10621
145
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=Ii1akrKLhRTkZ715JxXBBGKv3jGfXReXMQCYNzSnxmM,10567
146
+ ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
153
147
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
154
148
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
155
149
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
@@ -161,8 +155,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
161
155
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
162
156
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
163
157
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
164
- ai_edge_torch_nightly-0.3.0.dev20240910.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
- ai_edge_torch_nightly-0.3.0.dev20240910.dist-info/METADATA,sha256=WFNExTO6eF-tAWPmDdQDlr9dvplcoNB0uPdVxSNXYHk,1859
166
- ai_edge_torch_nightly-0.3.0.dev20240910.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
- ai_edge_torch_nightly-0.3.0.dev20240910.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
- ai_edge_torch_nightly-0.3.0.dev20240910.dist-info/RECORD,,
158
+ ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
159
+ ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/METADATA,sha256=EjeMjRJ5PeW8Azc8hoiJeMP_WaHUDlCend4DFIeQnzc,1859
160
+ ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
161
+ ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
162
+ ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/RECORD,,
@@ -1,219 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- # Example of building a Gemma model.
16
- #
17
- # Note: This is an experimental version of Gemma with external KV cache.
18
- # Please use with caution.
19
-
20
- import os
21
- from pathlib import Path
22
- from typing import Tuple
23
-
24
- from ai_edge_torch.generative.layers import builder
25
- import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
- from ai_edge_torch.generative.layers.experimental import attention
27
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
- import ai_edge_torch.generative.layers.model_config as cfg
29
- import ai_edge_torch.generative.utilities.loader as loading_utils
30
- import numpy as np
31
- import torch
32
- from torch import nn
33
-
34
-
35
- TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
- ff_up_proj="model.layers.{}.mlp.up_proj",
37
- ff_down_proj="model.layers.{}.mlp.down_proj",
38
- ff_gate_proj="model.layers.{}.mlp.gate_proj",
39
- attn_query_proj="model.layers.{}.self_attn.q_proj",
40
- attn_key_proj="model.layers.{}.self_attn.k_proj",
41
- attn_value_proj="model.layers.{}.self_attn.v_proj",
42
- attn_output_proj="model.layers.{}.self_attn.o_proj",
43
- pre_attn_norm="model.layers.{}.input_layernorm",
44
- post_attn_norm="model.layers.{}.post_attention_layernorm",
45
- embedding="model.embed_tokens",
46
- final_norm="model.norm",
47
- lm_head=None,
48
- )
49
-
50
-
51
- class Gemma(nn.Module):
52
- """A Gemma model built from the Edge Generative API layers."""
53
-
54
- def __init__(self, config: cfg.ModelConfig):
55
- super().__init__()
56
-
57
- self.config = config
58
- # Construct model layers.
59
- self.tok_embedding = nn.Embedding(
60
- config.vocab_size, config.embedding_dim, padding_idx=0
61
- )
62
- self.lm_head = nn.Linear(
63
- config.embedding_dim,
64
- config.vocab_size,
65
- bias=config.lm_head_use_bias,
66
- )
67
- # Gemma re-uses the embedding as the head projection layer.
68
- self.lm_head.weight.data = self.tok_embedding.weight.data
69
- self.transformer_blocks = nn.ModuleList(
70
- attention.TransformerBlock(config) for _ in range(config.num_layers)
71
- )
72
- self.final_norm = builder.build_norm(
73
- config.embedding_dim,
74
- config.final_norm_config,
75
- )
76
- self.rope_cache = attn_utils.build_rope_cache(
77
- size=config.kv_cache_max,
78
- dim=int(
79
- config.attn_config.rotary_percentage * config.attn_config.head_dim
80
- ),
81
- base=10_000,
82
- condense_ratio=1,
83
- dtype=torch.float32,
84
- device=torch.device("cpu"),
85
- )
86
- self.mask_cache = attn_utils.build_causal_mask_cache(
87
- size=config.kv_cache_max,
88
- dtype=torch.float32,
89
- device=torch.device("cpu"),
90
- )
91
- self.config = config
92
-
93
- @torch.inference_mode
94
- def forward(
95
- self,
96
- tokens: torch.Tensor,
97
- input_pos: torch.Tensor,
98
- kv_cache: kv_utils.EKVCache,
99
- ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
100
- _, seq_len = tokens.size()
101
- assert self.config.max_seq_len >= seq_len, (
102
- f"Cannot forward sequence of length {seq_len}, max seq length is only"
103
- f" {self.config.max_seq_len}"
104
- )
105
-
106
- cos, sin = self.rope_cache
107
- cos = cos.index_select(0, input_pos)
108
- sin = sin.index_select(0, input_pos)
109
- mask = self.mask_cache.index_select(2, input_pos)
110
- mask = mask[:, :, :, : self.config.kv_cache_max]
111
-
112
- # token embeddings of shape (b, t, n_embd)
113
- x = self.tok_embedding(tokens)
114
- x = x * (self.config.embedding_dim**0.5)
115
-
116
- updated_kv_entires = []
117
- for i, block in enumerate(self.transformer_blocks):
118
- kv_entry = kv_cache.caches[i] if kv_cache else None
119
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
120
- if kv_entry:
121
- updated_kv_entires.append(kv_entry)
122
- updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
123
-
124
- x = self.final_norm(x)
125
- res = self.lm_head(x) # (b, t, vocab_size)
126
- return res, updated_kv_cache
127
-
128
-
129
- def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
130
- """Returns the model config for a Gemma 2B model.
131
-
132
- Args:
133
- kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
134
- is 1024.
135
-
136
- Returns:
137
- The model config for a Gemma 2B model.
138
- """
139
- attn_config = cfg.AttentionConfig(
140
- num_heads=8,
141
- head_dim=256,
142
- num_query_groups=1,
143
- rotary_percentage=1.0,
144
- )
145
- ff_config = cfg.FeedForwardConfig(
146
- type=cfg.FeedForwardType.GATED,
147
- activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
148
- intermediate_size=16384,
149
- )
150
- norm_config = cfg.NormalizationConfig(
151
- type=cfg.NormalizationType.RMS_NORM,
152
- epsilon=1e-6,
153
- zero_centered=True,
154
- )
155
- config = cfg.ModelConfig(
156
- vocab_size=256000,
157
- num_layers=18,
158
- max_seq_len=8192,
159
- embedding_dim=2048,
160
- kv_cache_max_len=kv_cache_max_len,
161
- attn_config=attn_config,
162
- ff_config=ff_config,
163
- pre_attention_norm_config=norm_config,
164
- post_attention_norm_config=norm_config,
165
- final_norm_config=norm_config,
166
- parallel_residual=False,
167
- lm_head_use_bias=False,
168
- enable_hlfb=True,
169
- )
170
- return config
171
-
172
-
173
- def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
174
- config = get_model_config_2b(kv_cache_max_len)
175
- config.ff_config.intermediate_size = 128
176
- config.vocab_size = 128
177
- config.num_layers = 2
178
- config.max_seq_len = 2 * kv_cache_max_len
179
- return config
180
-
181
-
182
- def build_2b_model(
183
- checkpoint_path: str, test_model: bool = False, **kwargs
184
- ) -> nn.Module:
185
- """Instantiates the model instance and load checkpoint if provided."""
186
- config = (
187
- get_fake_model_config(**kwargs)
188
- if test_model
189
- else get_model_config_2b(**kwargs)
190
- )
191
- model = Gemma(config)
192
- if checkpoint_path is not None:
193
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
194
- # since embedding and lm-head use the same weight, we need to set strict
195
- # to False.
196
- loader.load(model, strict=False)
197
- model.eval()
198
- return model
199
-
200
-
201
- def define_and_run_2b(checkpoint_path: str, test_model: bool = False) -> None:
202
- """Instantiates and runs a Gemma 2B model."""
203
-
204
- kv_cache_max_len = 1024
205
- model = build_2b_model(
206
- checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
207
- )
208
- idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
209
- tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
210
- tokens[0, :4] = idx
211
- input_pos = torch.arange(0, kv_cache_max_len)
212
- kv = kv_utils.EKVCache.from_model_config(model.config)
213
- print("running an inference")
214
- print(model.forward(tokens, input_pos, kv))
215
-
216
-
217
- if __name__ == "__main__":
218
- input_checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
219
- define_and_run_2b(input_checkpoint_path)
@@ -1,14 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,14 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
@@ -1,87 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- #
16
- # Note: This is an experimental version of TinyLlama with external KV cache.
17
- # Please use with caution.
18
-
19
-
20
- import os
21
- from pathlib import Path
22
-
23
- import ai_edge_torch
24
- from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama
25
- from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
26
- from ai_edge_torch.generative.quantize import quant_recipes
27
- import torch
28
-
29
-
30
- def convert_tiny_llama_to_tflite(
31
- checkpoint_path: str,
32
- prefill_seq_len: int = 512,
33
- kv_cache_max_len: int = 1024,
34
- quantize: bool = True,
35
- ):
36
- """An example for converting TinyLlama model to multi-signature tflite model.
37
-
38
- Args:
39
- checkpoint_path (str): The filepath to the model checkpoint, or directory
40
- holding the checkpoint.
41
- prefill_seq_len (int, optional): The maximum size of prefill input tensor.
42
- Defaults to 512.
43
- kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
44
- including both prefill and decode. Defaults to 1024.
45
- quantize (bool, optional): Whether the model should be quanized. Defaults
46
- to True.
47
- """
48
- pytorch_model = tiny_llama.build_model(
49
- checkpoint_path, kv_cache_max_len=kv_cache_max_len
50
- )
51
- # Tensors used to trace the model graph during conversion.
52
- prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
53
- prefill_input_pos = torch.arange(0, prefill_seq_len)
54
- decode_token = torch.tensor([[0]], dtype=torch.long)
55
- decode_input_pos = torch.tensor([0], dtype=torch.int64)
56
- kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
57
-
58
- quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
59
- edge_model = (
60
- ai_edge_torch.signature(
61
- 'prefill',
62
- pytorch_model,
63
- sample_kwargs={
64
- 'tokens': prefill_tokens,
65
- 'input_pos': prefill_input_pos,
66
- 'kv_cache': kv,
67
- },
68
- )
69
- .signature(
70
- 'decode',
71
- pytorch_model,
72
- sample_kwargs={
73
- 'tokens': decode_token,
74
- 'input_pos': decode_input_pos,
75
- 'kv_cache': kv,
76
- },
77
- )
78
- .convert(quant_config=quant_config)
79
- )
80
- edge_model.export(
81
- f'/tmp/tiny_llama_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
82
- )
83
-
84
-
85
- if __name__ == '__main__':
86
- checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
87
- convert_tiny_llama_to_tflite(checkpoint_path)