ai-edge-torch-nightly 0.3.0.dev20250107__py3-none-any.whl → 0.3.0.dev20250109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +16 -6
  2. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +16 -6
  3. ai_edge_torch/generative/examples/gemma/gemma2.py +46 -25
  4. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +16 -6
  5. ai_edge_torch/generative/examples/llama/llama.py +29 -25
  6. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +16 -9
  7. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +11 -6
  8. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +17 -7
  9. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +16 -6
  10. ai_edge_torch/generative/examples/phi/phi3.py +26 -23
  11. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +17 -9
  12. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +16 -7
  13. ai_edge_torch/generative/examples/smollm/convert_v2_to_tflite.py +71 -0
  14. ai_edge_torch/generative/examples/smollm/smollm.py +38 -0
  15. ai_edge_torch/generative/examples/smollm/verify.py +18 -2
  16. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
  17. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +16 -8
  18. ai_edge_torch/generative/layers/attention.py +45 -37
  19. ai_edge_torch/generative/layers/lora.py +557 -0
  20. ai_edge_torch/generative/layers/model_config.py +6 -2
  21. ai_edge_torch/generative/layers/rotary_position_embedding.py +34 -28
  22. ai_edge_torch/generative/test/test_lora.py +147 -0
  23. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  24. ai_edge_torch/generative/utilities/converter.py +100 -47
  25. ai_edge_torch/generative/utilities/model_builder.py +23 -14
  26. ai_edge_torch/hlfb/mark_pattern/__init__.py +19 -7
  27. ai_edge_torch/hlfb/mark_pattern/{passes.py → fx_utils.py} +9 -2
  28. ai_edge_torch/hlfb/mark_pattern/pattern.py +9 -8
  29. ai_edge_torch/hlfb/test/test_mark_pattern.py +26 -0
  30. ai_edge_torch/odml_torch/_torch_future.py +13 -0
  31. ai_edge_torch/odml_torch/export.py +6 -2
  32. ai_edge_torch/odml_torch/lowerings/decomp.py +4 -0
  33. ai_edge_torch/version.py +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/METADATA +1 -1
  35. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/RECORD +38 -35
  36. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/LICENSE +0 -0
  37. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/WHEEL +0 -0
  38. {ai_edge_torch_nightly-0.3.0.dev20250107.dist-info → ai_edge_torch_nightly-0.3.0.dev20250109.dist-info}/top_level.txt +0 -0
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'gemma',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = gemma1.build_2b_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'gemma_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'gemma2',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = gemma2.build_2b_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'gemma2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -15,13 +15,14 @@
15
15
 
16
16
  """Example of building a Gemma2 model."""
17
17
 
18
- from typing import Optional, Tuple
18
+ from typing import List, Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
25
26
  from ai_edge_torch.generative.utilities import model_builder
26
27
  import ai_edge_torch.generative.utilities.loader as loading_utils
27
28
  import torch
@@ -103,17 +104,12 @@ class Gemma2(nn.Module):
103
104
  config.embedding_dim,
104
105
  config.final_norm_config,
105
106
  )
106
- # Gemma2 has same hyper parameters for each layer except for attention
107
- # types. Use the first layer.
108
- attn_config = config.block_config(0).attn_config
109
- self.rope_cache = attn_utils.build_rope_cache(
110
- size=config.kv_cache_max,
111
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
112
- base=attn_config.rotary_base,
113
- )
114
107
  self.mask_cache = attn_utils.build_causal_mask_cache(
115
108
  size=config.kv_cache_max,
116
109
  )
110
+ # Gemma2 has same hyper parameters for each layer except for attention
111
+ # types. Use the first layer.
112
+ attn_config = config.block_config(0).attn_config
117
113
  self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
118
114
  size=config.kv_cache_max,
119
115
  window_size=attn_config.sliding_window_size,
