ai-edge-torch-nightly 0.4.0.dev20250310__py3-none-any.whl → 0.4.0.dev20250312__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma3/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py +124 -0
- ai_edge_torch/generative/examples/gemma3/cpu_only/__init__.py +14 -0
- ai_edge_torch/generative/examples/gemma3/cpu_only/convert_gemma3_to_tflite.py +96 -0
- ai_edge_torch/generative/examples/gemma3/cpu_only/decoder.py +463 -0
- ai_edge_torch/generative/examples/gemma3/cpu_only/gemma3.py +212 -0
- ai_edge_torch/generative/examples/gemma3/cpu_only/image_encoder.py +149 -0
- ai_edge_torch/generative/examples/gemma3/decoder.py +436 -0
- ai_edge_torch/generative/examples/gemma3/gemma3.py +176 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -3
- ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +14 -3
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250310.dist-info → ai_edge_torch_nightly-0.4.0.dev20250312.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.4.0.dev20250310.dist-info → ai_edge_torch_nightly-0.4.0.dev20250312.dist-info}/RECORD +17 -8
- {ai_edge_torch_nightly-0.4.0.dev20250310.dist-info → ai_edge_torch_nightly-0.4.0.dev20250312.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250310.dist-info → ai_edge_torch_nightly-0.4.0.dev20250312.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.4.0.dev20250310.dist-info → ai_edge_torch_nightly-0.4.0.dev20250312.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,124 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting a Gemma3 model to multi-signature tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
24
|
+
from ai_edge_torch.generative.layers.experimental import kv_cache
|
25
|
+
from ai_edge_torch.generative.utilities import converter
|
26
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
27
|
+
import torch
|
28
|
+
|
29
|
+
_MODEL_SIZE = flags.DEFINE_string(
|
30
|
+
'model_size',
|
31
|
+
'1b',
|
32
|
+
'The size of the model to convert.',
|
33
|
+
)
|
34
|
+
|
35
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
36
|
+
'checkpoint_path',
|
37
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma3-1b'),
|
38
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
39
|
+
)
|
40
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
41
|
+
'output_path',
|
42
|
+
'/tmp/',
|
43
|
+
'The path to export the tflite model.',
|
44
|
+
)
|
45
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
46
|
+
'output_name_prefix',
|
47
|
+
'gemma3',
|
48
|
+
'The prefix of the output tflite model name.',
|
49
|
+
)
|
50
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
51
|
+
'prefill_seq_lens',
|
52
|
+
(32, 64, 128, 256, 512, 1024),
|
53
|
+
'List of the maximum sizes of prefill input tensors.',
|
54
|
+
)
|
55
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
56
|
+
'kv_cache_max_len',
|
57
|
+
2048,
|
58
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
59
|
+
)
|
60
|
+
_QUANTIZE = flags.DEFINE_bool(
|
61
|
+
'quantize',
|
62
|
+
False,
|
63
|
+
'Whether the model should be quantized.',
|
64
|
+
)
|
65
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
66
|
+
'lora_ranks',
|
67
|
+
None,
|
68
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
69
|
+
)
|
70
|
+
|
71
|
+
|
72
|
+
def _create_mask(mask_len, kv_cache_max_len):
|
73
|
+
mask = torch.full(
|
74
|
+
(mask_len, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
75
|
+
)
|
76
|
+
mask = torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
77
|
+
return mask
|
78
|
+
|
79
|
+
|
80
|
+
def _create_export_config(
|
81
|
+
prefill_seq_lens: list[int], kv_cache_max_len: int
|
82
|
+
) -> ExportConfig:
|
83
|
+
"""Creates the export config for the model."""
|
84
|
+
export_config = ExportConfig()
|
85
|
+
if isinstance(prefill_seq_lens, list):
|
86
|
+
prefill_mask = [_create_mask(i, kv_cache_max_len) for i in prefill_seq_lens]
|
87
|
+
else:
|
88
|
+
prefill_mask = _create_mask(prefill_seq_lens, kv_cache_max_len)
|
89
|
+
|
90
|
+
export_config.prefill_mask = prefill_mask
|
91
|
+
|
92
|
+
decode_mask = torch.full(
|
93
|
+
(1, kv_cache_max_len), float('-inf'), dtype=torch.float32
|
94
|
+
)
|
95
|
+
decode_mask = torch.triu(decode_mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
96
|
+
export_config.decode_mask = decode_mask
|
97
|
+
export_config.kvcache_cls = kv_cache.KVCacheTransposed
|
98
|
+
return export_config
|
99
|
+
|
100
|
+
|
101
|
+
def main(_):
|
102
|
+
if _MODEL_SIZE.value == '1b':
|
103
|
+
pytorch_model = gemma3.build_model_1b(
|
104
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
105
|
+
)
|
106
|
+
config = pytorch_model.config
|
107
|
+
else:
|
108
|
+
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
|
109
|
+
converter.convert_to_tflite(
|
110
|
+
pytorch_model,
|
111
|
+
output_path=_OUTPUT_PATH.value,
|
112
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
113
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
114
|
+
quantize=_QUANTIZE.value,
|
115
|
+
config=config,
|
116
|
+
lora_ranks=_LORA_RANKS.value,
|
117
|
+
export_config=_create_export_config(
|
118
|
+
_PREFILL_SEQ_LENS.value, _KV_CACHE_MAX_LEN.value
|
119
|
+
),
|
120
|
+
)
|
121
|
+
|
122
|
+
|
123
|
+
if __name__ == '__main__':
|
124
|
+
app.run(main)
|
@@ -0,0 +1,14 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
@@ -0,0 +1,96 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting a Gemma3 model to multi-signature tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.gemma3 import gemma3
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
25
|
+
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
26
|
+
|
27
|
+
_MODEL_SIZE = flags.DEFINE_string(
|
28
|
+
'model_size',
|
29
|
+
'1b',
|
30
|
+
'The size of the model to convert.',
|
31
|
+
)
|
32
|
+
|
33
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
34
|
+
'checkpoint_path',
|
35
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma3-1b'),
|
36
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
37
|
+
)
|
38
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
39
|
+
'output_path',
|
40
|
+
'/tmp/',
|
41
|
+
'The path to export the tflite model.',
|
42
|
+
)
|
43
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
44
|
+
'output_name_prefix',
|
45
|
+
'gemma3',
|
46
|
+
'The prefix of the output tflite model name.',
|
47
|
+
)
|
48
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
49
|
+
'prefill_seq_lens',
|
50
|
+
(8, 64, 128, 256, 512, 1024),
|
51
|
+
'List of the maximum sizes of prefill input tensors.',
|
52
|
+
)
|
53
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
54
|
+
'kv_cache_max_len',
|
55
|
+
1280,
|
56
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
57
|
+
)
|
58
|
+
_QUANTIZE = flags.DEFINE_bool(
|
59
|
+
'quantize',
|
60
|
+
True,
|
61
|
+
'Whether the model should be quantized.',
|
62
|
+
)
|
63
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
64
|
+
'lora_ranks',
|
65
|
+
None,
|
66
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
def main(_):
|
71
|
+
if _MODEL_SIZE.value == '1b':
|
72
|
+
pytorch_model = gemma3.build_model_1b(
|
73
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
74
|
+
)
|
75
|
+
config = pytorch_model.config
|
76
|
+
elif _MODEL_SIZE.value == '4b':
|
77
|
+
pytorch_model = gemma3.build_model_4b(
|
78
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
79
|
+
)
|
80
|
+
config = pytorch_model.config.decoder_config
|
81
|
+
else:
|
82
|
+
raise ValueError(f'Unsupported model size: {_MODEL_SIZE.value}')
|
83
|
+
converter.convert_to_tflite(
|
84
|
+
pytorch_model,
|
85
|
+
output_path=_OUTPUT_PATH.value,
|
86
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
87
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
88
|
+
quantize=_QUANTIZE.value,
|
89
|
+
config=config,
|
90
|
+
lora_ranks=_LORA_RANKS.value,
|
91
|
+
export_config=ExportConfig(),
|
92
|
+
)
|
93
|
+
|
94
|
+
|
95
|
+
if __name__ == '__main__':
|
96
|
+
app.run(main)
|