ai-edge-torch-nightly 0.3.0.dev20241209__py3-none-any.whl → 0.3.0.dev20241210__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (35) hide show
  1. ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py +8 -3
  2. ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py +2 -0
  3. ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py +2 -0
  4. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +2 -0
  5. ai_edge_torch/generative/examples/gemma/gemma1.py +8 -3
  6. ai_edge_torch/generative/examples/gemma/gemma2.py +15 -8
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +2 -0
  8. ai_edge_torch/generative/examples/llama/llama.py +11 -17
  9. ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +2 -0
  10. ai_edge_torch/generative/examples/openelm/openelm.py +8 -3
  11. ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +2 -0
  12. ai_edge_torch/generative/examples/paligemma/decoder.py +7 -9
  13. ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py +2 -0
  14. ai_edge_torch/generative/examples/phi/convert_to_tflite.py +2 -0
  15. ai_edge_torch/generative/examples/phi/phi2.py +8 -3
  16. ai_edge_torch/generative/examples/phi/phi3.py +7 -9
  17. ai_edge_torch/generative/examples/qwen/convert_to_tflite.py +2 -0
  18. ai_edge_torch/generative/examples/qwen/qwen.py +12 -9
  19. ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +3 -0
  20. ai_edge_torch/generative/examples/smollm/smollm.py +8 -3
  21. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +12 -2
  22. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +2 -0
  23. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +8 -3
  24. ai_edge_torch/generative/test/test_model_conversion.py +1 -1
  25. ai_edge_torch/generative/test/test_model_conversion_large.py +6 -6
  26. ai_edge_torch/generative/utilities/converter.py +25 -4
  27. ai_edge_torch/generative/utilities/model_builder.py +24 -4
  28. ai_edge_torch/generative/utilities/transformers_verifier.py +3 -3
  29. ai_edge_torch/generative/utilities/verifier.py +16 -2
  30. ai_edge_torch/version.py +1 -1
  31. {ai_edge_torch_nightly-0.3.0.dev20241209.dist-info → ai_edge_torch_nightly-0.3.0.dev20241210.dist-info}/METADATA +1 -1
  32. {ai_edge_torch_nightly-0.3.0.dev20241209.dist-info → ai_edge_torch_nightly-0.3.0.dev20241210.dist-info}/RECORD +35 -35
  33. {ai_edge_torch_nightly-0.3.0.dev20241209.dist-info → ai_edge_torch_nightly-0.3.0.dev20241210.dist-info}/LICENSE +0 -0
  34. {ai_edge_torch_nightly-0.3.0.dev20241209.dist-info → ai_edge_torch_nightly-0.3.0.dev20241210.dist-info}/WHEEL +0 -0
  35. {ai_edge_torch_nightly-0.3.0.dev20241209.dist-info → ai_edge_torch_nightly-0.3.0.dev20241210.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,16 @@
17
17
 
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
20
21
 
21
22
  TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
22
23
 
23
24
 
25
+ class AmdLlama(model_builder.DecoderOnlyModel):
26
+ """An AMD-Llama model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
24
30
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
25
31
  """Returns the model config for an AMD-Llama-135m model.
26
32
 
@@ -72,11 +78,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
72
78
  return config
73
79
 
74
80
 
75
- def build_model(
76
- checkpoint_path: str, **kwargs
77
- ) -> model_builder.DecoderOnlyModel:
81
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
78
82
  return model_builder.build_decoder_only_model(
79
83
  checkpoint_path=checkpoint_path,
80
84
  config=get_model_config(**kwargs),
81
85
  tensor_names=TENSOR_NAMES,
86
+ model_class=AmdLlama
82
87
  )
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
61
62
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
63
  prefill_seq_len=_PREFILL_SEQ_LEN.value,
63
64
  quantize=_QUANTIZE.value,
65
+ export_config=ExportConfig(),
64
66
  )
65
67
 
66
68
 
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.gemma import gemma1
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
61
62
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
63
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
64
  quantize=_QUANTIZE.value,
65
+ export_config=ExportConfig(),
64
66
  )
65
67
 
66
68
 
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.gemma import gemma2
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
61
62
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
63
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
64
  quantize=_QUANTIZE.value,
65
+ export_config=ExportConfig(),
64
66
  )
65
67
 
66
68
 
@@ -18,6 +18,7 @@
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
21
+ from torch import nn
21
22
 
22
23
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
23
24
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -33,6 +34,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
33
34
  )
34
35
 
35
36
 
