ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/llama/llama.py +29 -25
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/phi/phi3.py +26 -23
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
- ai_edge_torch/generative/examples/smollm/verify.py +18 -2
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
- ai_edge_torch/generative/layers/attention.py +45 -37
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
- ai_edge_torch/generative/utilities/converter.py +100 -47
- ai_edge_torch/generative/utilities/model_builder.py +23 -14
- ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
- ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
- ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
- ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
- ai_edge_torch/odml_torch/_torch_future.py +13 -0
- ai_edge_torch/odml_torch/export.py +6 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +38 -35
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'gemma',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = gemma1.build_2b_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'gemma2',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = gemma2.build_2b_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -15,13 +15,14 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma2 model."""
|
17
17
|
|
18
|
-
from typing import Optional, Tuple
|
18
|
+
from typing import List, Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
25
26
|
from ai_edge_torch.generative.utilities import model_builder
|
26
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
28
|
import torch
|
@@ -103,17 +104,12 @@ class Gemma2(nn.Module):
|
|
103
104
|
config.embedding_dim,
|
104
105
|
config.final_norm_config,
|
105
106
|
)
|
106
|
-
# Gemma2 has same hyper parameters for each layer except for attention
|
107
|
-
# types. Use the first layer.
|
108
|
-
attn_config = config.block_config(0).attn_config
|
109
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
110
|
-
size=config.kv_cache_max,
|
111
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
112
|
-
base=attn_config.rotary_base,
|
113
|
-
)
|
114
107
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
115
108
|
size=config.kv_cache_max,
|
116
109
|
)
|
110
|
+
# Gemma2 has same hyper parameters for each layer except for attention
|
111
|
+
# types. Use the first layer.
|
112
|
+
attn_config = config.block_config(0).attn_config
|
117
113
|
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
118
114
|
size=config.kv_cache_max,
|
119
115
|
window_size=attn_config.sliding_window_size,
|
@@ -140,29 +136,51 @@ class Gemma2(nn.Module):
|
|
140
136
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
141
137
|
f" {self.config.max_seq_len}"
|
142
138
|
)
|
139
|
+
|
140
|
+
# token embeddings of shape (b, t, n_embd)
|
141
|
+
input_embeds = self.tok_embedding(tokens)
|
142
|
+
# RoPE parameters are the same for all blocks. Use the first layer.
|
143
|
+
attn_config = self.config.block_config(0).attn_config
|
144
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
145
|
+
rope = rotary_pos_emb.build_rope(
|
146
|
+
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
147
|
+
)
|
148
|
+
mask = [
|
149
|
+
self.get_attention_mask(
|
150
|
+
self.config.block_config(i).attn_config.attn_type, input_pos
|
151
|
+
)
|
152
|
+
for i in range(self.config.num_layers)
|
153
|
+
]
|
154
|
+
|
155
|
+
return self._forward_with_embeds(
|
156
|
+
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
157
|
+
)
|
158
|
+
|
159
|
+
def _forward_with_embeds(
|
160
|
+
self,
|
161
|
+
input_embeds: torch.Tensor,
|
162
|
+
rope: Tuple[torch.Tensor, torch.Tensor],
|
163
|
+
mask: List[torch.Tensor],
|
164
|
+
input_pos: torch.Tensor,
|
165
|
+
kv_cache: kv_utils.KVCache,
|
166
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
167
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
168
|
+
"""Forwards the model with input embeddings."""
|
143
169
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
144
170
|
"The number of transformer blocks and the number of KV cache entries"
|
145
171
|
" must be the same."
|
146
172
|
)
|
147
173
|
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
# token embeddings of shape (b, t, n_embd)
|
153
|
-
x = self.tok_embedding(tokens)
|
154
|
-
x = x * (self.config.embedding_dim**0.5)
|
155
|
-
|
156
|
-
updated_kv_entires = []
|
174
|
+
if self.config.embedding_scale is not None:
|
175
|
+
input_embeds = input_embeds * self.config.embedding_scale
|
176
|
+
x = input_embeds
|
177
|
+
updated_kv_entries = []
|
157
178
|
for i, block in enumerate(self.transformer_blocks):
|
158
|
-
mask = self.get_attention_mask(
|
159
|
-
block.config.attn_config.attn_type, input_pos
|
160
|
-
)
|
161
179
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
162
|
-
x, kv_entry = block(x,
|
180
|
+
x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
|
163
181
|
if kv_entry:
|
164
|
-
|
165
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
182
|
+
updated_kv_entries.append(kv_entry)
|
183
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
166
184
|
|
167
185
|
if export_config is not None:
|
168
186
|
if (
|
@@ -228,11 +246,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
228
246
|
)
|
229
247
|
|
230
248
|
num_layers = 26
|
249
|
+
embedding_dim = 2304
|
231
250
|
config = cfg.ModelConfig(
|
232
251
|
vocab_size=256000,
|
233
252
|
num_layers=num_layers,
|
234
253
|
max_seq_len=8192,
|
235
|
-
embedding_dim=
|
254
|
+
embedding_dim=embedding_dim,
|
255
|
+
embedding_scale=embedding_dim**0.5,
|
236
256
|
kv_cache_max_len=kv_cache_max_len,
|
237
257
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
238
258
|
final_norm_config=norm_config,
|
@@ -249,6 +269,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
249
269
|
config.num_layers = 2
|
250
270
|
config.max_seq_len = 2 * kv_cache_max_len
|
251
271
|
config.embedding_dim = 128
|
272
|
+
config.embedding_scale = config.embedding_dim**0.5
|
252
273
|
config.block_configs = config.block_configs[: config.num_layers]
|
253
274
|
for block_config in config.block_configs:
|
254
275
|
block_config.attn_config.num_heads = 4
|
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
35
35
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
|
36
36
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
37
37
|
)
|
38
|
-
|
39
|
-
'
|
38
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
39
|
+
'output_path',
|
40
40
|
'/tmp/',
|
41
|
-
'The
|
41
|
+
'The path to export the tflite model.',
|
42
|
+
)
|
43
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
44
|
+
'output_name_prefix',
|
45
|
+
'llama',
|
46
|
+
'The prefix of the output tflite model name.',
|
42
47
|
)
|
43
48
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
44
49
|
'prefill_seq_lens',
|
@@ -55,6 +60,11 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
55
60
|
True,
|
56
61
|
'Whether the model should be quantized.',
|
57
62
|
)
|
63
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
64
|
+
'lora_ranks',
|
65
|
+
None,
|
66
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
67
|
+
)
|
58
68
|
|
59
69
|
_BUILDER = {
|
60
70
|
'1b': llama.build_1b_model,
|
@@ -66,13 +76,13 @@ def main(_):
|
|
66
76
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
67
77
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
68
78
|
)
|
69
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
70
|
-
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
71
79
|
converter.convert_to_tflite(
|
72
80
|
pytorch_model,
|
73
|
-
|
81
|
+
output_path=_OUTPUT_PATH.value,
|
82
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
74
83
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
75
84
|
quantize=_QUANTIZE.value,
|
85
|
+
lora_ranks=_LORA_RANKS.value,
|
76
86
|
export_config=ExportConfig(),
|
77
87
|
)
|
78
88
|
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Example of building Llama 3.2 models."""
|
17
17
|
|
18
|
+
from functools import partial
|
18
19
|
import math
|
19
20
|
from typing import Tuple
|
20
21
|
|
@@ -26,8 +27,8 @@ TENSOR_NAMES = model_builder.TENSOR_NAMES
|
|
26
27
|
|
27
28
|
|
28
29
|
def _build_llama3_rope_cache(
|
29
|
-
|
30
|
-
|
30
|
+
input_pos: torch.Tensor,
|
31
|
+
n_elem: int,
|
31
32
|
base: int,
|
32
33
|
condense_ratio: int,
|
33
34
|
dtype: torch.dtype,
|
@@ -36,8 +37,9 @@ def _build_llama3_rope_cache(
|
|
36
37
|
low_freq_factor: float,
|
37
38
|
high_freq_factor: float,
|
38
39
|
max_seq_len: int,
|
40
|
+
**kwargs,
|
39
41
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
40
|
-
"""
|
42
|
+
"""Computes Rotary Positional Embeddings for Llama 3.2 model.
|
41
43
|
|
42
44
|
It's a modified version of attn_utils.build_rope_cache with additional
|
43
45
|
arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin
|
@@ -47,13 +49,12 @@ def _build_llama3_rope_cache(
|
|
47
49
|
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307
|
48
50
|
|
49
51
|
Args:
|
50
|
-
|
51
|
-
|
52
|
-
base (int
|
53
|
-
condense_ratio (int
|
54
|
-
|
55
|
-
|
56
|
-
device (torch.device, optional): Output tensor's data type.
|
52
|
+
input_pos (torch.Tensor): the given input sequence positions
|
53
|
+
n_elem (int): Each sequence's dimmension.
|
54
|
+
base (int): Rope base value.
|
55
|
+
condense_ratio (int): The ratio by which sequence indicies are condensed.
|
56
|
+
dtype (torch.dtype): Output tensor's data type.
|
57
|
+
device (torch.device): Output tensor's data type.
|
57
58
|
factor (float): Factor to scale theta down for tokens in long range in the
|
58
59
|
sequence.
|
59
60
|
low_freq_factor (float): Factor to determine if tokens are in long range
|
@@ -66,7 +67,7 @@ def _build_llama3_rope_cache(
|
|
66
67
|
Returns:
|
67
68
|
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
68
69
|
"""
|
69
|
-
theta = 1.0 / (base ** (torch.arange(0,
|
70
|
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
70
71
|
low_freq_wavelen = max_seq_len / low_freq_factor
|
71
72
|
high_freq_wavelen = max_seq_len / high_freq_factor
|
72
73
|
wavelen = 2 * math.pi / theta
|
@@ -81,7 +82,7 @@ def _build_llama3_rope_cache(
|
|
81
82
|
is_medium = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
82
83
|
theta = torch.where(is_medium, smoothed_theta, theta)
|
83
84
|
|
84
|
-
seq_idx =
|
85
|
+
seq_idx = input_pos / condense_ratio
|
85
86
|
idx_theta = torch.outer(seq_idx, theta)
|
86
87
|
cos = torch.cos(idx_theta).to(dtype=dtype, device=device)
|
87
88
|
sin = torch.sin(idx_theta).to(dtype=dtype, device=device)
|
@@ -97,18 +98,6 @@ class Llama(model_builder.DecoderOnlyModel):
|
|
97
98
|
def __init__(self, config: cfg.ModelConfig):
|
98
99
|
super().__init__(config)
|
99
100
|
attn_config = self.config.block_config(0).attn_config
|
100
|
-
self.rope_cache = _build_llama3_rope_cache(
|
101
|
-
size=self.config.kv_cache_max,
|
102
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
103
|
-
base=attn_config.rotary_base,
|
104
|
-
condense_ratio=1,
|
105
|
-
dtype=torch.float32,
|
106
|
-
device=torch.device("cpu"),
|
107
|
-
factor=32.0,
|
108
|
-
low_freq_factor=1.0,
|
109
|
-
high_freq_factor=4.0,
|
110
|
-
max_seq_len=self.config.max_seq_len,
|
111
|
-
)
|
112
101
|
|
113
102
|
|
114
103
|
def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -140,15 +129,30 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
140
129
|
pre_attention_norm_config=norm_config,
|
141
130
|
post_attention_norm_config=norm_config,
|
142
131
|
)
|
132
|
+
|
133
|
+
max_seq_len = 8192
|
134
|
+
# Create the RoPE callable
|
135
|
+
build_rope = partial(
|
136
|
+
_build_llama3_rope_cache,
|
137
|
+
condense_ratio=1,
|
138
|
+
dtype=torch.float32,
|
139
|
+
device=torch.device("cpu"),
|
140
|
+
factor=32.0,
|
141
|
+
low_freq_factor=1.0,
|
142
|
+
high_freq_factor=4.0,
|
143
|
+
max_seq_len=max_seq_len,
|
144
|
+
)
|
145
|
+
|
143
146
|
config = cfg.ModelConfig(
|
144
147
|
vocab_size=128256,
|
145
148
|
num_layers=16,
|
146
|
-
max_seq_len=
|
149
|
+
max_seq_len=max_seq_len,
|
147
150
|
embedding_dim=2048,
|
148
151
|
kv_cache_max_len=kv_cache_max_len,
|
149
152
|
block_configs=block_config,
|
150
153
|
final_norm_config=norm_config,
|
151
154
|
enable_hlfb=True,
|
155
|
+
build_rope=build_rope,
|
152
156
|
)
|
153
157
|
return config
|
154
158
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'openelm',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,22 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = openelm.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = (
|
60
|
-
f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
|
-
)
|
62
|
-
|
63
68
|
converter.convert_to_tflite(
|
64
69
|
pytorch_model,
|
65
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
66
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
67
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
68
75
|
export_config=ExportConfig(),
|
69
76
|
)
|
70
77
|
|
@@ -40,10 +40,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
40
40
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
|
41
41
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
42
42
|
)
|
43
|
-
|
44
|
-
'
|
43
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
44
|
+
'output_path',
|
45
45
|
'/tmp/',
|
46
|
-
'The
|
46
|
+
'The path to export the tflite model.',
|
47
|
+
)
|
48
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
49
|
+
'output_name_prefix',
|
50
|
+
'paligemma',
|
51
|
+
'The prefix of the output tflite model name.',
|
47
52
|
)
|
48
53
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
49
54
|
'prefill_seq_len',
|
@@ -73,11 +78,11 @@ def main(_):
|
|
73
78
|
version=int(_VERSION.value),
|
74
79
|
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
|
75
80
|
)
|
76
|
-
|
77
|
-
output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
81
|
+
|
78
82
|
converter.convert_to_tflite(
|
79
83
|
pytorch_model,
|
80
|
-
|
84
|
+
output_path=_OUTPUT_PATH.value,
|
85
|
+
output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
|
81
86
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
82
87
|
pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
|
83
88
|
quantize=_QUANTIZE.value,
|
@@ -26,13 +26,18 @@ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
|
26
26
|
|
27
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
28
|
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
29
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'phi3',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = phi3.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'phi2',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = phi2.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
|
17
17
|
|
18
|
+
from functools import partial
|
18
19
|
import math
|
19
20
|
from typing import Tuple
|
20
21
|
|
@@ -93,40 +94,41 @@ ROPE_SHORT_FACTOR = [
|
|
93
94
|
]
|
94
95
|
|
95
96
|
|
96
|
-
def
|
97
|
-
|
98
|
-
|
97
|
+
def _build_phi3_rope(
|
98
|
+
input_pos: int,
|
99
|
+
n_elem: int,
|
99
100
|
base: int,
|
100
101
|
condense_ratio: int,
|
101
102
|
dtype: torch.dtype,
|
102
103
|
device: torch.device,
|
103
104
|
theta_factors: torch.Tensor,
|
104
105
|
scale: float,
|
106
|
+
**kwargs,
|
105
107
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
106
|
-
"""
|
108
|
+
"""Computes Rotary Positional Embeddings for Phi-3.5 model.
|
107
109
|
|
108
110
|
It's a modified version of attn_utils.build_rope_cache with additional
|
109
111
|
arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
|
110
112
|
Cos values with scaling factors for quick lookup during the inference.
|
111
113
|
|
112
114
|
Args:
|
113
|
-
|
114
|
-
|
115
|
+
input_pos (torch.Tensor): the given input sequence positions
|
116
|
+
n_elem (int): Each sequence's dimmension.
|
115
117
|
base (int, optional): Rope base value.
|
116
118
|
condense_ratio (int, optional): The ratio by which sequence indicies are
|
117
119
|
condensed.
|
118
120
|
dtype (torch.dtype, optional): Output tensor's data type.
|
119
121
|
device (torch.device, optional): Output tensor's data type.
|
120
|
-
theta_factors (torch.Tensor, optional): A tensor of shape (
|
121
|
-
scale the theta values.
|
122
|
+
theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used
|
123
|
+
to scale the theta values.
|
122
124
|
scale (float, optional): A float used to scale the rope values.
|
123
125
|
|
124
126
|
Returns:
|
125
127
|
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
126
128
|
"""
|
127
|
-
theta = 1.0 / (base ** (torch.arange(0,
|
129
|
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
128
130
|
theta = theta / theta_factors
|
129
|
-
seq_idx =
|
131
|
+
seq_idx = input_pos / condense_ratio
|
130
132
|
idx_theta = torch.outer(seq_idx, theta)
|
131
133
|
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
|
132
134
|
sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
|
@@ -139,18 +141,6 @@ class Phi3_5Mini(model_builder.DecoderOnlyModel):
|
|
139
141
|
def __init__(self, config: cfg.ModelConfig):
|
140
142
|
super().__init__(config)
|
141
143
|
attn_config = self.config.block_config(0).attn_config
|
142
|
-
self.rope_cache = _build_rope_cache(
|
143
|
-
size=self.config.kv_cache_max,
|
144
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
145
|
-
base=attn_config.rotary_base,
|
146
|
-
condense_ratio=1,
|
147
|
-
dtype=torch.float32,
|
148
|
-
device=torch.device("cpu"),
|
149
|
-
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
|
150
|
-
scale=math.sqrt(
|
151
|
-
1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
|
152
|
-
),
|
153
|
-
)
|
154
144
|
|
155
145
|
|
156
146
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -183,16 +173,29 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
183
173
|
pre_attention_norm_config=norm_config,
|
184
174
|
post_attention_norm_config=norm_config,
|
185
175
|
)
|
176
|
+
max_seq_len = 4096
|
177
|
+
# Create the RoPE callable
|
178
|
+
build_rope = partial(
|
179
|
+
_build_phi3_rope,
|
180
|
+
condense_ratio=1,
|
181
|
+
dtype=torch.float32,
|
182
|
+
device=torch.device("cpu"),
|
183
|
+
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
|
184
|
+
scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)),
|
185
|
+
max_seq_len=max_seq_len,
|
186
|
+
)
|
187
|
+
|
186
188
|
config = cfg.ModelConfig(
|
187
189
|
vocab_size=32064,
|
188
190
|
num_layers=32,
|
189
|
-
max_seq_len=
|
191
|
+
max_seq_len=max_seq_len,
|
190
192
|
kv_cache_max_len=kv_cache_max_len,
|
191
193
|
embedding_dim=3072,
|
192
194
|
block_configs=block_config,
|
193
195
|
final_norm_config=norm_config,
|
194
196
|
lm_head_share_weight_with_embedding=False,
|
195
197
|
enable_hlfb=True,
|
198
|
+
build_rope=build_rope,
|
196
199
|
)
|
197
200
|
return config
|
198
201
|
|