ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240912__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +34 -18
- ai_edge_torch/generative/examples/gemma/gemma2.py +38 -17
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +31 -33
- ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
- ai_edge_torch/generative/examples/smallm/smallm.py +119 -0
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +58 -25
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +40 -24
- ai_edge_torch/generative/layers/attention.py +60 -63
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/kv_cache.py +160 -51
- ai_edge_torch/generative/layers/model_config.py +1 -0
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +8 -22
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +72 -34
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/RECORD +33 -39
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240912.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,78 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Provides lowering for coreaten to stablehlo for LayerNorm."""
|
16
|
+
|
17
|
+
import math
|
18
|
+
from typing import Optional
|
19
|
+
from ai_edge_torch.odml_torch.lowerings import registry
|
20
|
+
from ai_edge_torch.odml_torch.lowerings import utils
|
21
|
+
from jax._src.lib.mlir import ir
|
22
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
23
|
+
import torch
|
24
|
+
|
25
|
+
|
26
|
+
# native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight,
|
27
|
+
# Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
|
28
|
+
@registry.lower(torch.ops.aten.native_layer_norm)
|
29
|
+
def _aten_native_layer_norm(
|
30
|
+
lctx,
|
31
|
+
data: ir.Value,
|
32
|
+
normalized_shape: list[int],
|
33
|
+
weight: Optional[ir.Value],
|
34
|
+
bias: Optional[ir.Value],
|
35
|
+
eps: float,
|
36
|
+
):
|
37
|
+
data_type: ir.RankedTensorType = data.type
|
38
|
+
unnormalized_count = math.prod(data_type.shape) // math.prod(normalized_shape)
|
39
|
+
dest_shape = [
|
40
|
+
1,
|
41
|
+
unnormalized_count,
|
42
|
+
math.prod(normalized_shape),
|
43
|
+
]
|
44
|
+
dest_type = ir.RankedTensorType.get(dest_shape, data_type.element_type)
|
45
|
+
|
46
|
+
reshaped_data = stablehlo.reshape(dest_type, data)
|
47
|
+
|
48
|
+
one = utils.splat(1, data_type.element_type, [unnormalized_count])
|
49
|
+
zero = utils.splat(0, data_type.element_type, [unnormalized_count])
|
50
|
+
output, mean, var = stablehlo.batch_norm_training(
|
51
|
+
reshaped_data, one, zero, eps, 1
|
52
|
+
)
|
53
|
+
eps_splat = utils.splat(eps, var.type.element_type, var.type.shape)
|
54
|
+
rstd = stablehlo.rsqrt(stablehlo.add(var, eps_splat))
|
55
|
+
|
56
|
+
stats_shape = data_type.shape[: -1 * len(normalized_shape)] + [1] * len(
|
57
|
+
normalized_shape
|
58
|
+
)
|
59
|
+
stats_type = ir.RankedTensorType.get(stats_shape, data_type.element_type)
|
60
|
+
mean = stablehlo.reshape(stats_type, mean)
|
61
|
+
rstd = stablehlo.reshape(stats_type, rstd)
|
62
|
+
|
63
|
+
output = stablehlo.reshape(data_type, output)
|
64
|
+
|
65
|
+
data_rank = len(data_type.shape)
|
66
|
+
normalized_rank = len(normalized_shape)
|
67
|
+
if weight is not None:
|
68
|
+
weight = stablehlo.broadcast_in_dim(
|
69
|
+
data_type, weight, list(range(data_rank - normalized_rank, data_rank))
|
70
|
+
)
|
71
|
+
output = stablehlo.multiply(weight, output)
|
72
|
+
if bias is not None:
|
73
|
+
bias = stablehlo.broadcast_in_dim(
|
74
|
+
data_type, bias, list(range(data_rank - normalized_rank, data_rank))
|
75
|
+
)
|
76
|
+
output = stablehlo.add(bias, output)
|
77
|
+
|
78
|
+
return output, mean, rstd
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240912
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=Li1VzlXx5ExydpfV93yVAd78cF1L_g3x30-daYdgsLA,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -39,24 +39,17 @@ ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOa
|
|
39
39
|
ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
40
40
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
41
41
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
42
|
-
ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
43
|
-
ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
44
|
-
ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
|
45
|
-
ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=aCoD86pf4nuquUMk7MOR-jsN5FqvySSEuMx9Psxjblk,7261
|
46
|
-
ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
|
-
ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
|
48
|
-
ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=Jxf3ZyYDpS78l6uh4_LGGIcHawrOhZ1vHoHFVxRaK40,6789
|
49
|
-
ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
|
-
ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=xPVvHQjLJHFiRv_-Fy2sDm0Aft7SG8SXiV6o3rF03cQ,3108
|
51
|
-
ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=nUm0SQbCTmNAc5u-C9gbQRFPt7GDvUt6UjH6doTvH-I,6817
|
52
42
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
53
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=
|
54
|
-
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=
|
55
|
-
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=
|
56
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
57
|
-
ai_edge_torch/generative/examples/
|
58
|
-
ai_edge_torch/generative/examples/
|
59
|
-
ai_edge_torch/generative/examples/
|
43
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=ZJvw8uFVu7FEJ7eXfpzn-pPKgPELoxkGz4Zg7LKKMSI,3048
|
44
|
+
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=hM-fwjZG53p1UE_lkovLMmHRDHleJsb6_0ib0_k0v54,3040
|
45
|
+
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=oVV1lXgi9cMPES6JmiV8fJOgBQruRdHpyJL7MmXU09M,7283
|
46
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=X6WfUCDJDEqyyEAYGq1lmKtlDXcYLzy-n2moQPLJA_U,9769
|
47
|
+
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
48
|
+
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=vqEpZVmB0_wMKcAl6RXm7W57DqPTzEdVVN6W2Z-QYzI,3011
|
49
|
+
ai_edge_torch/generative/examples/phi/phi2.py,sha256=BzvUrClFx5HKf6PYzJc7ba2O3AwYUJE485u5GSOiPy4,6851
|
50
|
+
ai_edge_torch/generative/examples/smallm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
51
|
+
ai_edge_torch/generative/examples/smallm/convert_to_tflite.py,sha256=aqqxQMBBO_dtGB1iZ1tpF8hbGpdZkx0VIz62ZqfVMCc,3036
|
52
|
+
ai_edge_torch/generative/examples/smallm/smallm.py,sha256=j7SDdcX0WvgQWgpaAi7Gi39Jf0-w9D9PftDbugNrN1M,3919
|
60
53
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
61
54
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
62
55
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
|
@@ -78,25 +71,24 @@ ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8v
|
|
78
71
|
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
|
79
72
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
80
73
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_R_pp-A_7gKGSdHWDSXyis97r1ELVI,5622
|
81
|
-
ai_edge_torch/generative/examples/test_models/
|
82
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
|
74
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=PbWpfg3AOEZjI1FlnZCxRD-kIKtdkR9AOZ6l-9-TpRA,5664
|
83
75
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
84
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=
|
85
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
76
|
+
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=y4LiWhwgflqrg4WWh3wq5ei3VOT_cV0A62x62qptQiM,3070
|
77
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=HwoEWls-uJ7oHj0HYxJtgZZhgiBR_OQPXlR6l14vm5E,6778
|
86
78
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
|
87
79
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
|
88
80
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
89
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
81
|
+
ai_edge_torch/generative/layers/attention.py,sha256=ee0KHRakhjLjawP32FY2EntxOkyPvjiEZChLnBn_HPc,12601
|
90
82
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
91
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
83
|
+
ai_edge_torch/generative/layers/builder.py,sha256=KMwMfZ08r5CXHhcPVZ72nZnIAcsMAIKsv7-QPntlqgI,4418
|
92
84
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=uto7xtwx6jPkk1GZ2x7pSTentQzRrPSKw4_PSE12ahA,3525
|
93
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
94
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
95
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
85
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=WDu03NQwkDCrrrT9Du_3ZOxlURZz3XDbS1PLzFozhMI,6013
|
86
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=03tjidDM1uo_H0jsHNjYEUR5R1FEckc1GIxSoE7ItQQ,5780
|
87
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvCU5a88ESd_s5WY34ErKA,6129
|
96
88
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
97
89
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
|
98
90
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
99
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
91
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=cpygyJccLq6KHKxV7oz4YKh529YLjC9isupnsVmPi0A,27190
|
100
92
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
101
93
|
ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
|
102
94
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -107,11 +99,12 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC
|
|
107
99
|
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
|
108
100
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
109
101
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
110
|
-
ai_edge_torch/generative/test/
|
111
|
-
ai_edge_torch/generative/test/test_loader.py,sha256=
|
112
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
113
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
102
|
+
ai_edge_torch/generative/test/test_kv_cache.py,sha256=FU2rmU03Lp-vZ5wWXXCao1WEw7xbpqebFMANL_O2chA,3713
|
103
|
+
ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
|
104
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=SIv7_sc5qHvbHFN8SbAfY00iXGvH7J6cJLkERU_cd5k,5888
|
105
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=F3q3K9ZgWBzlLy4WpE8-w6UWSuJ-UoJwMm3N6Zb3Y14,5016
|
114
106
|
ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
|
107
|
+
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
115
108
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
116
109
|
ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
|
117
110
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
|
@@ -145,11 +138,12 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW
|
|
145
138
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
|
146
139
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
|
147
140
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
148
|
-
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=
|
141
|
+
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
|
149
142
|
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
|
150
143
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
151
144
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
|
152
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
145
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=Ii1akrKLhRTkZ715JxXBBGKv3jGfXReXMQCYNzSnxmM,10567
|
146
|
+
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
|
153
147
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
154
148
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
|
155
149
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
|
@@ -161,8 +155,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
161
155
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
162
156
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
163
157
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
164
|
-
ai_edge_torch_nightly-0.3.0.
|
165
|
-
ai_edge_torch_nightly-0.3.0.
|
166
|
-
ai_edge_torch_nightly-0.3.0.
|
167
|
-
ai_edge_torch_nightly-0.3.0.
|
168
|
-
ai_edge_torch_nightly-0.3.0.
|
158
|
+
ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
159
|
+
ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/METADATA,sha256=EjeMjRJ5PeW8Azc8hoiJeMP_WaHUDlCend4DFIeQnzc,1859
|
160
|
+
ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
161
|
+
ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
162
|
+
ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/RECORD,,
|
@@ -1,219 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
# Example of building a Gemma model.
|
16
|
-
#
|
17
|
-
# Note: This is an experimental version of Gemma with external KV cache.
|
18
|
-
# Please use with caution.
|
19
|
-
|
20
|
-
import os
|
21
|
-
from pathlib import Path
|
22
|
-
from typing import Tuple
|
23
|
-
|
24
|
-
from ai_edge_torch.generative.layers import builder
|
25
|
-
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
|
-
from ai_edge_torch.generative.layers.experimental import attention
|
27
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
28
|
-
import ai_edge_torch.generative.layers.model_config as cfg
|
29
|
-
import ai_edge_torch.generative.utilities.loader as loading_utils
|
30
|
-
import numpy as np
|
31
|
-
import torch
|
32
|
-
from torch import nn
|
33
|
-
|
34
|
-
|
35
|
-
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
36
|
-
ff_up_proj="model.layers.{}.mlp.up_proj",
|
37
|
-
ff_down_proj="model.layers.{}.mlp.down_proj",
|
38
|
-
ff_gate_proj="model.layers.{}.mlp.gate_proj",
|
39
|
-
attn_query_proj="model.layers.{}.self_attn.q_proj",
|
40
|
-
attn_key_proj="model.layers.{}.self_attn.k_proj",
|
41
|
-
attn_value_proj="model.layers.{}.self_attn.v_proj",
|
42
|
-
attn_output_proj="model.layers.{}.self_attn.o_proj",
|
43
|
-
pre_attn_norm="model.layers.{}.input_layernorm",
|
44
|
-
post_attn_norm="model.layers.{}.post_attention_layernorm",
|
45
|
-
embedding="model.embed_tokens",
|
46
|
-
final_norm="model.norm",
|
47
|
-
lm_head=None,
|
48
|
-
)
|
49
|
-
|
50
|
-
|
51
|
-
class Gemma(nn.Module):
|
52
|
-
"""A Gemma model built from the Edge Generative API layers."""
|
53
|
-
|
54
|
-
def __init__(self, config: cfg.ModelConfig):
|
55
|
-
super().__init__()
|
56
|
-
|
57
|
-
self.config = config
|
58
|
-
# Construct model layers.
|
59
|
-
self.tok_embedding = nn.Embedding(
|
60
|
-
config.vocab_size, config.embedding_dim, padding_idx=0
|
61
|
-
)
|
62
|
-
self.lm_head = nn.Linear(
|
63
|
-
config.embedding_dim,
|
64
|
-
config.vocab_size,
|
65
|
-
bias=config.lm_head_use_bias,
|
66
|
-
)
|
67
|
-
# Gemma re-uses the embedding as the head projection layer.
|
68
|
-
self.lm_head.weight.data = self.tok_embedding.weight.data
|
69
|
-
self.transformer_blocks = nn.ModuleList(
|
70
|
-
attention.TransformerBlock(config) for _ in range(config.num_layers)
|
71
|
-
)
|
72
|
-
self.final_norm = builder.build_norm(
|
73
|
-
config.embedding_dim,
|
74
|
-
config.final_norm_config,
|
75
|
-
)
|
76
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
77
|
-
size=config.kv_cache_max,
|
78
|
-
dim=int(
|
79
|
-
config.attn_config.rotary_percentage * config.attn_config.head_dim
|
80
|
-
),
|
81
|
-
base=10_000,
|
82
|
-
condense_ratio=1,
|
83
|
-
dtype=torch.float32,
|
84
|
-
device=torch.device("cpu"),
|
85
|
-
)
|
86
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
87
|
-
size=config.kv_cache_max,
|
88
|
-
dtype=torch.float32,
|
89
|
-
device=torch.device("cpu"),
|
90
|
-
)
|
91
|
-
self.config = config
|
92
|
-
|
93
|
-
@torch.inference_mode
|
94
|
-
def forward(
|
95
|
-
self,
|
96
|
-
tokens: torch.Tensor,
|
97
|
-
input_pos: torch.Tensor,
|
98
|
-
kv_cache: kv_utils.EKVCache,
|
99
|
-
) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
|
100
|
-
_, seq_len = tokens.size()
|
101
|
-
assert self.config.max_seq_len >= seq_len, (
|
102
|
-
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
103
|
-
f" {self.config.max_seq_len}"
|
104
|
-
)
|
105
|
-
|
106
|
-
cos, sin = self.rope_cache
|
107
|
-
cos = cos.index_select(0, input_pos)
|
108
|
-
sin = sin.index_select(0, input_pos)
|
109
|
-
mask = self.mask_cache.index_select(2, input_pos)
|
110
|
-
mask = mask[:, :, :, : self.config.kv_cache_max]
|
111
|
-
|
112
|
-
# token embeddings of shape (b, t, n_embd)
|
113
|
-
x = self.tok_embedding(tokens)
|
114
|
-
x = x * (self.config.embedding_dim**0.5)
|
115
|
-
|
116
|
-
updated_kv_entires = []
|
117
|
-
for i, block in enumerate(self.transformer_blocks):
|
118
|
-
kv_entry = kv_cache.caches[i] if kv_cache else None
|
119
|
-
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
120
|
-
if kv_entry:
|
121
|
-
updated_kv_entires.append(kv_entry)
|
122
|
-
updated_kv_cache = kv_utils.EKVCache(tuple(updated_kv_entires))
|
123
|
-
|
124
|
-
x = self.final_norm(x)
|
125
|
-
res = self.lm_head(x) # (b, t, vocab_size)
|
126
|
-
return res, updated_kv_cache
|
127
|
-
|
128
|
-
|
129
|
-
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
130
|
-
"""Returns the model config for a Gemma 2B model.
|
131
|
-
|
132
|
-
Args:
|
133
|
-
kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
|
134
|
-
is 1024.
|
135
|
-
|
136
|
-
Returns:
|
137
|
-
The model config for a Gemma 2B model.
|
138
|
-
"""
|
139
|
-
attn_config = cfg.AttentionConfig(
|
140
|
-
num_heads=8,
|
141
|
-
head_dim=256,
|
142
|
-
num_query_groups=1,
|
143
|
-
rotary_percentage=1.0,
|
144
|
-
)
|
145
|
-
ff_config = cfg.FeedForwardConfig(
|
146
|
-
type=cfg.FeedForwardType.GATED,
|
147
|
-
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
148
|
-
intermediate_size=16384,
|
149
|
-
)
|
150
|
-
norm_config = cfg.NormalizationConfig(
|
151
|
-
type=cfg.NormalizationType.RMS_NORM,
|
152
|
-
epsilon=1e-6,
|
153
|
-
zero_centered=True,
|
154
|
-
)
|
155
|
-
config = cfg.ModelConfig(
|
156
|
-
vocab_size=256000,
|
157
|
-
num_layers=18,
|
158
|
-
max_seq_len=8192,
|
159
|
-
embedding_dim=2048,
|
160
|
-
kv_cache_max_len=kv_cache_max_len,
|
161
|
-
attn_config=attn_config,
|
162
|
-
ff_config=ff_config,
|
163
|
-
pre_attention_norm_config=norm_config,
|
164
|
-
post_attention_norm_config=norm_config,
|
165
|
-
final_norm_config=norm_config,
|
166
|
-
parallel_residual=False,
|
167
|
-
lm_head_use_bias=False,
|
168
|
-
enable_hlfb=True,
|
169
|
-
)
|
170
|
-
return config
|
171
|
-
|
172
|
-
|
173
|
-
def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
174
|
-
config = get_model_config_2b(kv_cache_max_len)
|
175
|
-
config.ff_config.intermediate_size = 128
|
176
|
-
config.vocab_size = 128
|
177
|
-
config.num_layers = 2
|
178
|
-
config.max_seq_len = 2 * kv_cache_max_len
|
179
|
-
return config
|
180
|
-
|
181
|
-
|
182
|
-
def build_2b_model(
|
183
|
-
checkpoint_path: str, test_model: bool = False, **kwargs
|
184
|
-
) -> nn.Module:
|
185
|
-
"""Instantiates the model instance and load checkpoint if provided."""
|
186
|
-
config = (
|
187
|
-
get_fake_model_config(**kwargs)
|
188
|
-
if test_model
|
189
|
-
else get_model_config_2b(**kwargs)
|
190
|
-
)
|
191
|
-
model = Gemma(config)
|
192
|
-
if checkpoint_path is not None:
|
193
|
-
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
194
|
-
# since embedding and lm-head use the same weight, we need to set strict
|
195
|
-
# to False.
|
196
|
-
loader.load(model, strict=False)
|
197
|
-
model.eval()
|
198
|
-
return model
|
199
|
-
|
200
|
-
|
201
|
-
def define_and_run_2b(checkpoint_path: str, test_model: bool = False) -> None:
|
202
|
-
"""Instantiates and runs a Gemma 2B model."""
|
203
|
-
|
204
|
-
kv_cache_max_len = 1024
|
205
|
-
model = build_2b_model(
|
206
|
-
checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
|
207
|
-
)
|
208
|
-
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
209
|
-
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
|
210
|
-
tokens[0, :4] = idx
|
211
|
-
input_pos = torch.arange(0, kv_cache_max_len)
|
212
|
-
kv = kv_utils.EKVCache.from_model_config(model.config)
|
213
|
-
print("running an inference")
|
214
|
-
print(model.forward(tokens, input_pos, kv))
|
215
|
-
|
216
|
-
|
217
|
-
if __name__ == "__main__":
|
218
|
-
input_checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
|
219
|
-
define_and_run_2b(input_checkpoint_path)
|
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
@@ -1,14 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
@@ -1,87 +0,0 @@
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
# ==============================================================================
|
15
|
-
#
|
16
|
-
# Note: This is an experimental version of TinyLlama with external KV cache.
|
17
|
-
# Please use with caution.
|
18
|
-
|
19
|
-
|
20
|
-
import os
|
21
|
-
from pathlib import Path
|
22
|
-
|
23
|
-
import ai_edge_torch
|
24
|
-
from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama
|
25
|
-
from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
26
|
-
from ai_edge_torch.generative.quantize import quant_recipes
|
27
|
-
import torch
|
28
|
-
|
29
|
-
|
30
|
-
def convert_tiny_llama_to_tflite(
|
31
|
-
checkpoint_path: str,
|
32
|
-
prefill_seq_len: int = 512,
|
33
|
-
kv_cache_max_len: int = 1024,
|
34
|
-
quantize: bool = True,
|
35
|
-
):
|
36
|
-
"""An example for converting TinyLlama model to multi-signature tflite model.
|
37
|
-
|
38
|
-
Args:
|
39
|
-
checkpoint_path (str): The filepath to the model checkpoint, or directory
|
40
|
-
holding the checkpoint.
|
41
|
-
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
42
|
-
Defaults to 512.
|
43
|
-
kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
|
44
|
-
including both prefill and decode. Defaults to 1024.
|
45
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
46
|
-
to True.
|
47
|
-
"""
|
48
|
-
pytorch_model = tiny_llama.build_model(
|
49
|
-
checkpoint_path, kv_cache_max_len=kv_cache_max_len
|
50
|
-
)
|
51
|
-
# Tensors used to trace the model graph during conversion.
|
52
|
-
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
|
53
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len)
|
54
|
-
decode_token = torch.tensor([[0]], dtype=torch.long)
|
55
|
-
decode_input_pos = torch.tensor([0], dtype=torch.int64)
|
56
|
-
kv = kv_utils.EKVCache.from_model_config(pytorch_model.config)
|
57
|
-
|
58
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
59
|
-
edge_model = (
|
60
|
-
ai_edge_torch.signature(
|
61
|
-
'prefill',
|
62
|
-
pytorch_model,
|
63
|
-
sample_kwargs={
|
64
|
-
'tokens': prefill_tokens,
|
65
|
-
'input_pos': prefill_input_pos,
|
66
|
-
'kv_cache': kv,
|
67
|
-
},
|
68
|
-
)
|
69
|
-
.signature(
|
70
|
-
'decode',
|
71
|
-
pytorch_model,
|
72
|
-
sample_kwargs={
|
73
|
-
'tokens': decode_token,
|
74
|
-
'input_pos': decode_input_pos,
|
75
|
-
'kv_cache': kv,
|
76
|
-
},
|
77
|
-
)
|
78
|
-
.convert(quant_config=quant_config)
|
79
|
-
)
|
80
|
-
edge_model.export(
|
81
|
-
f'/tmp/tiny_llama_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
|
82
|
-
)
|
83
|
-
|
84
|
-
|
85
|
-
if __name__ == '__main__':
|
86
|
-
checkpoint_path = os.path.join(Path.home(), 'Downloads/llm_data/tiny_llama')
|
87
|
-
convert_tiny_llama_to_tflite(checkpoint_path)
|