@@ -140,29 +136,51 @@ class Gemma2(nn.Module):
140
136
  f"Cannot forward sequence of length {seq_len}, max seq length is only"
141
137
  f" {self.config.max_seq_len}"
142
138
  )
139
+
140
+ # token embeddings of shape (b, t, n_embd)
141
+ input_embeds = self.tok_embedding(tokens)
142
+ # RoPE parameters are the same for all blocks. Use the first layer.
143
+ attn_config = self.config.block_config(0).attn_config
144
+ n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
145
+ rope = rotary_pos_emb.build_rope(
146
+ input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
147
+ )
148
+ mask = [
149
+ self.get_attention_mask(
150
+ self.config.block_config(i).attn_config.attn_type, input_pos
151
+ )
152
+ for i in range(self.config.num_layers)
153
+ ]
154
+
155
+ return self._forward_with_embeds(
156
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
157
+ )
158
+
159
+ def _forward_with_embeds(
160
+ self,
161
+ input_embeds: torch.Tensor,
162
+ rope: Tuple[torch.Tensor, torch.Tensor],
163
+ mask: List[torch.Tensor],
164
+ input_pos: torch.Tensor,
165
+ kv_cache: kv_utils.KVCache,
166
+ export_config: Optional[model_builder.ExportConfig] = None,
167
+ ) -> dict[torch.Tensor, kv_utils.KVCache]:
168
+ """Forwards the model with input embeddings."""
143
169
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
144
170
  "The number of transformer blocks and the number of KV cache entries"
145
171
  " must be the same."
146
172
  )
147
173
 
148
- cos, sin = self.rope_cache
149
- cos = cos.index_select(0, input_pos)
150
- sin = sin.index_select(0, input_pos)
151
-
152
- # token embeddings of shape (b, t, n_embd)
153
- x = self.tok_embedding(tokens)
154
- x = x * (self.config.embedding_dim**0.5)
155
-
156
- updated_kv_entires = []
174
+ if self.config.embedding_scale is not None:
175
+ input_embeds = input_embeds * self.config.embedding_scale
176
+ x = input_embeds
177
+ updated_kv_entries = []
157
178
  for i, block in enumerate(self.transformer_blocks):
158
- mask = self.get_attention_mask(
159
- block.config.attn_config.attn_type, input_pos
160
- )
161
179
  kv_entry = kv_cache.caches[i] if kv_cache else None
162
- x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
180
+ x, kv_entry = block(x, rope, mask[i], input_pos, kv_entry)
163
181
  if kv_entry:
164
- updated_kv_entires.append(kv_entry)
165
- updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
182
+ updated_kv_entries.append(kv_entry)
183
+ updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entries))
166
184
 
167
185
  if export_config is not None:
168
186
  if (
@@ -228,11 +246,13 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
228
246
  )
229
247
 
230
248
  num_layers = 26
249
+ embedding_dim = 2304
231
250
  config = cfg.ModelConfig(
232
251
  vocab_size=256000,
233
252
  num_layers=num_layers,
234
253
  max_seq_len=8192,
235
- embedding_dim=2304,
254
+ embedding_dim=embedding_dim,
255
+ embedding_scale=embedding_dim**0.5,
236
256
  kv_cache_max_len=kv_cache_max_len,
237
257
  block_configs=[get_block_config(i) for i in range(num_layers)],
238
258
  final_norm_config=norm_config,
@@ -249,6 +269,7 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
249
269
  config.num_layers = 2
250
270
  config.max_seq_len = 2 * kv_cache_max_len
251
271
  config.embedding_dim = 128
272
+ config.embedding_scale = config.embedding_dim**0.5
252
273
  config.block_configs = config.block_configs[: config.num_layers]
253
274
  for block_config in config.block_configs:
254
275
  block_config.attn_config.num_heads = 4
@@ -35,10 +35,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
35
35
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/llama'),
36
36
  'The path to the model checkpoint, or directory holding the checkpoint.',
37
37
  )
