ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (38) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  3. ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
  4. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  5. ai_edge_torch/generative/examples/llama/llama.py +29 -25
  6. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  7. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  8. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  9. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  10. ai_edge_torch/generative/examples/phi/phi3.py +26 -23
  11. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  12. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  13. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  14. ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
  15. ai_edge_torch/generative/examples/smollm/verify.py +18 -2
  16. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  17. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  18. ai_edge_torch/generative/layers/attention.py +45 -37
  19. ai_edge_torch/generative/layers/lora.py +557 -0
  20. ai_edge_torch/generative/layers/model_config.py +6 -2
  21. ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
  22. ai_edge_torch/generative/test/test_lora.py +147 -0
  23. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  24. ai_edge_torch/generative/utilities/converter.py +100 -47
  25. ai_edge_torch/generative/utilities/model_builder.py +23 -14
  26. ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
  27. ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
  28. ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
  29. ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
  30. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  31. ai_edge_torch/odml_torch/export.py +6 -2
  32. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  33. ai_edge_torch/version.py +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
  35. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +38 -35
  36. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
  37. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
  38. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'gemma',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = gemma1.build_2b_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'gemma2',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = gemma2.build_2b_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -15,13 +15,14 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- from typing import Optional, Tuple
18
+ from typing import List, Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
25
26
  from ai_edge_torch.generative.utilities import model_builder
26
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
28
  import torch
@@ -103,17 +104,12 @@ class Gemma2(nn.Module):
103
104
  config.embedding_dim,
104
105
  config.final_norm_config,
105
106
  )
106
- # Gemma2 has same hyper parameters for each layer except for attention
107
- # types. Use the first layer.
108
- attn_config = config.block_config(0).attn_config
109
- self.rope_cache = attn_utils.build_rope_cache(
110
- size=config.kv_cache_max,
111
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
112
- base=attn_config.rotary_base,
113
- )
114
107
  self.mask_cache = attn_utils.build_causal_mask_cache(
115
108
  size=config.kv_cache_max,
116
109
  )
110
+ # Gemma2 has same hyper parameters for each layer except for attention
111
+ # types. Use the first layer.
112
+ attn_config = config.block_config(0).attn_config
117
113
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
118
114
  size=config.kv_cache_max,
119
115
  window_size=attn_config.sliding_window_size,
@@ -140,29 +136,51 @@ class Gemma2(nn.Module):
140
136
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
141
137
  f" {self.config.max_seq_len}"
142
138
  )
139
+
140
+ # token embeddings of shape (b, t, n_embd)
141
+ input_embeds = self.tok_embedding(tokens)
142
+ # RoPE parameters are the same for all blocks. Use the first layer.
143
+ attn_config = self.config.block_config(0).attn_config
144
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
145
+ rope = rotary_pos_emb.build_rope(
146
+ input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
147
+ )
148
+ mask = [
149
+ self.get_attention_mask(
150
+ self.config.block_config(i).attn_config.attn_type, input_pos
151
+ )
152
+ for i in range(self.config.num_layers)
153
+ ]
154
+
155
+ return self._forward_with_embeds(
156
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
157
+ )
158
+
159
+ def _forward_with_embeds(
160
+ self,
161
+ input_embeds: torch.Tensor,
162
+ rope: Tuple[torch.Tensor, torch.Tensor],
163
+ mask: List[torch.Tensor],
164
+ input_pos: torch.Tensor,
165
+ kv_cache: kv_utils.KVCache,
166
+ export_config: Optional[model_builder.ExportConfig] = None,
167
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
168
+ """Forwards the model with input embeddings."""
143
169
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
144
170
  "The number of transformer blocks and the number of KV cache entries"
145
171
  " must be the same."
146
172
  )
147
173
 
148
- cos, sin = self.rope_cache
149
- cos = cos.index_select(0, input_pos)
150
- sin = sin.index_select(0, input_pos)
151
-
152
- # token embeddings of shape (b, t, n_embd)
153
- x = self.tok_embedding(tokens)
154
- x = x * (self.config.embedding_dim**0.5)
155
-
156
- updated_kv_entires = []
174
+ if self.config.embedding_scale is not None:
175
+ input_embeds = input_embeds * self.config.embedding_scale
176
+ x = input_embeds
177
+ updated_kv_entries = []
157
178
  for i, block in enumerate(self.transformer_blocks):
158
- mask = self.get_attention_mask(
159
- block.config.attn_config.attn_type, input_pos
160
- )
161
179
  kv_entry = kv_cache.caches[i] if kv_cache else None