37
+ class Gemma1(model_builder.DecoderOnlyModel):
38
+ """A Gemma1 model built from the Edge Generative API layers."""
39
+ pass
40
+
41
+
36
42
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
37
43
  """Returns the model config for a Gemma 2B model.
38
44
 
@@ -91,11 +97,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
91
97
  return config
92
98
 
93
99
 
94
- def build_2b_model(
95
- checkpoint_path: str, **kwargs
96
- ) -> model_builder.DecoderOnlyModel:
100
+ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
97
101
  return model_builder.build_decoder_only_model(
98
102
  checkpoint_path=checkpoint_path,
99
103
  config=get_model_config_2b(**kwargs),
100
104
  tensor_names=TENSOR_NAMES,
105
+ model_class=Gemma1,
101
106
  )
@@ -22,6 +22,7 @@ from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ from ai_edge_torch.generative.utilities import model_builder
25
26
  import ai_edge_torch.generative.utilities.loader as loading_utils
26
27
  import torch
27
28
  from torch import nn
@@ -132,6 +133,7 @@ class Gemma2(nn.Module):
132
133
  tokens: torch.Tensor,
133
134
  input_pos: torch.Tensor,
134
135
  kv_cache: kv_utils.KVCache,
136
+ export_config: Optional[model_builder.ExportConfig] = None,
135
137
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
136
138
  _, seq_len = tokens.size()
137
139
  assert self.config.max_seq_len >= seq_len, (
@@ -162,6 +164,13 @@ class Gemma2(nn.Module):
162
164
  updated_kv_entires.append(kv_entry)
163
165
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
164
166
 
167
+ if export_config is not None:
168
+ if (
169
+ torch.numel(input_pos) > 1
170
+ and not export_config.output_logits_on_prefill
171
+ ):
172
+ return {"kv_cache": updated_kv_cache}
173
+
165
174
  x = self.final_norm(x)
166
175
  res = self.lm_head(x) # (b, t, vocab_size)
167
176
  if self.config.final_logit_softcap is not None:
@@ -250,11 +259,9 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
250
259
 
251
260
 
252
261
  def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
253
- config = get_model_config_2b(**kwargs)
254
- model = Gemma2(config)
255
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
256
- # Since embedding and lm-head use the same weight, we need to set strict
257
- # to False.
258
- loader.load(model, strict=False)
259
- model.eval()
260
- return model
262
+ return model_builder.build_decoder_only_model(
263
+ checkpoint_path=checkpoint_path,
264
+ config=get_model_config_2b(**kwargs),
265
+ tensor_names=TENSOR_NAMES,
266
+ model_class=Gemma2,
267
+ )
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.llama import llama
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _MODEL_SIZE = flags.DEFINE_enum(
27
28
  'model_size',
@@ -72,6 +73,7 @@ def main(_):
72
73
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
73
74
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
74
75
  quantize=_QUANTIZE.value,
76
+ export_config=ExportConfig(),
75
77
  )
76
78
 
77
79
 
@@ -20,7 +20,6 @@ from typing import Tuple
20
20
 
21
21
  import ai_edge_torch.generative.layers.model_config as cfg
22
22
  from ai_edge_torch.generative.utilities import model_builder
23
- import ai_edge_torch.generative.utilities.loader as loading_utils
24
23
  import torch
25
24
 
26
25
  TENSOR_NAMES = model_builder.TENSOR_NAMES
@@ -177,23 +176,18 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
177
176
 
178
177
  def _build_model(
179
178
  checkpoint_path: str, config: cfg.ModelConfig
180
- ) -> model_builder.DecoderOnlyModel:
181
- model = Llama(config)
182
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
183
- # Since embedding and lm-head use the same weight, we need to set strict
184
- # to False.
185
- loader.load(model, strict=False)
186
- model.eval()
187
- return model
188
-
189
-
190
- def build_1b_model(
191
- checkpoint_path: str, **kwargs
192
- ) -> model_builder.DecoderOnlyModel:
179
+ ) -> torch.nn.Module:
180
+ return model_builder.build_decoder_only_model(
181
+ checkpoint_path=checkpoint_path,
182
+ config=config,
183
+ tensor_names=TENSOR_NAMES,
184
+ model_class=Llama,
185
+ )
186
+
187
+
188
+ def build_1b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
193
189
  return _build_model(checkpoint_path, get_1b_model_config(**kwargs))
194
190
 
195
191
 
196
- def build_3b_model(
197
- checkpoint_path: str, **kwargs
198
- ) -> model_builder.DecoderOnlyModel:
192
+ def build_3b_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
199
193
  return _build_model(checkpoint_path, get_3b_model_config(**kwargs))
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.openelm import openelm
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -64,6 +65,7 @@ def main(_):
64
65
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
65
66
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
66
67
  quantize=_QUANTIZE.value,
68
+ export_config=ExportConfig(),
67
69
  )
68
70
 
69
71
 
@@ -18,6 +18,7 @@
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
21
+ from torch import nn
21
22
 
22
23
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
23
24
  ff_up_proj="transformer.layers.{}.ffn.proj_1",
@@ -34,6 +35,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
34
35
  )
35
36
 
36
37
 
38
+ class OpenELM(model_builder.DecoderOnlyModel):
39
+ """An OpenELM model built from the Edge Generative API layers."""
40
+ pass
41
+
42
+
37
43
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
38
44
  """Returns the model config for an OpenELM model.
39
45
 
@@ -112,11 +118,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
112
118
  return config
113
119
 
114
120
 
115
- def build_model(
116
- checkpoint_path: str, **kwargs
117
- ) -> model_builder.DecoderOnlyModel:
121
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
118
122
  return model_builder.build_decoder_only_model(
119
123
  checkpoint_path=checkpoint_path,
120
124
  config=get_model_config(**kwargs),
121
125
  tensor_names=TENSOR_NAMES,
126
+ model_class=OpenELM,
122
127
  )
@@ -26,6 +26,7 @@ from absl import app
26
26
  from absl import flags
27
27
  from ai_edge_torch.generative.examples.paligemma import paligemma
28
28
  from ai_edge_torch.generative.utilities import converter
29
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
29
30
  import torch
30
31
 
31
32
  _CHECKPOINT_PATH = flags.DEFINE_string(
@@ -73,6 +74,7 @@ def main(_):
73
74
  pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
74
75
  quantize=_QUANTIZE.value,
75
76
  config=pytorch_model.config.decoder_config,
77
+ export_config=ExportConfig(),
76
78
  )
77
79
 
78
80
 
@@ -130,12 +130,10 @@ def get_fake_decoder_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
130
130
  return config
131
131
 
132
132
 