38
- _TFLITE_PATH = flags.DEFINE_string(
39
- 'tflite_path',
38
+ _OUTPUT_PATH = flags.DEFINE_string(
39
+ 'output_path',
40
40
  '/tmp/',
41
- 'The tflite file path to export.',
41
+ 'The path to export the tflite model.',
42
+ )
43
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
44
+ 'output_name_prefix',
45
+ 'llama',
46
+ 'The prefix of the output tflite model name.',
42
47
  )
43
48
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
44
49
  'prefill_seq_lens',
@@ -55,6 +60,11 @@ _QUANTIZE = flags.DEFINE_bool(
55
60
  True,
56
61
  'Whether the model should be quantized.',
57
62
  )
63
+ _LORA_RANKS = flags.DEFINE_multi_integer(
64
+ 'lora_ranks',
65
+ None,
66
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
67
+ )
58
68
 
59
69
  _BUILDER = {
60
70
  '1b': llama.build_1b_model,
@@ -66,13 +76,13 @@ def main(_):
66
76
  pytorch_model = _BUILDER[_MODEL_SIZE.value](
67
77
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
68
78
  )
69
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
70
- output_filename = f'llama_{_MODEL_SIZE.value}_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
71
79
  converter.convert_to_tflite(
72
80
  pytorch_model,
73
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
81
+ output_path=_OUTPUT_PATH.value,
82
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
74
83
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
75
84
  quantize=_QUANTIZE.value,
85
+ lora_ranks=_LORA_RANKS.value,
76
86
  export_config=ExportConfig(),
77
87
  )
78
88
 
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Example of building Llama 3.2 models."""
17
17
 
18
+ from functools import partial
18
19
  import math
19
20
  from typing import Tuple
20
21
 
@@ -26,8 +27,8 @@ TENSOR_NAMES = model_builder.TENSOR_NAMES
26
27
 
27
28
 
28
29
  def _build_llama3_rope_cache(
29
- size: int,
30
- dim: int,
30
+ input_pos: torch.Tensor,
31
+ n_elem: int,
31
32
  base: int,
32
33
  condense_ratio: int,
33
34
  dtype: torch.dtype,
@@ -36,8 +37,9 @@ def _build_llama3_rope_cache(
36
37
  low_freq_factor: float,
37
38
  high_freq_factor: float,
38
39
  max_seq_len: int,
40
+ **kwargs,
39
41
  ) -> Tuple[torch.Tensor, torch.Tensor]:
40
- """Precomputes Rotary Positional Embeddings for Llama 3.2 model.
42
+ """Computes Rotary Positional Embeddings for Llama 3.2 model.
41
43
 
42
44
  It's a modified version of attn_utils.build_rope_cache with additional
43
45
  arguments for Llama 3.2 model. It precomputes Rotary Positional Embedding Sin
@@ -47,13 +49,12 @@ def _build_llama3_rope_cache(
47
49
  https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L307
48
50
 
49
51
  Args:
50
- size (int): The size of the built cache.
51
- dim (int): Each sequence's dimmension.
52
- base (int, optional): Rope base value.
53
- condense_ratio (int, optional): The ratio by which sequence indicies are
54
- condensed.
55
- dtype (torch.dtype, optional): Output tensor's data type.
56
- device (torch.device, optional): Output tensor's data type.
52
+ input_pos (torch.Tensor): the given input sequence positions
53
+ n_elem (int): Each sequence's dimmension.
54
+ base (int): Rope base value.
55
+ condense_ratio (int): The ratio by which sequence indicies are condensed.
56
+ dtype (torch.dtype): Output tensor's data type.
57
+ device (torch.device): Output tensor's data type.
57
58
  factor (float): Factor to scale theta down for tokens in long range in the
58
59
  sequence.
59
60
  low_freq_factor (float): Factor to determine if tokens are in long range
@@ -66,7 +67,7 @@ def _build_llama3_rope_cache(
66
67
  Returns:
67
68
  Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
68
69
  """
69
- theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
70
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
70
71
  low_freq_wavelen = max_seq_len / low_freq_factor
71
72
  high_freq_wavelen = max_seq_len / high_freq_factor
72
73
  wavelen = 2 * math.pi / theta
@@ -81,7 +82,7 @@ def _build_llama3_rope_cache(
81
82
  is_medium = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
82
83
  theta = torch.where(is_medium, smoothed_theta, theta)
83
84
 
84
- seq_idx = torch.arange(size) / condense_ratio
85
+ seq_idx = input_pos / condense_ratio
85
86
  idx_theta = torch.outer(seq_idx, theta)
86
87
  cos = torch.cos(idx_theta).to(dtype=dtype, device=device)
87
88
  sin = torch.sin(idx_theta).to(dtype=dtype, device=device)
@@ -97,18 +98,6 @@ class Llama(model_builder.DecoderOnlyModel):
97
98
  def __init__(self, config: cfg.ModelConfig):
98
99
  super().__init__(config)
99
100
  attn_config = self.config.block_config(0).attn_config
100
- self.rope_cache = _build_llama3_rope_cache(
101
- size=self.config.kv_cache_max,
102
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
103
- base=attn_config.rotary_base,
104
- condense_ratio=1,
105
- dtype=torch.float32,
106
- device=torch.device("cpu"),
107
- factor=32.0,
108
- low_freq_factor=1.0,
109
- high_freq_factor=4.0,
110
- max_seq_len=self.config.max_seq_len,
111
- )
112
101
 
113
102
 
114
103
  def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -140,15 +129,30 @@ def get_1b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
140
129
  pre_attention_norm_config=norm_config,
141
130
  post_attention_norm_config=norm_config,
142
131
  )
132
+
133
+ max_seq_len = 8192
134
+ # Create the RoPE callable
135
+ build_rope = partial(
136
+ _build_llama3_rope_cache,
137
+ condense_ratio=1,
138
+ dtype=torch.float32,
139
+ device=torch.device("cpu"),
140
+ factor=32.0,
141
+ low_freq_factor=1.0,
142
+ high_freq_factor=4.0,
143
+ max_seq_len=max_seq_len,
144
+ )
145
+
143
146
  config = cfg.ModelConfig(
144
147
  vocab_size=128256,
145
148
  num_layers=16,
146
- max_seq_len=8192,
149
+ max_seq_len=max_seq_len,
147
150
  embedding_dim=2048,
148
151
  kv_cache_max_len=kv_cache_max_len,
149
152
  block_configs=block_config,
150
153
  final_norm_config=norm_config,
151
154
  enable_hlfb=True,
155
+ build_rope=build_rope,
152
156
  )
153
157
  return config
154
158
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/openelm'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'openelm',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,22 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = openelm.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = (
60
- f'openelm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
61
- )
62
-
63
68
  converter.convert_to_tflite(
64
69
  pytorch_model,
65
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
66
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
67
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
68
75
  export_config=ExportConfig(),
69
76
  )
