ai-edge-torch-nightly 0.4.0.dev20250311__py3-none-any.whl → 0.4.0.dev20250313__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -247,6 +247,9 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
247
247
  rotary_base=10000,
248
248
  rotary_percentage=1.0,
249
249
  qkv_transpose_before_split=True,
250
+ # The safetensors from HF is not using the interleaved qkv format, so
251
+ # we need to disable interleaving here in the model config.
252
+ qkv_fused_interleaved=False,
250
253
  logit_softcap=50.0,
251
254
  sliding_window_size=4096,
252
255
  attn_type=(
@@ -0,0 +1,14 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,124 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting a Gemma3 model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.gemma3 import gemma3
24
+ from ai_edge_torch.generative.layers.experimental import kv_cache
25
+ from ai_edge_torch.generative.utilities import converter
26
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
27
+ import torch
28
+
29
+ _MODEL_SIZE = flags.DEFINE_string(
30
+ 'model_size',
31
+ '1b',
32
+ 'The size of the model to convert.',
33
+ )
34
+
35
+ _CHECKPOINT_PATH = flags.DEFINE_string(
36
+ 'checkpoint_path',
37
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma3-1b'),
38
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
39
+ )
40
+ _OUTPUT_PATH = flags.DEFINE_string(
41
+ 'output_path',
42
+ '/tmp/',
43
+ 'The path to export the tflite model.',
44
+ )
45
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
46
+ 'output_name_prefix',
47
+ 'gemma3',
48
+ 'The prefix of the output tflite model name.',
49
+ )
50
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
51
+ 'prefill_seq_lens',
52
+ (32, 64, 128, 256, 512, 1024),
53
+ 'List of the maximum sizes of prefill input tensors.',
54
+ )
55
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
56
+ 'kv_cache_max_len',
57
+ 2048,
58
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
59
+ )
60
+ _QUANTIZE = flags.DEFINE_bool(
61
+ 'quantize',
62
+ False,
63
+ 'Whether the model should be quantized.',
64
+ )
65
+ _LORA_RANKS = flags.DEFINE_multi_integer(
66
+ 'lora_ranks',
67
+ None,
68
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
69
+ )
70
+
71
+
72
+ def _create_mask(mask_len, kv_cache_max_len):
73
+ mask = torch.full(
74
+ (mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
75
+ )
76
+ mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
77
+ return mask
78
+
79
+
80
+ def _create_export_config(
81
+ prefill_seq_lens: list[int], kv_cache_max_len: int
82
+ ) -> ExportConfig:
83
+ """Creates the export config for the model."""
84
+ export_config = ExportConfig()
85
+ if isinstance(prefill_seq_lens, list):
86
+ prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
87
+ else:
88
+ prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
89
+
90
+ export_config.prefill_mask = prefill_mask
91
+
92
+ decode_mask = torch.full(
93
+ (1, kv_cache_max_len), float('-inf'), dtype=torch.float32
94
+ )
95
+ decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
96
+ export_config.decode_mask = decode_mask
97
+ export_config.kvcache_cls = kv_cache.KVCacheTransposed
98
+ return export_config
99
+
100
+
101
+ def main(_):
102
+ if _MODEL_SIZE.value == '1b':
103
+ pytorch_model = gemma3.build_model_1b(
104
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
105
+ )
106
+ config = pytorch_model.config
107
+ else:
108
+ raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
109
+ converter.convert_to_tflite(
110
+ pytorch_model,
111
+ output_path=_OUTPUT_PATH.value,
112
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
113
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
114
+ quantize=_QUANTIZE.value,
115
+ config=config,
116
+ lora_ranks=_LORA_RANKS.value,
117
+ export_config=_create_export_config(
118
+ _PREFILL_SEQ_LENS.value, _KV_CACHE_MAX_LEN.value
119
+ ),
120
+ )
121
+
122
+
123
+ if __name__ == '__main__':
124
+ app.run(main)
@@ -0,0 +1,14 @@
1
+ # Copyright 2025 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,96 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting a Gemma3 model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ from absl import app
22
+ from absl import flags
23
+ from ai_edge_torch.generative.examples.gemma3 import gemma3
24
+ from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
+
27
+ _MODEL_SIZE = flags.DEFINE_string(
28
+ 'model_size',
29
+ '1b',
30
+ 'The size of the model to convert.',
31
+ )
32
+
33
+ _CHECKPOINT_PATH = flags.DEFINE_string(
34
+ 'checkpoint_path',
35
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma3-1b'),
36
+ 'The path to the model checkpoint, or directory holding the checkpoint.',
37
+ )
38
+ _OUTPUT_PATH = flags.DEFINE_string(
39
+ 'output_path',
40
+ '/tmp/',
41
+ 'The path to export the tflite model.',
42
+ )
43
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44
+ 'output_name_prefix',
45
+ 'gemma3',
46
+ 'The prefix of the output tflite model name.',
47
+ )
48
+ _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
49
+ 'prefill_seq_lens',
50
+ (8, 64, 128, 256, 512, 1024),
51
+ 'List of the maximum sizes of prefill input tensors.',
52
+ )
53
+ _KV_CACHE_MAX_LEN = flags.DEFINE_integer(
54
+ 'kv_cache_max_len',
55
+ 1280,
56
+ 'The maximum size of KV cache buffer, including both prefill and decode.',
57
+ )
58
+ _QUANTIZE = flags.DEFINE_bool(
59
+ 'quantize',
60
+ True,
61
+ 'Whether the model should be quantized.',
62
+ )
63
+ _LORA_RANKS = flags.DEFINE_multi_integer(
64
+ 'lora_ranks',
65
+ None,
66
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
67
+ )
68
+
69
+
70
+ def main(_):
71
+ if _MODEL_SIZE.value == '1b':
72
+ pytorch_model = gemma3.build_model_1b(
73
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
74
+ )
75
+ config = pytorch_model.config
76
+ elif _MODEL_SIZE.value == '4b':
77
+ pytorch_model = gemma3.build_model_4b(
78
+ _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
79
+ )
80
+ config = pytorch_model.config.decoder_config
81
+ else:
82
+ raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
83
+ converter.convert_to_tflite(
84
+ pytorch_model,
85
+ output_path=_OUTPUT_PATH.value,
86
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
87
+ prefill_seq_len=_PREFILL_SEQ_LENS.value,
88
+ quantize=_QUANTIZE.value,
89
+ config=config,
90
+ lora_ranks=_LORA_RANKS.value,
91
+ export_config=ExportConfig(),
92
+ )
93
+
94
+
95
+ if __name__ == '__main__':
96
+ app.run(main)