133
- def build_decoder(
134
- checkpoint_path: str, **kwargs
135
- ) -> model_builder.DecoderOnlyModel:
136
- decoder = Decoder(get_decoder_config(**kwargs))
137
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
138
- # Loose the strictness because only decoder is being loaded.
139
- loader.load(decoder, strict=False)
140
- decoder.eval()
141
- return decoder
133
+ def build_decoder(checkpoint_path: str, **kwargs) -> torch.nn.Module:
134
+ return model_builder.build_decoder_only_model(
135
+ checkpoint_path=checkpoint_path,
136
+ config=get_decoder_config(**kwargs),
137
+ tensor_names=TENSOR_NAMES,
138
+ model_class=Decoder,
139
+ )
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.phi import phi3
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
61
62
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
63
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
64
  quantize=_QUANTIZE.value,
65
+ export_config=ExportConfig(),
64
66
  )
65
67
 
66
68
 
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.phi import phi2
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -61,6 +62,7 @@ def main(_):
61
62
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
63
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
64
  quantize=_QUANTIZE.value,
65
+ export_config=ExportConfig(),
64
66
  )
65
67
 
66
68
 
@@ -18,6 +18,7 @@
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
21
+ from torch import nn
21
22
 
22
23
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
23
24
  ff_up_proj="model.layers.{}.mlp.fc1",
@@ -33,6 +34,11 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
33
34
  )
34
35
 
35
36
 
37
+ class Phi2(model_builder.DecoderOnlyModel):
38
+ """A Phi-2 model built from the Edge Generative API layers."""
39
+ pass
40
+
41
+
36
42
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
37
43
  """Returns the model config for a Phi-2 model.
38
44
 
@@ -92,11 +98,10 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
92
98
  return config
93
99
 
94
100
 
95
- def build_model(
96
- checkpoint_path: str, **kwargs
97
- ) -> model_builder.DecoderOnlyModel:
101
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
98
102
  return model_builder.build_decoder_only_model(
99
103
  checkpoint_path=checkpoint_path,
100
104
  config=get_model_config(**kwargs),
101
105
  tensor_names=TENSOR_NAMES,
106
+ model_class=Phi2,
102
107
  )
@@ -207,13 +207,11 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
207
207
  return config
208
208
 
209
209
 
210
- def build_model(
211
- checkpoint_path: str, **kwargs
212
- ) -> model_builder.DecoderOnlyModel:
210
+ def build_model(checkpoint_path: str, **kwargs) -> torch.nn.Module:
213
211
  """Instantiates the model instance and load checkpoint if provided."""
214
- config = get_model_config(**kwargs)
215
- model = Phi3_5Mini(config)
216
- loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
217
- loader.load(model)
218
- model.eval()
219
- return model
212
+ return model_builder.build_decoder_only_model(
213
+ checkpoint_path=checkpoint_path,
214
+ config=get_model_config(**kwargs),
215
+ tensor_names=TENSOR_NAMES,
216
+ model_class=Phi3_5Mini,
217
+ )
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.qwen import qwen
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _MODEL_SIZE = flags.DEFINE_enum(
27
28
  'model_size',
@@ -76,6 +77,7 @@ def main(_):
76
77
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
77
78
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
78
79
  quantize=_QUANTIZE.value,
80
+ export_config=ExportConfig(),
79
81
  )
80
82
 
81
83
 
@@ -17,10 +17,16 @@
17
17
 
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
20
21
 
21
22
  TENSOR_NAMES = model_builder.TENSOR_NAMES
22
23
 
23
24
 
25
+ class Qwen(model_builder.DecoderOnlyModel):
26
+ """A Qwen model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
24
30
  def get_3b_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
