ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/llama/llama.py +29 -25
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
- ai_edge_torch/generative/examples/phi/phi3.py +26 -23
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
- ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
- ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
- ai_edge_torch/generative/examples/smollm/verify.py +18 -2
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
- ai_edge_torch/generative/layers/attention.py +45 -37
- ai_edge_torch/generative/layers/lora.py +557 -0
- ai_edge_torch/generative/layers/model_config.py +6 -2
- ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
- ai_edge_torch/generative/test/test_lora.py +147 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
- ai_edge_torch/generative/utilities/converter.py +100 -47
- ai_edge_torch/generative/utilities/model_builder.py +23 -14
- ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
- ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
- ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
- ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
- ai_edge_torch/odml_torch/_torch_future.py +13 -0
- ai_edge_torch/odml_torch/export.py +6 -2
- ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +38 -35
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'gemma',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = gemma1.build_2b_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'gemma2',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = gemma2.build_2b_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -15,13 +15,14 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma2 model."""
|
17
17
|
|
18
|
-
from typing import Optional, Tuple
|
18
|
+
from typing import List, Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
+
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
25
26
|
from ai_edge_torch.generative.utilities import model_builder
|
26
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
27
28
|
import torch
|
@@ -103,17 +104,12 @@ class Gemma2(nn.Module):
|
|
103
104
|
config.embedding_dim,
|
104
105
|
config.final_norm_config,
|
105
106
|
)
|
106
|
-
# Gemma2 has same hyper parameters for each layer except for attention
|
107
|
-
# types. Use the first layer.
|
108
|
-
attn_config = config.block_config(0).attn_config
|
109
|
-
self.rope_cache = attn_utils.build_rope_cache(
|
110
|
-
size=config.kv_cache_max,
|
111
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
112
|
-
base=attn_config.rotary_base,
|
113
|
-
)
|
114
107
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
115
108
|
size=config.kv_cache_max,
|
116
109
|
)
|
110
|
+
# Gemma2 has same hyper parameters for each layer except for attention
|
111
|
+
# types. Use the first layer.
|
112
|
+
attn_config = config.block_config(0).attn_config
|
117
113
|
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
118
114
|
size=config.kv_cache_max,
|
119
115
|
window_size=attn_config.sliding_window_size,
|
@@ -140,29 +136,51 @@ class Gemma2(nn.Module):
|
|
140
136
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
141
137
|
f" {self.config.max_seq_len}"
|
142
138
|
)
|
139
|
+
|
140
|
+
# token embeddings of shape (b, t, n_embd)
|
141
|
+
input_embeds = self.tok_embedding(tokens)
|
142
|
+
# RoPE parameters are the same for all blocks. Use the first layer.
|
143
|
+
attn_config = self.config.block_config(0).attn_config
|
144
|
+
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
145
|
+
rope = rotary_pos_emb.build_rope(
|
146
|
+
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
147
|
+
)
|
148
|
+
mask = [
|
149
|
+
self.get_attention_mask(
|
150
|
+
self.config.block_config(i).attn_config.attn_type, input_pos
|
151
|
+
)
|
152
|
+
for i in range(self.config.num_layers)
|
153
|
+
]
|
154
|
+
|
155
|
+
return self._forward_with_embeds(
|
156
|
+
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
157
|
+
)
|
158
|
+
|
159
|
+
def _forward_with_embeds(
|
160
|
+
self,
|
161
|
+
input_embeds: torch.Tensor,
|
162
|
+
rope: Tuple[torch.Tensor, torch.Tensor],
|
163
|
+
mask: List[torch.Tensor],
|
164
|
+
input_pos: torch.Tensor,
|
165
|
+
kv_cache: kv_utils.KVCache,
|
166
|
+
export_config: Optional[model_builder.ExportConfig] = None,
|
167
|
+
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
168
|
+
"""Forwards the model with input embeddings."""
|
143
169
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
144
170
|
"The number of transformer blocks and the number of KV cache entries"
|
145
171
|
" must be the same."
|
146
172
|
)
|
147
173
|
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
# token embeddings of shape (b, t, n_embd)
|
153
|
-
x = self.tok_embedding(tokens)
|
154
|
-
x = x * (self.config.embedding_dim**0.5)
|
155
|
-
|
156
|
-
updated_kv_entires = []
|
174
|
+
if self.config.embedding_scale is not None:
|
175
|
+
input_embeds = input_embeds * self.config.embedding_scale
|
176
|
+
x = input_embeds
|
177
|
+
updated_kv_entries = []
|
157
178
|
for i, block in enumerate(self.transformer_blocks):
|
158
|
-
mask = self.get_attention_mask(
|
159
|
-
block.config.attn_config.attn_type, input_pos
|
160
|
-
)
|
161
179
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
162
|
-
x, kv_entry = block(x,
|
180
|
+
x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
|
163
181
|
if kv_entry:
|
164
|
-
|
165
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
182
|
+
updated_kv_entries.append(kv_entry)
|
183
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
|
166
184
|
|
167
185
|
if export_config is not None:
|
168
186
|
if (
|
@@ -228,11 +246,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
228
246
|
)
|
229
247
|
|
230
248
|
num_layers = 26
|
249
|
+
embedding_dim = 2304
|
231
250
|
config = cfg.ModelConfig(
|
232
251
|
vocab_size=256000,
|
233
252
|
num_layers=num_layers,
|
234
253
|
max_seq_len=8192,
|
235
|
-
embedding_dim=
|
254
|
+
embedding_dim=embedding_dim,
|
255
|
+
embedding_scale=embedding_dim**0.5,
|
236
256
|
kv_cache_max_len=kv_cache_max_len,
|
237
257
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
238
258
|
final_norm_config=norm_config,
|
@@ -249,6 +269,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
249
269
|
config.num_layers = 2
|
250
270
|
config.max_seq_len = 2 * kv_cache_max_len
|
251
271
|
config.embedding_dim = 128
|
272
|
+
config.embedding_scale = config.embedding_dim**0.5
|
252
273
|
config.block_configs = config.block_configs[: config.num_layers]
|
253
274
|
for block_config in config.block_configs:
|
254
275
|
block_config.attn_config.num_heads = 4
|
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
35
35
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
|
36
36
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
37
37
|
)
|
38
|
-
|
39
|
-
'
|
38
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
39
|
+
'output_path',
|
40
40
|
'/tmp/',
|
41
|
-
'The
|
41
|
+
'The path to export the tflite model.',
|
42
|
+
)
|
43
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
44
|
+
'output_name_prefix',
|
45
|
+
'llama',
|
46
|
+
'The prefix of the output tflite model name.',
|
42
47
|
)
|
43
48
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
44
49
|
'prefill_seq_lens',
|
@@ -55,6 +60,11 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
55
60
|
True,
|
56
61
|
'Whether the model should be quantized.',
|
57
62
|
)
|
63
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
64
|
+
'lora_ranks',
|
65
|
+
None,
|
66
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
67
|
+
)
|
58
68
|
|
59
69
|
_BUILDER = {
|
60
70
|
'1b': llama.build_1b_model,
|
@@ -66,13 +76,13 @@ def main(_):
|
|
66
76
|
pytorch_model = _BUILDER[_MODEL_SIZE.value](
|
67
77
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
68
78
|
)
|
69
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
70
|
-
output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
71
79
|
converter.convert_to_tflite(
|
72
80
|
pytorch_model,
|
73
|
-
|
81
|
+
output_path=_OUTPUT_PATH.value,
|
82
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
74
83
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
75
84
|
quantize=_QUANTIZE.value,
|
85
|
+
lora_ranks=_LORA_RANKS.value,
|
76
86
|
export_config=ExportConfig(),
|
77
87
|
)
|
78
88
|
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Example of building Llama 3.2 models."""
|
17
17
|
|
18
|
+
from functools import partial
|
18
19
|
import math
|
19
20
|
from typing import Tuple
|
20
21
|
|
@@ -26,8 +27,8 @@ TENSOR_NAMES = model_builder.TENSOR_NAMES
|
|
26
27
|
|
27
28
|
|
28
29
|
def _build_llama3_rope_cache(
|
29
|
-
|
30
|
-
|
30
|
+
input_pos: torch.Tensor,
|
31
|
+
n_elem: int,
|
31
32
|
base: int,
|
32
33
|
condense_ratio: int,
|
33
34
|
dtype: torch.dtype,
|
@@ -36,8 +37,9 @@ def _build_llama3_rope_cache(
|
|
36
37
|
low_freq_factor: float,
|
37
38
|
high_freq_factor: float,
|
38
39
|
max_seq_len: int,
|
40
|
+
**kwargs,
|
39
41
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
40
|
-
"""
|
42
|
+
"""Computes Rotary Positional Embeddings for Llama 3.2 model.
|
41
43
|
|
42
44
|
It's a modified version of attn_utils.build_rope_cache with additional
|
43
45
|
arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin
|
@@ -47,13 +49,12 @@ def _build_llama3_rope_cache(
|
|
47
49
|
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307
|
48
50
|
|
49
51
|
Args:
|
50
|
-
|
51
|
-
|
52
|
-
base (int
|
53
|
-
condense_ratio (int
|
54
|
-
|
55
|
-
|
56
|
-
device (torch.device, optional): Output tensor's data type.
|
52
|
+
input_pos (torch.Tensor): the given input sequence positions
|
53
|
+
n_elem (int): Each sequence's dimmension.
|
54
|
+
base (int): Rope base value.
|
55
|
+
condense_ratio (int): The ratio by which sequence indicies are condensed.
|
56
|
+
dtype (torch.dtype): Output tensor's data type.
|
57
|
+
device (torch.device): Output tensor's data type.
|
57
58
|
factor (float): Factor to scale theta down for tokens in long range in the
|
58
59
|
sequence.
|
59
60
|
low_freq_factor (float): Factor to determine if tokens are in long range
|
@@ -66,7 +67,7 @@ def _build_llama3_rope_cache(
|
|
66
67
|
Returns:
|
67
68
|
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
68
69
|
"""
|
69
|
-
theta = 1.0 / (base ** (torch.arange(0,
|
70
|
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
70
71
|
low_freq_wavelen = max_seq_len / low_freq_factor
|
71
72
|
high_freq_wavelen = max_seq_len / high_freq_factor
|
72
73
|
wavelen = 2 * math.pi / theta
|
@@ -81,7 +82,7 @@ def _build_llama3_rope_cache(
|
|
81
82
|
is_medium = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
|
82
83
|
theta = torch.where(is_medium, smoothed_theta, theta)
|
83
84
|
|
84
|
-
seq_idx =
|
85
|
+
seq_idx = input_pos / condense_ratio
|
85
86
|
idx_theta = torch.outer(seq_idx, theta)
|
86
87
|
cos = torch.cos(idx_theta).to(dtype=dtype, device=device)
|
87
88
|
sin = torch.sin(idx_theta).to(dtype=dtype, device=device)
|
@@ -97,18 +98,6 @@ class Llama(model_builder.DecoderOnlyModel):
|
|
97
98
|
def __init__(self, config: cfg.ModelConfig):
|
98
99
|
super().__init__(config)
|
99
100
|
attn_config = self.config.block_config(0).attn_config
|
100
|
-
self.rope_cache = _build_llama3_rope_cache(
|
101
|
-
size=self.config.kv_cache_max,
|
102
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
103
|
-
base=attn_config.rotary_base,
|
104
|
-
condense_ratio=1,
|
105
|
-
dtype=torch.float32,
|
106
|
-
device=torch.device("cpu"),
|
107
|
-
factor=32.0,
|
108
|
-
low_freq_factor=1.0,
|
109
|
-
high_freq_factor=4.0,
|
110
|
-
max_seq_len=self.config.max_seq_len,
|
111
|
-
)
|
112
101
|
|
113
102
|
|
114
103
|
def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -140,15 +129,30 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
140
129
|
pre_attention_norm_config=norm_config,
|
141
130
|
post_attention_norm_config=norm_config,
|
142
131
|
)
|
132
|
+
|
133
|
+
max_seq_len = 8192
|
134
|
+
# Create the RoPE callable
|
135
|
+
build_rope = partial(
|
136
|
+
_build_llama3_rope_cache,
|
137
|
+
condense_ratio=1,
|
138
|
+
dtype=torch.float32,
|
139
|
+
device=torch.device("cpu"),
|
140
|
+
factor=32.0,
|
141
|
+
low_freq_factor=1.0,
|
142
|
+
high_freq_factor=4.0,
|
143
|
+
max_seq_len=max_seq_len,
|
144
|
+
)
|
145
|
+
|
143
146
|
config = cfg.ModelConfig(
|
144
147
|
vocab_size=128256,
|
145
148
|
num_layers=16,
|
146
|
-
max_seq_len=
|
149
|
+
max_seq_len=max_seq_len,
|
147
150
|
embedding_dim=2048,
|
148
151
|
kv_cache_max_len=kv_cache_max_len,
|
149
152
|
block_configs=block_config,
|
150
153
|
final_norm_config=norm_config,
|
151
154
|
enable_hlfb=True,
|
155
|
+
build_rope=build_rope,
|
152
156
|
)
|
153
157
|
return config
|
154
158
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'openelm',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,22 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = openelm.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = (
|
60
|
-
f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
61
|
-
)
|
62
|
-
|
63
68
|
converter.convert_to_tflite(
|
64
69
|
pytorch_model,
|
65
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
66
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
67
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
68
75
|
export_config=ExportConfig(),
|
69
76
|
)
|
70
77
|
|
@@ -40,10 +40,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
40
40
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
|
41
41
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
42
42
|
)
|
43
|
-
|
44
|
-
'
|
43
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
44
|
+
'output_path',
|
45
45
|
'/tmp/',
|
46
|
-
'The
|
46
|
+
'The path to export the tflite model.',
|
47
|
+
)
|
48
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
49
|
+
'output_name_prefix',
|
50
|
+
'paligemma',
|
51
|
+
'The prefix of the output tflite model name.',
|
47
52
|
)
|
48
53
|
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
49
54
|
'prefill_seq_len',
|
@@ -73,11 +78,11 @@ def main(_):
|
|
73
78
|
version=int(_VERSION.value),
|
74
79
|
kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
|
75
80
|
)
|
76
|
-
|
77
|
-
output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
81
|
+
|
78
82
|
converter.convert_to_tflite(
|
79
83
|
pytorch_model,
|
80
|
-
|
84
|
+
output_path=_OUTPUT_PATH.value,
|
85
|
+
output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
|
81
86
|
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
82
87
|
pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
|
83
88
|
quantize=_QUANTIZE.value,
|
@@ -26,13 +26,18 @@ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
|
26
26
|
|
27
27
|
_CHECKPOINT_PATH = flags.DEFINE_string(
|
28
28
|
'checkpoint_path',
|
29
|
-
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/
|
29
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'phi3',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = phi3.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
|
|
29
29
|
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
|
30
30
|
'The path to the model checkpoint, or directory holding the checkpoint.',
|
31
31
|
)
|
32
|
-
|
33
|
-
'
|
32
|
+
_OUTPUT_PATH = flags.DEFINE_string(
|
33
|
+
'output_path',
|
34
34
|
'/tmp/',
|
35
|
-
'The
|
35
|
+
'The path to export the tflite model.',
|
36
|
+
)
|
37
|
+
_OUTPUT_NAME_PREFIX = flags.DEFINE_string(
|
38
|
+
'output_name_prefix',
|
39
|
+
'phi2',
|
40
|
+
'The prefix of the output tflite model name.',
|
36
41
|
)
|
37
42
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
43
|
'prefill_seq_lens',
|
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
49
54
|
True,
|
50
55
|
'Whether the model should be quantized.',
|
51
56
|
)
|
57
|
+
_LORA_RANKS = flags.DEFINE_multi_integer(
|
58
|
+
'lora_ranks',
|
59
|
+
None,
|
60
|
+
'If set, the model will be converted with the provided list of LoRA ranks.',
|
61
|
+
)
|
52
62
|
|
53
63
|
|
54
64
|
def main(_):
|
55
65
|
pytorch_model = phi2.build_model(
|
56
66
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
57
67
|
)
|
58
|
-
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
59
|
-
output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
60
68
|
converter.convert_to_tflite(
|
61
69
|
pytorch_model,
|
62
|
-
|
70
|
+
output_path=_OUTPUT_PATH.value,
|
71
|
+
output_name_prefix=_OUTPUT_NAME_PREFIX.value,
|
63
72
|
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
64
73
|
quantize=_QUANTIZE.value,
|
74
|
+
lora_ranks=_LORA_RANKS.value,
|
65
75
|
export_config=ExportConfig(),
|
66
76
|
)
|
67
77
|
|
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
|
17
17
|
|
18
|
+
from functools import partial
|
18
19
|
import math
|
19
20
|
from typing import Tuple
|
20
21
|
|
@@ -93,40 +94,41 @@ ROPE_SHORT_FACTOR = [
|
|
93
94
|
]
|
94
95
|
|
95
96
|
|
96
|
-
def
|
97
|
-
|
98
|
-
|
97
|
+
def _build_phi3_rope(
|
98
|
+
input_pos: int,
|
99
|
+
n_elem: int,
|
99
100
|
base: int,
|
100
101
|
condense_ratio: int,
|
101
102
|
dtype: torch.dtype,
|
102
103
|
device: torch.device,
|
103
104
|
theta_factors: torch.Tensor,
|
104
105
|
scale: float,
|
106
|
+
**kwargs,
|
105
107
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
106
|
-
"""
|
108
|
+
"""Computes Rotary Positional Embeddings for Phi-3.5 model.
|
107
109
|
|
108
110
|
It's a modified version of attn_utils.build_rope_cache with additional
|
109
111
|
arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
|
110
112
|
Cos values with scaling factors for quick lookup during the inference.
|
111
113
|
|
112
114
|
Args:
|
113
|
-
|
114
|
-
|
115
|
+
input_pos (torch.Tensor): the given input sequence positions
|
116
|
+
n_elem (int): Each sequence's dimmension.
|
115
117
|
base (int, optional): Rope base value.
|
116
118
|
condense_ratio (int, optional): The ratio by which sequence indicies are
|
117
119
|
condensed.
|
118
120
|
dtype (torch.dtype, optional): Output tensor's data type.
|
119
121
|
device (torch.device, optional): Output tensor's data type.
|
120
|
-
theta_factors (torch.Tensor, optional): A tensor of shape (
|
121
|
-
scale the theta values.
|
122
|
+
theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used
|
123
|
+
to scale the theta values.
|
122
124
|
scale (float, optional): A float used to scale the rope values.
|
123
125
|
|
124
126
|
Returns:
|
125
127
|
Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
|
126
128
|
"""
|
127
|
-
theta = 1.0 / (base ** (torch.arange(0,
|
129
|
+
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
128
130
|
theta = theta / theta_factors
|
129
|
-
seq_idx =
|
131
|
+
seq_idx = input_pos / condense_ratio
|
130
132
|
idx_theta = torch.outer(seq_idx, theta)
|
131
133
|
cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
|
132
134
|
sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
|
@@ -139,18 +141,6 @@ class Phi3_5Mini(model_builder.DecoderOnlyModel):
|
|
139
141
|
def __init__(self, config: cfg.ModelConfig):
|
140
142
|
super().__init__(config)
|
141
143
|
attn_config = self.config.block_config(0).attn_config
|
142
|
-
self.rope_cache = _build_rope_cache(
|
143
|
-
size=self.config.kv_cache_max,
|
144
|
-
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
145
|
-
base=attn_config.rotary_base,
|
146
|
-
condense_ratio=1,
|
147
|
-
dtype=torch.float32,
|
148
|
-
device=torch.device("cpu"),
|
149
|
-
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
|
150
|
-
scale=math.sqrt(
|
151
|
-
1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
|
152
|
-
),
|
153
|
-
)
|
154
144
|
|
155
145
|
|
156
146
|
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
@@ -183,16 +173,29 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
183
173
|
pre_attention_norm_config=norm_config,
|
184
174
|
post_attention_norm_config=norm_config,
|
185
175
|
)
|
176
|
+
max_seq_len = 4096
|
177
|
+
# Create the RoPE callable
|
178
|
+
build_rope = partial(
|
179
|
+
_build_phi3_rope,
|
180
|
+
condense_ratio=1,
|
181
|
+
dtype=torch.float32,
|
182
|
+
device=torch.device("cpu"),
|
183
|
+
theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
|
184
|
+
scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)),
|
185
|
+
max_seq_len=max_seq_len,
|
186
|
+
)
|
187
|
+
|
186
188
|
config = cfg.ModelConfig(
|
187
189
|
vocab_size=32064,
|
188
190
|
num_layers=32,
|
189
|
-
max_seq_len=
|
191
|
+
max_seq_len=max_seq_len,
|
190
192
|
kv_cache_max_len=kv_cache_max_len,
|
191
193
|
embedding_dim=3072,
|
192
194
|
block_configs=block_config,
|
193
195
|
final_norm_config=norm_config,
|
194
196
|
lm_head_share_weight_with_embedding=False,
|
195
197
|
enable_hlfb=True,
|
198
|
+
build_rope=build_rope,
|
196
199
|
)
|
197
200
|
return config
|
198
201
|
|