162
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
180
+ x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
163
181
  if kv_entry:
164
- updated_kv_entires.append(kv_entry)
165
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
182
+ updated_kv_entries.append(kv_entry)
183
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
166
184
 
167
185
  if export_config is not None:
168
186
  if (
@@ -228,11 +246,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
228
246
  )
229
247
 
230
248
  num_layers = 26
249
+ embedding_dim = 2304
231
250
  config = cfg.ModelConfig(
232
251
  vocab_size=256000,
233
252
  num_layers=num_layers,
234
253
  max_seq_len=8192,
235
- embedding_dim=2304,
254
+ embedding_dim=embedding_dim,
255
+ embedding_scale=embedding_dim**0.5,
236
256
  kv_cache_max_len=kv_cache_max_len,
237
257
  block_configs=[get_block_config(i) for i in range(num_layers)],
238
258
  final_norm_config=norm_config,
@@ -249,6 +269,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
249
269
  config.num_layers = 2
250
270
  config.max_seq_len = 2 * kv_cache_max_len
251
271
  config.embedding_dim = 128
272
+ config.embedding_scale = config.embedding_dim**0.5
252
273
  config.block_configs = config.block_configs[: config.num_layers]
253
274
  for block_config in config.block_configs:
254
275
  block_config.attn_config.num_heads = 4
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
35
35
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
36
36
  'The path to the model checkpoint, or directory holding the checkpoint.',
37
37
  )
38
- _TFLITE_PATH = flags.DEFINE_string(
39
- 'tflite_path',
38
+ _OUTPUT_PATH = flags.DEFINE_string(
39
+ 'output_path',
40
40
  '/tmp/',
41
- 'The tflite file path to export.',
41
+ 'The path to export the tflite model.',
42
+ )
43
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44
+ 'output_name_prefix',
45
+ 'llama',
46
+ 'The prefix of the output tflite model name.',
42
47
  )
43
48
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
44
49
  'prefill_seq_lens',
@@ -55,6 +60,11 @@ _QUANTIZE = flags.DEFINE_bool(
55
60
  True,
56
61
  'Whether the model should be quantized.',
57
62
  )
63
+ _LORA_RANKS = flags.DEFINE_multi_integer(
64
+ 'lora_ranks',
65
+ None,
66
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
67
+ )
58
68
 
59
69
  _BUILDER = {
60
70
  '1b': llama.build_1b_model,
@@ -66,13 +76,13 @@ def main(_):
66
76
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
67
77
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
68
78
  )
69
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
70
- output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
71
79
  converter.convert_to_tflite(
72
80
  pytorch_model,
73
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
81
+ output_path=_OUTPUT_PATH.value,
82
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
74
83
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
75
84
  quantize=_QUANTIZE.value,
85
+ lora_ranks=_LORA_RANKS.value,
76
86
  export_config=ExportConfig(),
77
87
  )
78
88
 
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Example of building Llama 3.2 models."""
17
17
 
18
+ from functools import partial
18
19
  import math
19
20
  from typing import Tuple
20
21
 
@@ -26,8 +27,8 @@ TENSOR_NAMES = model_builder.TENSOR_NAMES
26
27
 
27
28
 
28
29
  def _build_llama3_rope_cache(
29
- size: int,
30
- dim: int,
30
+ input_pos: torch.Tensor,
31
+ n_elem: int,
31
32
  base: int,
32
33
  condense_ratio: int,
33
34
  dtype: torch.dtype,
@@ -36,8 +37,9 @@ def _build_llama3_rope_cache(
36
37
  low_freq_factor: float,
37
38
  high_freq_factor: float,
38
39
  max_seq_len: int,
40
+ **kwargs,
39
41
  ) -> Tuple[torch.Tensor, torch.Tensor]:
40
- """Precomputes Rotary Positional Embeddings for Llama 3.2 model.
42
+ """Computes Rotary Positional Embeddings for Llama 3.2 model.
41
43
 
42
44
  It's a modified version of attn_utils.build_rope_cache with additional
43
45
  arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin
@@ -47,13 +49,12 @@ def _build_llama3_rope_cache(
47
49
  https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307
48
50
 
49
51
  Args:
50
- size (int): The size of the built cache.
51
- dim (int): Each sequence's dimmension.
52
- base (int, optional): Rope base value.
53
- condense_ratio (int, optional): The ratio by which sequence indicies are
54
- condensed.
55
- dtype (torch.dtype, optional): Output tensor's data type.
56
- device (torch.device, optional): Output tensor's data type.
52
+ input_pos (torch.Tensor): the given input sequence positions
53
+ n_elem (int): Each sequence's dimmension.
54
+ base (int): Rope base value.
55
+ condense_ratio (int): The ratio by which sequence indicies are condensed.
56
+ dtype (torch.dtype): Output tensor's data type.
57
+ device (torch.device): Output tensor's data type.
57
58
  factor (float): Factor to scale theta down for tokens in long range in the
58
59
  sequence.
59
60
  low_freq_factor (float): Factor to determine if tokens are in long range
@@ -66,7 +67,7 @@ def _build_llama3_rope_cache(
66
67
  Returns:
67
68
  Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
68
69
  """