25
31
  """Returns the model config for a Qwen 2.5 3B model.
26
32
 
@@ -101,31 +107,28 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
101
107
  return config
102
108
 
103
109
 
104
- def build_3b_model(
105
- checkpoint_path: str, **kwargs
106
- ) -> model_builder.DecoderOnlyModel:
110
+ def build_3b_model(checkpoint_path: str, **kwargs) -> nn.Module:
107
111
  return model_builder.build_decoder_only_model(
108
112
  checkpoint_path=checkpoint_path,
109
113
  config=get_3b_model_config(**kwargs),
110
114
  tensor_names=TENSOR_NAMES,
115
+ model_class=Qwen,
111
116
  )
112
117
 
113
118
 
114
- def build_1_5b_model(
115
- checkpoint_path: str, **kwargs
116
- ) -> model_builder.DecoderOnlyModel:
119
+ def build_1_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
117
120
  return model_builder.build_decoder_only_model(
118
121
  checkpoint_path=checkpoint_path,
119
122
  config=get_1_5b_model_config(**kwargs),
120
123
  tensor_names=TENSOR_NAMES,
124
+ model_class=Qwen,
121
125
  )
122
126
 
123
127
 
124
- def build_0_5b_model(
125
- checkpoint_path: str, **kwargs
126
- ) -> model_builder.DecoderOnlyModel:
128
+ def build_0_5b_model(checkpoint_path: str, **kwargs) -> nn.Module:
127
129
  return model_builder.build_decoder_only_model(
128
130
  checkpoint_path=checkpoint_path,
129
131
  config=get_0_5b_model_config(**kwargs),
130
132
  tensor_names=TENSOR_NAMES,
133
+ model_class=Qwen,
131
134
  )
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.smollm import smollm
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -54,6 +55,7 @@ def main(_):
54
55
  pytorch_model = smollm.build_model(
55
56
  _CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
56
57
  )
58
+
57
59
  quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
58
60
  output_filename = f'smollm_{quant_suffix}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
59
61
  converter.convert_to_tflite(
@@ -61,6 +63,7 @@ def main(_):
61
63
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
62
64
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
63
65
  quantize=_QUANTIZE.value,
66
+ export_config=ExportConfig(),
64
67
  )
65
68
 
66
69
 
@@ -17,10 +17,16 @@
17
17
 
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
20
21
 
21
22
  TENSOR_NAMES = model_builder.TENSOR_NAMES
22
23
 
23
24
 
25
+ class SmolLM(model_builder.DecoderOnlyModel):
26
+ """A SmolLM model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
24
30
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
25
31
  """Returns the model config for a SmolLM 135M model.
26
32
 
@@ -72,11 +78,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
72
78
  return config
73
79
 
74
80
 
75
- def build_model(
76
- checkpoint_path: str, **kwargs
77
- ) -> model_builder.DecoderOnlyModel:
81
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
78
82
  return model_builder.build_decoder_only_model(
79
83
  checkpoint_path=checkpoint_path,
80
84
  config=get_model_config(**kwargs),
81
85
  tensor_names=TENSOR_NAMES,
86
+ model_class=SmolLM,
82
87
  )
@@ -15,13 +15,14 @@
15
15
 
16
16
  """A toy example which has basic transformer block (w/ externalized KV-Cache)."""
17
17
 
18
- from typing import Tuple
18
+ from typing import Optional, Tuple
19
19
 
20
20
  from ai_edge_torch.generative.layers import attention
21
21
  from ai_edge_torch.generative.layers import builder
22
22
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
23
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
  import torch
26
27
  from torch import nn
27
28
 
@@ -62,6 +63,7 @@ class ToyModelWithKVCache(torch.nn.Module):
62
63
  tokens: torch.Tensor,
63
64
  input_pos: torch.Tensor,
64
65
  kv_cache: kv_utils.KVCache,
66
+ export_config: Optional[ExportConfig] = None,
65
67
  ) -> Tuple[torch.Tensor, kv_utils.KVCache]:
66
68
  x = self.tok_embedding(tokens)
67
69
  cos, sin = self.rope_cache
@@ -77,8 +79,16 @@ class ToyModelWithKVCache(torch.nn.Module):
77
79
  if kv_entry:
78
80
  updated_kv_entires.append(kv_entry)
79
81
 
80
- x = self.final_norm(x)
81
82
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
83
+
84
+ if export_config is not None:
85
+ if (
86
+ torch.numel(input_pos) > 1
87
+ and not export_config.output_logits_on_prefill
88
+ ):
89
+ return {'kv_cache': updated_kv_cache}
90
+
91
+ x = self.final_norm(x)
82
92
  return {'logits': self.lm_head(x), 'kv_cache': updated_kv_cache}
83
93
 
84
94
 
@@ -22,6 +22,7 @@ from absl import app
22
22
  from absl import flags
23
23
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
24
24
  from ai_edge_torch.generative.utilities import converter
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
25
26
 
26
27
  _CHECKPOINT_PATH = flags.DEFINE_string(
27
28
  'checkpoint_path',
@@ -63,6 +64,7 @@ def main(_):
63
64
  tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
64
65
  prefill_seq_len=_PREFILL_SEQ_LENS.value,
65
66
  quantize=_QUANTIZE.value,
67
+ export_config=ExportConfig(),
66
68
  )
67
69
 
68
70
 
@@ -17,10 +17,16 @@
17
17
 
18
18
  import ai_edge_torch.generative.layers.model_config as cfg
19
19
  from ai_edge_torch.generative.utilities import model_builder
20
+ from torch import nn
20
21
 
21
22
  TENSOR_NAMES = model_builder.TENSOR_NAMES_WITH_SEPARATE_LM_HEAD
22
23
 
23
24
 
25
+ class TinyLlama(model_builder.DecoderOnlyModel):
26
+ """A TinyLlama model built from the Edge Generative API layers."""
27
+ pass
28
+
29
+
24
30
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
25
31
  """Returns the model config for a TinyLlama model.
26
32
 
@@ -73,11 +79,10 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
73
79
  return config
74
80
 
75
81
 
76
- def build_model(
77
- checkpoint_path: str, **kwargs
78
- ) -> model_builder.DecoderOnlyModel:
82
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
79
83
  return model_builder.build_decoder_only_model(
80
84
  checkpoint_path=checkpoint_path,
81
85
  config=get_model_config(**kwargs),
82
86
  tensor_names=TENSOR_NAMES,
87
+ model_class=TinyLlama,
83
88
  )
@@ -185,7 +185,7 @@ class TestModelConversion(googletest.TestCase):
185
185
  )
186
186
  def test_tiny_llama_multisig(self):
187
187
  config = tiny_llama.get_fake_model_config()
188
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
188
+ pytorch_model = tiny_llama.TinyLlama(config).eval()
189
189
  self._test_multisig_model(config, pytorch_model, atol=1e-5, rtol=1e-5)
190
190
 
191
191
 
@@ -93,7 +93,7 @@ class TestModelConversion(googletest.TestCase):
93
93
  )
94
94
  def test_gemma1(self):
95
95
  config = gemma1.get_fake_model_config()
96
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
96
+ pytorch_model = gemma1.Gemma1(config).eval()
97
97
  self._test_model(
98
98
  config, pytorch_model, "serving_default", atol=1e-2, rtol=1e-5
99
99
  )
@@ -122,7 +122,7 @@ class TestModelConversion(googletest.TestCase):
122
122
  )
123
123
  def test_phi2(self):
124
124
  config = phi2.get_fake_model_config()