70
77
 
@@ -40,10 +40,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
40
40
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma2-3b-224'),
41
41
  'The path to the model checkpoint, or directory holding the checkpoint.',
42
42
  )
43
- _TFLITE_PATH = flags.DEFINE_string(
44
- 'tflite_path',
43
+ _OUTPUT_PATH = flags.DEFINE_string(
44
+ 'output_path',
45
45
  '/tmp/',
46
- 'The tflite file path to export.',
46
+ 'The path to export the tflite model.',
47
+ )
48
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
49
+ 'output_name_prefix',
50
+ 'paligemma',
51
+ 'The prefix of the output tflite model name.',
47
52
  )
48
53
  _PREFILL_SEQ_LEN = flags.DEFINE_integer(
49
54
  'prefill_seq_len',
@@ -73,11 +78,11 @@ def main(_):
73
78
  version=int(_VERSION.value),
74
79
  kv_cache_max_len=_KV_CACHE_MAX_LEN.value,
75
80
  )
76
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
77
- output_filename = f'paligemma{_VERSION.value}_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
81
+
78
82
  converter.convert_to_tflite(
79
83
  pytorch_model,
80
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
84
+ output_path=_OUTPUT_PATH.value,
85
+ output_name_prefix=f'{_OUTPUT_NAME_PREFIX.value}_{_VERSION.value}',
81
86
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
82
87
  pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
83
88
  quantize=_QUANTIZE.value,
@@ -26,13 +26,18 @@ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
26
26
 
27
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
28
28
  'checkpoint_path',
29
- os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi3'),
29
+ os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma-2b'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'phi3',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = phi3.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'phi3_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -29,10 +29,15 @@ _CHECKPOINT_PATH = flags.DEFINE_string(
29
29
  os.path.join(pathlib.Path.home(), 'Downloads/llm_data/phi2'),
30
30
  'The path to the model checkpoint, or directory holding the checkpoint.',
31
31
  )
32
- _TFLITE_PATH = flags.DEFINE_string(
33
- 'tflite_path',
32
+ _OUTPUT_PATH = flags.DEFINE_string(
33
+ 'output_path',
34
34
  '/tmp/',
35
- 'The tflite file path to export.',
35
+ 'The path to export the tflite model.',
36
+ )
37
+ _OUTPUT_NAME_PREFIX = flags.DEFINE_string(
38
+ 'output_name_prefix',
39
+ 'phi2',
40
+ 'The prefix of the output tflite model name.',
36
41
  )
37
42
  _PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
38
43
  'prefill_seq_lens',
@@ -49,19 +54,24 @@ _QUANTIZE = flags.DEFINE_bool(
49
54
  True,
50
55
  'Whether the model should be quantized.',
51
56
  )
57
+ _LORA_RANKS = flags.DEFINE_multi_integer(
58
+ 'lora_ranks',
59
+ None,
60
+ 'If set, the model will be converted with the provided list of LoRA ranks.',
61
+ )
52
62
 
53
63
 
54
64
  def main(_):
55
65
  pytorch_model = phi2.build_model(
56
66
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
57
67
  )
58
- quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
59
- output_filename = f'phi2_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
60
68
  converter.convert_to_tflite(
61
69
  pytorch_model,
62
- tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
70
+ output_path=_OUTPUT_PATH.value,
71
+ output_name_prefix=_OUTPUT_NAME_PREFIX.value,
63
72
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
64
73
  quantize=_QUANTIZE.value,
74
+ lora_ranks=_LORA_RANKS.value,
65
75
  export_config=ExportConfig(),
66
76
  )
67
77
 
@@ -15,6 +15,7 @@
15
15
 
16
16
  """Example of building a Phi-3.5 model up to 4K tokens, not to 128K tokens."""
17
17
 
18
+ from functools import partial
18
19
  import math
19
20
  from typing import Tuple
20
21
 
@@ -93,40 +94,41 @@ ROPE_SHORT_FACTOR = [
93
94
  ]
94
95
 
95
96
 
96
- def _build_rope_cache(
97
- size: int,
98
- dim: int,
97
+ def _build_phi3_rope(
98
+ input_pos: int,
99
+ n_elem: int,
99
100
  base: int,
100
101
  condense_ratio: int,
101
102
  dtype: torch.dtype,
102
103
  device: torch.device,
103
104
  theta_factors: torch.Tensor,
104
105
  scale: float,
106
+ **kwargs,
105
107
  ) -> Tuple[torch.Tensor, torch.Tensor]:
106
- """Precomputes Rotary Positional Embeddings for Phi-3.5 model.
108
+ """Computes Rotary Positional Embeddings for Phi-3.5 model.
107
109
 
108
110
  It's a modified version of attn_utils.build_rope_cache with additional
109
111
  arguments for Phi-3.5 model. It precompute Rotary Positional Embedding Sin and
110
112
  Cos values with scaling factors for quick lookup during the inference.
111
113
 
112
114
  Args:
113
- size (int): The size of the built cache.
114
- dim (int): Each sequence's dimmension.
115
+ input_pos (torch.Tensor): the given input sequence positions
116
+ n_elem (int): Each sequence's dimmension.
115
117
  base (int, optional): Rope base value.
116
118
  condense_ratio (int, optional): The ratio by which sequence indicies are
117
119
  condensed.
118
120
  dtype (torch.dtype, optional): Output tensor's data type.
119
121
  device (torch.device, optional): Output tensor's data type.
120
- theta_factors (torch.Tensor, optional): A tensor of shape (dim,) used to
121
- scale the theta values.
122
+ theta_factors (torch.Tensor, optional): A tensor of shape (n_elem,) used
123
+ to scale the theta values.
122
124
  scale (float, optional): A float used to scale the rope values.
123
125
 
124
126
  Returns:
125
127
  Tuple[torch.Tensor, torch.Tensor]: Rope's Cosine and Sine waves.
126
128
  """
127
- theta = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
129
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
128
130
  theta = theta / theta_factors
129
- seq_idx = torch.arange(size) / condense_ratio
131
+ seq_idx = input_pos / condense_ratio
130
132
  idx_theta = torch.outer(seq_idx, theta)
131
133
  cos = torch.cos(idx_theta).to(dtype=dtype, device=device) * scale
132
134
  sin = torch.sin(idx_theta).to(dtype=dtype, device=device) * scale
@@ -139,18 +141,6 @@ class Phi3_5Mini(model_builder.DecoderOnlyModel):
139
141
  def __init__(self, config: cfg.ModelConfig):
140
142
  super().__init__(config)
141
143
  attn_config = self.config.block_config(0).attn_config
142
- self.rope_cache = _build_rope_cache(
143
- size=self.config.kv_cache_max,
144
- dim=int(attn_config.rotary_percentage * attn_config.head_dim),
145
- base=attn_config.rotary_base,
146
- condense_ratio=1,
147
- dtype=torch.float32,
148
- device=torch.device("cpu"),
149
- theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
150
- scale=math.sqrt(
151
- 1 + math.log(ROPE_SCALE_FACTOR) / math.log(config.max_seq_len)
152
- ),
153
- )
154
144
 
155
145
 
156
146
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
@@ -183,16 +173,29 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
183
173
  pre_attention_norm_config=norm_config,
184
174
  post_attention_norm_config=norm_config,
185
175
  )
176
+ max_seq_len = 4096
177
+ # Create the RoPE callable
178
+ build_rope = partial(
179
+ _build_phi3_rope,
180
+ condense_ratio=1,
181
+ dtype=torch.float32,
182
+ device=torch.device("cpu"),
183
+ theta_factors=torch.tensor(ROPE_SHORT_FACTOR),
184
+ scale=math.sqrt(1 + math.log(ROPE_SCALE_FACTOR) / math.log(max_seq_len)),
185
+ max_seq_len=max_seq_len,
186
+ )
187
+
186
188
  config = cfg.ModelConfig(
187
189
  vocab_size=32064,
188
190
  num_layers=32,
189
- max_seq_len=4096,
191
+ max_seq_len=max_seq_len,
190
192
  kv_cache_max_len=kv_cache_max_len,
191
193
  embedding_dim=3072,
192
194
  block_configs=block_config,
193
195
  final_norm_config=norm_config,
194
196
  lm_head_share_weight_with_embedding=False,
195
197
  enable_hlfb=True,
198
+ build_rope=build_rope,
196
199
  )
197
200
  return config
198
201