69
- theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
70
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
70
71
  low_freq_wavelen = max_seq_len / low_freq_factor
71
72
  high_freq_wavelen = max_seq_len / high_freq_factor
72
73
  wavelen = 2 * math.pi / theta
@@ -81,7 +82,7 @@ def _build_llama3_rope_cache(
81
82
  is_medium = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
82
83
  theta = torch.where(is_medium, smoothed_theta, theta)
83
84
 
84
- seq_idx = torch.arange(size) / condense_ratio
85
+ seq_idx = input_pos / condense_ratio
85
86
  idx_theta = torch.outer(seq_idx, theta)
86
87
  cos = torch.cos(idx_theta).to(dtype=dtype, device=device)
87
88
  sin = torch.sin(idx_theta).to(dtype=dtype, device=device)
@@ -97,18 +98,6 @@ class Llama(model_builder.DecoderOnlyModel):
97
98
  def __init__(self, config: cfg.ModelConfig):
98
99
  super().__init__(config)
99
100
  attn_config = self.config.block_config(0).attn_config
100
- self.rope_cache = _build_llama3_rope_cache(
101
- size=self.config.kv_cache_max,
102
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
103
- base=attn_config.rotary_base,
104
- condense_ratio=1,
105
- dtype=torch.float32,
106
- device=torch.device("cpu"),
107
- factor=32.0,
108
- low_freq_factor=1.0,
109
- high_freq_factor=4.0,
110
- max_seq_len=self.config.max_seq_len,
111
- )
112
101
 
113
102
 
114
103
  def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -140,15 +129,30 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
140
129
  pre_attention_norm_config=norm_config,
141
130
  post_attention_norm_config=norm_config,
142
131
  )
132
+
133
+ max_seq_len = 8192
134
+ # Create the RoPE callable
135
+ build_rope = partial(
136
+ _build_llama3_rope_cache,
137
+ condense_ratio=1,
138
+ dtype=torch.float32,
139
+ device=torch.device("cpu"),
140
+ factor=32.0,
141
+ low_freq_factor=1.0,
142
+ high_freq_factor=4.0,
143
+ max_seq_len=max_seq_len,
144
+ )
145
+
143
146
  config = cfg.ModelConfig(
144
147
  vocab_size=128256,
145
148
  num_layers=16,
146
- max_seq_len=8192,
149
+ max_seq_len=max_seq_len,
147
150
  embedding_dim=2048,
148
151
  kv_cache_max_len=kv_cache_max_len,
149
152
  block_configs=block_config,
150
153
  final_norm_config=norm_config,
151
154
  enable_hlfb=True,
155
+ build_rope=build_rope,
152
156
  )
153
157
  return config
154
158
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'openelm',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,22 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = openelm.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = (
60
- f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
- )
62
-
63
68
  converter.convert_to_tflite(
64
69
  pytorch_model,
65
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
66
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
67
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
68
75
  export_config=ExportConfig(),
69
76
  )
70
77
 
@@ -40,10 +40,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
40
40
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
41
41
  'The path to the model checkpoint, or directory holding the checkpoint.',
42
42
  )
43
- _TFLITE_PATH = flags.DEFINE_string(
44
- 'tflite_path',
43
+ _OUTPUT_PATH = flags.DEFINE_string(
44
+ 'output_path',
45
45
  '/tmp/',
46
- 'The tflite file path to export.',
46
+ 'The path to export the tflite model.',
47
+ )
48
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
49
+ 'output_name_prefix',
50
+ 'paligemma',
51
+ 'The prefix of the output tflite model name.',
47
52
  )
48
53
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
49
54
  'prefill_seq_len',