125
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
125
+ pytorch_model = phi2.Phi2(config).eval()
126
126
  self._test_model(
127
127
  config, pytorch_model, "serving_default", atol=1e-3, rtol=1e-3
128
128
  )
@@ -142,7 +142,7 @@ class TestModelConversion(googletest.TestCase):
142
142
  )
143
143
  def test_smollm(self):
144
144
  config = smollm.get_fake_model_config()
145
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
145
+ pytorch_model = smollm.SmolLM(config).eval()
146
146
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
147
147
 
148
148
  @googletest.skipIf(
@@ -151,7 +151,7 @@ class TestModelConversion(googletest.TestCase):
151
151
  )
152
152
  def test_openelm(self):
153
153
  config = openelm.get_fake_model_config()
154
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
154
+ pytorch_model = openelm.OpenElm(config).eval()
155
155
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
156
156
 
157
157
  @googletest.skipIf(
@@ -160,7 +160,7 @@ class TestModelConversion(googletest.TestCase):
160
160
  )
161
161
  def test_qwen(self):
162
162
  config = qwen.get_fake_model_config()
163
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
163
+ pytorch_model = qwen.Qwen(config).eval()
164
164
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
165
165
 
166
166
  @googletest.skipIf(
@@ -169,7 +169,7 @@ class TestModelConversion(googletest.TestCase):
169
169
  )
170
170
  def test_amd_llama_135m(self):
171
171
  config = amd_llama_135m.get_fake_model_config()
172
- pytorch_model = model_builder.DecoderOnlyModel(config).eval()
172
+ pytorch_model = amd_llama_135m.AmdLlama(config).eval()
173
173
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
174
174
 
175
175
  @googletest.skipIf(
@@ -15,13 +15,28 @@
15
15
 
16
16
  """Common utility functions for model conversion."""
17
17
 
18
- from typing import Union
18
+ from functools import partial
19
+ from typing import Any, Union
19
20
 
20
21
  from ai_edge_torch._convert import converter as converter_utils
21
22
  import ai_edge_torch.generative.layers.kv_cache as kv_utils
22
23
  import ai_edge_torch.generative.layers.model_config as cfg
23
24
  from ai_edge_torch.generative.quantize import quant_recipes
25
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
24
26
  import torch
27
+ import torch.nn as nn
28
+
29
+
30
+ class ExportableModule(torch.nn.Module):
31
+
32
+ def __init__(self, module, **extra_kwargs):
33
+ super().__init__()
34
+ self.module = module
35
+ self.extra_kwargs = extra_kwargs
36
+
37
+ def forward(self, *export_args, **export_kwargs):
38
+ full_kwargs = {**export_kwargs, **self.extra_kwargs}
39
+ return self.module(*export_args, **full_kwargs)
25
40
 
26
41
 
27
42
  def convert_to_tflite(
@@ -31,6 +46,7 @@ def convert_to_tflite(
31
46
  pixel_values_size: torch.Size = None,
32
47
  quantize: bool = True,
33
48
  config: cfg.ModelConfig = None,
49
+ export_config: ExportConfig = None,
34
50
  ):
35
51
  """Converts a nn.Module model to multi-signature tflite model.
36
52
 
@@ -97,6 +113,11 @@ def convert_to_tflite(
97
113
  )
98
114
 
99
115
  quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
116
+
117
+ # For export, we create a module that captures any non-exportable,
118
+ # arugments, e.g. the generation config object.
119
+ mod = ExportableModule(pytorch_model, export_config=export_config)
120
+
100
121
  converter = converter_utils.Converter()
101
122
  for i in range(len(prefill_seq_lens)):
102
123
  prefill_seq_len = prefill_seq_lens[i]
@@ -108,7 +129,7 @@ def convert_to_tflite(
108
129
  prefill_signature_name = f'prefill_{prefill_seq_len}'
109
130
  converter.add_signature(
110
131
  prefill_signature_name,
111
- pytorch_model,
132
+ mod,
112
133
  sample_kwargs={
113
134
  'tokens': prefill_tokens,
114
135
  'input_pos': prefill_input_pos,
@@ -118,7 +139,7 @@ def convert_to_tflite(
118
139
  if prefill_pixel_values is not None:
119
140
  converter.add_signature(
120
141
  prefill_signature_name + '_pixel',
121
- pytorch_model,
142
+ mod,
122
143
  sample_kwargs={
123
144
  'tokens': prefill_tokens,
124
145
  'input_pos': prefill_input_pos,
@@ -129,7 +150,7 @@ def convert_to_tflite(
129
150
 
130
151
  converter.add_signature(
131
152
  'decode',
132
- pytorch_model,
153
+ mod,
133
154
  sample_kwargs={
134
155
  'tokens': decode_token,
135
156
  'input_pos': decode_input_pos,
@@ -16,7 +16,8 @@
16
16
  """Utilities to be used for re-authoring transformer models."""
17
17
 
18
18
  import copy
19
- from typing import Tuple
19
+ from dataclasses import dataclass
20
+ from typing import Optional, Tuple
20
21
 
21
22
  from ai_edge_torch.generative.layers import attention
22
23
  from ai_edge_torch.generative.layers import builder
@@ -45,6 +46,15 @@ TENSOR_NAMES_WITH_SEPARATE_LM_HEAD = copy.copy(TENSOR_NAMES)
45
46
  TENSOR_NAMES_WITH_SEPARATE_LM_HEAD.lm_head = "lm_head"
46
47
 
47
48
 
49
+ @dataclass
50
+ class ExportConfig:
51
+ """Model generating configuration settings."""
52
+
53
+ # On prefill signatures, should the model produce logit output?
54
+ # When False, only decode signatures will produce output.
55
+ output_logits_on_prefill: bool = False
56
+
57
+
48
58
  class DecoderOnlyModel(nn.Module):
49
59
  """A simple decoder-only transformer model built from the Edge Generative API.
50
60
 
@@ -93,6 +103,7 @@ class DecoderOnlyModel(nn.Module):
93
103
  tokens: torch.Tensor,
94
104
  input_pos: torch.Tensor,
95
105
  kv_cache: kv_utils.KVCache,
106
+ export_config: Optional[ExportConfig] = None,
96
107
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
97
108
  _, seq_len = tokens.size()
98
109
  assert self.config.max_seq_len >= seq_len, (
@@ -108,7 +119,7 @@ class DecoderOnlyModel(nn.Module):
108
119
  mask = mask[:, :, :, : self.config.kv_cache_max]
109
120
 
110
121
  return self.forward_with_embeds(
111
- input_embeds, rope, mask, input_pos, kv_cache
122
+ input_embeds, rope, mask, input_pos, kv_cache, export_config
112
123
  )
113
124
 
114
125
  def forward_with_embeds(
@@ -118,6 +129,7 @@ class DecoderOnlyModel(nn.Module):
118
129
  mask: torch.Tensor,
119
130
  input_pos: torch.Tensor,
120
131
  kv_cache: kv_utils.KVCache,
132
+ export_config: Optional[ExportConfig] = None,
121
133
  ) -> dict[torch.Tensor, kv_utils.KVCache]:
122
134
  """Forwards the model with input embeddings."""
123
135
  assert len(self.transformer_blocks) == len(kv_cache.caches), (
@@ -137,6 +149,13 @@ class DecoderOnlyModel(nn.Module):
137
149
  updated_kv_entires.append(kv_entry)
138
150
  updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
139
151
 
152
+ if export_config is not None:
153
+ if (
154
+ torch.numel(input_pos) > 1
155
+ and not export_config.output_logits_on_prefill
156
+ ):
157
+ return {"kv_cache": updated_kv_cache}
158
+
140
159
  x = self.final_norm(x)
141
160
  logits = self.lm_head(x) # (b, t, vocab_size)
142
161
  return {"logits": logits, "kv_cache": updated_kv_cache}
@@ -146,8 +165,9 @@ def build_decoder_only_model(
146
165
  checkpoint_path: str,
147
166
  config: cfg.ModelConfig,
148
167
  tensor_names: loading_utils.ModelLoader.TensorNames,
149
- ) -> DecoderOnlyModel:
150
- transformer = DecoderOnlyModel(config)
168
+ model_class: type[nn.Module] = DecoderOnlyModel,
169
+ ) -> nn.Module:
170
+ transformer = model_class(config)
151
171
  loader = loading_utils.ModelLoader(checkpoint_path, tensor_names)
152
172
  loader.load(
153
173
  transformer, strict=not config.lm_head_share_weight_with_embedding
@@ -29,7 +29,7 @@ class TransformersModelWrapper(verifier.ModelWrapper):
29
29
  an object with `logits` field.
30
30
 
31
31
  Transformers models get `max_new_tokens` settings for generate() via
32
- GenerationConfig.
32
+ ExportConfig.
33
33
  """
34
34
 
35
35
  def forward(self, tokens: torch.Tensor) -> torch.Tensor:
@@ -38,5 +38,5 @@ class TransformersModelWrapper(verifier.ModelWrapper):
38
38
  def generate(
39
39
  self, inputs: torch.Tensor, max_new_tokens: int
40
40
  ) -> torch.IntTensor:
41
- gen_config = transformers.GenerationConfig(max_new_tokens=max_new_tokens)
42
- return self.model.generate(inputs=inputs, generation_config=gen_config)
41
+ export_config = transformers.ExportConfig(max_new_tokens=max_new_tokens)
42
+ return self.model.generate(inputs=inputs, generation_config=export_config)
@@ -19,6 +19,7 @@ import logging
19
19
  from typing import List
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
+ from ai_edge_torch.generative.utilities.model_builder import ExportConfig
22
23
  import torch
23
24
 
24
25
 
@@ -40,6 +41,7 @@ class ModelWrapper(torch.nn.Module):
40
41
  """
41
42
  super().__init__()
42
43
  self.model = model
44
+ self.export_config = ExportConfig(output_logits_on_prefill=True)
43
45
 
44
46
  def forward(
45
47
  self, tokens: torch.Tensor, pixel_values: torch.Tensor = None
@@ -103,13 +105,25 @@ class ReauthoredModelWrapper(ModelWrapper):
103
105
  Returns:
104
106
  The output logits and the updated KV cache.
105
107
  """
108
+ # Verification requires logit outputs on prefill for comparison.
109
+ if (
110
+ self.export_config is not None
111
+ and not self.export_config.output_logits_on_prefill
112
+ ):
113
+ raise ValueError("Verifier requires logit output on prefill.")
106
114
  # Since the reauthored model doesn't include keyword arguments, pass
107
115
  # pixel_values only when it is not None. Otherwise, it may raise an error.
108
116
  if pixel_values is None:
109
- output = self.model.forward(tokens, input_pos, kv_cache)
117
+ output = self.model.forward(
118
+ tokens, input_pos, kv_cache, self.export_config
119
+ )
110
120
  else:
111
121
  output = self.model.forward(
112
- tokens, input_pos, kv_cache, pixel_values=pixel_values
122
+ tokens,
123
+ input_pos,
124
+ kv_cache,
125
+ pixel_values=pixel_values,
126
+ export_config=self.export_config,
113
127
  )
114
128
  return output["logits"], output["kv_cache"]
115
129
 
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241209"
16
+ __version__ = "0.3.0.dev20241210"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241209
3
+ Version: 0.3.0.dev20241210
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=vqIQlpl8YMYWieiQ_ll8mjORaCsx5yY_5yEP0XERIiw,706
6
+ ai_edge_torch/version.py,sha256=AYxcupivW-iYIlCjWXl-QtEvpRsQqFcNK9I6uyGDqaU,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -40,50 +40,50 @@ ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
40
40
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
41
41
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
42
  ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
43
- ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=bkq2ZknJfuY7WC8wLVg92Z6eA_aMDbkgwaMxvmDW4_0,2618
44
- ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=-n79r6yFnCACpms5eMkXNpyQsCn2PYVRdB-jOoIqn14,2227
43
+ ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif89PyCXbdXT5spOeDvdM5luJ-a5HaXHM86v4JnU,2766
44
+ ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=Oqlg5ZoUuG2aU3067QaPpmEXWOdB8GEq7u_NWoBpoB4,2337
45
45
  ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
46
46
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
- ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=mrG96_WEGD4NQ4uFEKrHRMAQvVVliOcj1zbI3drGDjI,2199
48
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=I_tvwCYmtf08D1HqDxYx7dpvj2q5_eaYnuI_3rI6Dlw,2201
49
- ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
50
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
47
+ ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=GhwtQZ1xuMyKJl8qdxU6uKavQnlm5US9xhKJvdmgACc,2309
48
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=hsy4Gd7Inchi0p_Cc5yecH6vr9A7X4MvmQNfTt8N2sQ,2311
49
+ ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=N0jKVZA3qWKOaHVbIM3WmQh3u0Sq7MTw_oO3Zo16wCw,3456
50
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=whQ6DEnmhmj9hd5OyaoEI-FUNJ4m302vY3Swo_IqQcA,9285
51
51
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
52
52
  ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
53
53
  ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
54
54
  ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
55
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=Brb83sbqBfStUiIZFhfWnYtN7LcNmkKyFn96cZK4sGo,2426
56
- ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaXEJ3RLLwkHii7to7UhVo,6854
55
+ ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=ck7tXN0U25wAbbRjDcf-aqiS2YhismkmoZIsMpjIsjc,2536
56
+ ai_edge_torch/generative/examples/llama/llama.py,sha256=BMjpdw6oOXmtqXCAfW9o7Iewaj-Hxd57xVrvSLBuHTk,6656
57
57
  ai_edge_torch/generative/examples/llama/verify.py,sha256=X7oKQi85M789ugBrOlMvzk8eSRR3Kf1Mprfl-U-WIpo,2842
58
58
  ai_edge_torch/generative/examples/moonshine/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
59
59
  ai_edge_torch/generative/examples/moonshine/convert_moonshine_to_tflite.py,sha256=7m3rYRzThRDYb-7pGnpLr3ACi4PWX07Mg20Q98ArPc4,1714
60
60
  ai_edge_torch/generative/examples/moonshine/moonshine.py,sha256=nZ2b8u4TmsB5sgdClgAuH8E78bcTv9RCnF9666HqP2M,3394
61
61
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
62
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=-qDBu3bjUq0jx73SPDMsPIBP0BT1nA_0UgtFKeSuM18,2213
63
- ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
62
+ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=nji1oDgf6xImvGh95--8cNl3QPs-Xml2XBgNJB_c2hY,2323
63
+ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sIJ8Ie1oxFrJM-1jvv2ukiJbQOTIUGuMEZvmwZbt3n0,4556
64
64
  ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
65
65
  ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
66
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=dT7dnx1dzGzFiH5gQJ4M6zcTLSRFvSDpi3IuZ9_vd78,2706
67
- ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=XMeznGBbjRJidv725L6_7XzkYskS2cDjf8NGB18FNhg,4944
66
+ ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=rPFqcsv8RHvjmgfBW9OL6EKxMtVX-ySjBsMP4N8FErk,2816
67
+ ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=f_A3GWcLrP0nRq2Tq-fThfXIQVJ-EYWoExYLO_6iVIQ,4866
68
68
  ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
69
69
  ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=pIjsS-IUFevRjFA9153YT1vtWXATGWHsgVQQX_nWaZQ,5280
70
70
  ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
71
71
  ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
72
72
  ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
73
73
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
74
- ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=ruY-LLwpqBqVZ5z9h_sewYj04ukWRG4j804tUAyDdnA,2186
75
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=UdMk1SSkcWpv8gosUylx3JSCxdOJBjZNhuQQtT4-Ono,2184
76
- ai_edge_torch/generative/examples/phi/phi2.py,sha256=nbivDwZREd-sypy_ittO59-yaAdPvHv1YEV6Fo5buCo,3341
77
- ai_edge_torch/generative/examples/phi/phi3.py,sha256=GkHOaYfsFEbHvfZCaLlb3Us_h19ezqPDUakoz_DiG9A,7123
74
+ ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=cD8rtwgYeGrXB9sYVV_D1AB8Up1AWNS-1XtrRlyzE5o,2296
75
+ ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=G1i_ybDCTBaOD1OOCTk6jqOf__xYYZvhXcxY8MXhPHw,2294
76
+ ai_edge_torch/generative/examples/phi/phi2.py,sha256=c6PYCky7yJn6MVIYOCTx8S_CH27kOPmJbRZcI95nbZs,3477
77
+ ai_edge_torch/generative/examples/phi/phi3.py,sha256=7Y1E4XpRuZOiSbeZJ-C2uJjmlnDtWv6L0XvPRE8oEQs,7112
78
78
  ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
79
79
  ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
80
80
  ai_edge_torch/generative/examples/qwen/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
81
- ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=1M3DTkf536TCLYcQB1lu-3TxQ6mV03dFhTdbk0p8i84,2523
82
- ai_edge_torch/generative/examples/qwen/qwen.py,sha256=oYm9hhALUQ4uOn-PO1bF7fCIGP8EWRNK4zClkx2RQs8,4070
81
+ ai_edge_torch/generative/examples/qwen/convert_to_tflite.py,sha256=BHkDsivbbfVBPxknkgWwtLskvxyrd42TXuCB0aLVbMY,2633
82
+ ai_edge_torch/generative/examples/qwen/qwen.py,sha256=Zi_qiQ1JPokXZ95jgSEnQp3F-LKzFCvWvFLKhJjnASo,4199
83
83
  ai_edge_torch/generative/examples/qwen/verify.py,sha256=9_AyEJTeUfvhhID64Rto2bflFPyXMFokdQLsseLUMiI,2775
84
84
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
85
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=56CzCjyp9xh_2ZtXKN9tlEv6GayeSc4giTIZsi2Q59E,2194
86
- ai_edge_torch/generative/examples/smollm/smollm.py,sha256=M5qAcSUE5gxOSfq24a8lZku9kgvmlFCyIBar3kF2XEk,2570
85
+ ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=RKmSBMrup5A2bsPPaTdrBQb8NeRiUHy_1SUOA8DAs9U,2305
86
+ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=kk3cB_qaCzbFOhHtJlLb7qvSEBQTsILnoAcSFE3AkpE,2711
87
87
  ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
88
88
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
89
89
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
@@ -107,10 +107,10 @@ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqX
107
107
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
108
108
  ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
109
109
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
110
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=rRodLr-hEqAs_-8x06O8qO-hJ_cqr2AfhJZ9DCptvwo,4332
110
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=WMl1iuCE8So9FDnxPV0OTMzuPngQUTO61g8rfnBLyB4,4664
111
111
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
112
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=WmEshoN9HgNLbV7NTjxdqWz9Olcim6r_vo4R9eYE98I,2228
113
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=10X8HwPx4akzclnIMOBNItKQemhRbvxBbTo7nwZtWjM,2650
112
+ ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=5rgbTIxHoFg8sTnzrGA_ekT-HJEt9p7Dla7cIY874jU,2338
113
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
114
114
  ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
115
115
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
116
116
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
@@ -139,20 +139,20 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
139
139
  ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1ey1tNTlSdkp6TeYEijp0,3301
140
140
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
141
141
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
142
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=aZFaheg2sq7rEccch1TZM6W4BSfpJZjrM9Gyp4hVGYs,6351
143
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=xWV9O2wuRHc4VNBWuWipiuqXa3AJhiV1nmjewAZHHWM,11177
142
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=i3tQ6mEAo9lCctNoqFAnULk94hgKncC4ywn8IvgbUOo,6341
143
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=IBuvXvORtHu3khr3mLJzYXyCd-zQLUdURTfH28Oo9e0,11079
144
144
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
145
145
  ai_edge_torch/generative/test/utils.py,sha256=eQ-hjd1eXuHJF3SJK6_CrjgOZVzmG_4VEdH7Z1gH_lA,1897
146
146
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
147
- ai_edge_torch/generative/utilities/converter.py,sha256=S14STbyxV6A9HKy1BdUo49f2jS6Ij0RL9mVAFUMWYV8,5291
147
+ ai_edge_torch/generative/utilities/converter.py,sha256=hIwWUWjgPvWLATtsYYG6RWbFQWhOr2RpPlMrd-4Am9U,5959
148
148
  ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
149
149
  ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
150
- ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
150
+ ai_edge_torch/generative/utilities/model_builder.py,sha256=rfD6INxunvDVdiUfTUxD7yy0dRxL74W7kVmZsxUjpOQ,6379
151
151
  ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
152
152
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
153
153
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
154
- ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
155
- ai_edge_torch/generative/utilities/verifier.py,sha256=GLh7h8pcpSKtCKoPyxJhv3TmvENd2h6ek_cnbe2s3Ak,11418
154
+ ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=nHmI27ybu7lj8Ufw2LzmCwRDqEwNppIFNTx5ltLHIgE,1547
155
+ ai_edge_torch/generative/utilities/verifier.py,sha256=1NcmT_55Sb5e5spnHab4x5wqJZi2CKKVtXuXgK3lE6Q,11927
156
156
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
157
157
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
158
158
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -200,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
200
200
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
201
201
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
202
202
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
203
- ai_edge_torch_nightly-0.3.0.dev20241209.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
204
- ai_edge_torch_nightly-0.3.0.dev20241209.dist-info/METADATA,sha256=CgWO0uYKYs2y1On-UazTKeOqO9EIIof0veLSK5q0wmk,1897
205
- ai_edge_torch_nightly-0.3.0.dev20241209.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
206
- ai_edge_torch_nightly-0.3.0.dev20241209.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
207
- ai_edge_torch_nightly-0.3.0.dev20241209.dist-info/RECORD,,
203
+ ai_edge_torch_nightly-0.3.0.dev20241210.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
204
+ ai_edge_torch_nightly-0.3.0.dev20241210.dist-info/METADATA,sha256=SM6aXiKe6YYFKtS0NbSZwwYIdZES74y0X7wautX45S4,1897
205
+ ai_edge_torch_nightly-0.3.0.dev20241210.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
206
+ ai_edge_torch_nightly-0.3.0.dev20241210.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
207
+ ai_edge_torch_nightly-0.3.0.dev20241210.dist-info/RECORD,,