@@ -73,11 +78,11 @@ def main(_):
73
78
  version=int(_VERSION.value),
74
79
  kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
75
80
  )
76
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
77
- output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
81
+
78
82
  converter.convert_to_tflite(
79
83
  pytorch_model,
80
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
84
+ output_path=_OUTPUT_PATH.value,
85
+ output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
81
86
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
82
87
  pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
83
88
  quantize=_QUANTIZE.value,
@@ -26,13 +26,18 @@ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
26
 
27
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
28
28
  'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'phi3',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = phi3.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'phi2',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = phi2.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
17
17
 
18
+ from functools import partial
18
19
  import math
19
20
  from typing import Tuple
20
21
 
@@ -93,40 +94,41 @@ ROPE_SHORT_FACTOR = [
93
94
  ]
94
95
 
95
96
 
96
- def _build_rope_cache(
97
- size: int,
98
- dim: int,
97
+ def _build_phi3_rope(
98
+ input_pos: int,
99
+ n_elem: int,
99
100
  base: int,
100
101
  condense_ratio: int,
101
102
  dtype: torch.dtype,
102
103
  device: torch.device,
103
104
  theta_factors: torch.Tensor,
104
105
  scale: float,
106
+ **kwargs,
105
107
  ) -> Tuple[torch.Tensor, torch.Tensor]:
106
- """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
108
+ """Computes Rotary Positional Embeddings for Phi-3.5 model.
107
109
 
108
110
  It's a modified version of attn_utils.build_rope_cache with additional
109
111
  arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
110
112
  Cos values with scaling factors for quick lookup during the inference.
111
113
 
112
114
  Args:
113
- size (int): The size of the built cache.
114
- dim (int): Each sequence's dimmension.
115
+ input_pos (torch.Tensor): the given input sequence positions
116
+ n_elem (int): Each sequence's dimmension.
115
117
  base (int, optional): Rope base value.
116
118
  condense_ratio (int, optional): The ratio by which sequence indicies are
117
119
  condensed.
118
120
  dtype (torch.dtype, optional): Output tensor's data type.
119
121
  device (torch.device, optional): Output tensor's data type.
120
- theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
121
- scale the theta values.
122
+ theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used
123
+ to scale the theta values.
122
124
  scale (float, optional): A float used to scale the rope values.
123
125
 
124
126
  Returns:
125
127
  Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
126
128
  """
127
- theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
129
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
128
130
  theta = theta / theta_factors
129
- seq_idx = torch.arange(size) / condense_ratio
131
+ seq_idx = input_pos / condense_ratio
130
132
  idx_theta = torch.outer(seq_idx, theta)
131
133
  cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
132
134
  sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
@@ -139,18 +141,6 @@ class Phi3_5Mini(model_builder.DecoderOnlyModel):
139
141
  def __init__(self, config: cfg.ModelConfig):
140
142
  super().__init__(config)
141
143
  attn_config = self.config.block_config(0).attn_config
142
- self.rope_cache = _build_rope_cache(
143
- size=self.config.kv_cache_max,
144
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
145
- base=attn_config.rotary_base,
146
- condense_ratio=1,
147
- dtype=torch.float32,
148
- device=torch.device("cpu"),
149
- theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
150
- scale=math.sqrt(
151
- 1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
152
- ),
153
- )
154
144
 
155
145
 
156
146
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -183,16 +173,29 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
183
173
  pre_attention_norm_config=norm_config,
184
174
  post_attention_norm_config=norm_config,
185
175
  )
176
+ max_seq_len = 4096
177
+ # Create the RoPE callable
178
+ build_rope = partial(
179
+ _build_phi3_rope,
180
+ condense_ratio=1,
181
+ dtype=torch.float32,
182
+ device=torch.device("cpu"),
183
+ theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
184
+ scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)),
185
+ max_seq_len=max_seq_len,
186
+ )
187
+
186
188
  config = cfg.ModelConfig(
187
189
  vocab_size=32064,
188
190
  num_layers=32,
189
- max_seq_len=4096,
191
+ max_seq_len=max_seq_len,
190
192
  kv_cache_max_len=kv_cache_max_len,
191
193
  embedding_dim=3072,
192
194
  block_configs=block_config,
193
195
  final_norm_config=norm_config,
194
196
  lm_head_share_weight_with_embedding=False,
195
197
  enable_hlfb=True,
198
+ build_rope=build_rope,
196
199
  )
197
200
  return config
198
201