ai-edge-torch-nightly 0.3.0.dev20240925__py3-none-any.whl → 0.3.0.dev20240927__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (25) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma2.py +0 -2
  2. ai_edge_torch/generative/examples/gemma/verify_gemma1.py +3 -2
  3. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +3 -2
  4. ai_edge_torch/generative/examples/gemma/verify_util.py +15 -25
  5. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/llama.py +203 -0
  9. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  10. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  11. ai_edge_torch/generative/examples/openelm/verify.py +19 -11
  12. ai_edge_torch/generative/examples/phi/phi3.py +15 -21
  13. ai_edge_torch/generative/examples/phi/verify.py +13 -12
  14. ai_edge_torch/generative/examples/phi/verify_phi3.py +13 -12
  15. ai_edge_torch/generative/examples/smollm/verify.py +19 -9
  16. ai_edge_torch/generative/examples/tiny_llama/verify.py +20 -10
  17. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  18. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  19. ai_edge_torch/generative/utilities/verifier.py +130 -114
  20. ai_edge_torch/version.py +1 -1
  21. {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/METADATA +1 -1
  22. {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/RECORD +25 -18
  23. {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/LICENSE +0 -0
  24. {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/WHEEL +0 -0
  25. {ai_edge_torch_nightly-0.3.0.dev20240925.dist-info → ai_edge_torch_nightly-0.3.0.dev20240927.dist-info}/top_level.txt +0 -0
@@ -15,116 +15,130 @@
15
15
 
16
16
  """Common utility functions to verify the reauthored models."""
17
17
 
18
- import datetime
19
- from typing import List, Optional, Union
18
+ import logging
19
+ from typing import List
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
22
  import torch
23
- import transformers
24
-
25
-
26
- def log_msg(*args):
27
- print("[%s]" % datetime.datetime.now(), *args)
28
23
 
29
24
 
30
25
  class ModelWrapper(torch.nn.Module):
31
- """A wrapper for the model to be verified, this could be a HuggingFace model
26
+ """A wrapper for the model to be verified.
32
27
 
33
- or a regular PyTorch model.
28
+ It unifies the interface of forward() and generate() of models for the
29
+ verification to call.
34
30
  """
35
31
 
36
- def __init__(
37
- self,
38
- model: torch.nn.Module,
39
- model_format: str = "huggingface",
40
- hf_generation_config: Optional[transformers.GenerationConfig] = None,
41
- ):
32
+ def __init__(self, model: torch.nn.Module):
42
33
  """Initializes the wrapper.
43
34
 
44
35
  Args:
45
- model (torch.nn.Module): The original model. This could be a model built
46
- from HuggingFace transformers, or a regular PyTorch model.
47
- model_format (str): The format of the model. It should be either
48
- "huggingface" or "pytorch".
49
- hf_generation_config (transformers.GenerationConfig): The HuggingFace
50
- generation config. This config will only be used if the underlying model
51
- is built from HuggingFace transformers.
36
+ model (torch.nn.Module): The model which might have different interfaces
37
+ of forward() and generate(). It could be a model built from HuggingFace
38
+ transformers, a regular PyTorch model, or a model re-authored with
39
+ ai_edge_torch Generative API.
52
40
  """
53
41
  super().__init__()
54
42
  self.model = model
55
- self.model_format = model_format
56
- self.hf_generation_config = hf_generation_config
43
+
44
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
45
+ """Gets output logits by forwarding the input tokens.
46
+
47
+ Args:
48
+ tokens (torch.Tensor): The input tokens to forward. Its dimension is
49
+ expected to be (batch_size=1, kv_cache_max_len).
50
+
51
+ Returns:
52
+ The output logits.
53
+ """
54
+ raise NotImplementedError("forward() is not implemented.")
57
55
 
58
56
  def generate(
59
- self, inputs: torch.Tensor
60
- ) -> Union[transformers.utils.ModelOutput, torch.LongTensor]:
61
- if self.model_format == "huggingface":
62
- return self.model.generate(
63
- inputs=inputs, generation_config=self.hf_generation_config
64
- )
65
- else:
66
- raise NotImplementedError(
67
- "generate() is not implemented for model format: %s"
68
- % self.model_format
69
- )
57
+ self, prompts: torch.Tensor, max_new_tokens: int
58
+ ) -> torch.IntTensor:
59
+ """Returns the response token IDs to the given prompts tensor.
70
60
 
71
- def forward(
72
- self,
73
- inputs: torch.Tensor,
74
- ):
75
- return self.model.forward(inputs)
61
+ The maximum number of tokens to generate might be set by subclasses.
76
62
 
63
+ Args:
64
+ prompts (torch.Tensor): The input token IDs to generate with. Its shape is
65
+ expected to be (batch_size=1, input_ids_len).
66
+ max_new_tokens (int): The maximum number of response token IDs to
67
+ generate.
68
+
69
+ Returns:
70
+ The tensor of response token IDs with shape of (batch_size=1,
71
+ response_ids_len).
72
+ """
73
+ raise NotImplementedError("generate() is not implemented.")
77
74
 
78
- def forward(
79
- model: torch.nn.Module,
80
- tokens: torch.Tensor,
81
- kv_cache: kv_utils.KVCache,
82
- ) -> tuple[torch.Tensor, kv_utils.KVCache]:
83
- """Forwards the model reauthored with ai_edge_torch Generative API.
84
75
 
85
- Args:
86
- model (torch.nn.Module): The model to forward. It should be a model built
87
- with ai_edge_torch Generative API.
88
- tokens (torch.Tensor): The input tokens to forward.
89
- kv_cache (KVCache): The KV cache to forward.
76
+ class ReauthoredModelWrapper(ModelWrapper):
77
+ """A wrapper for the model reauthored with ai_edge_torch Generative API."""
90
78
 
91
- Returns:
92
- The output logits and the updated KV cache.
93
- """
94
- input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
95
- output = model.forward(tokens, input_pos, kv_cache)
96
- return output["logits"], output["kv_cache"]
79
+ def _init_kv_cache(self):
80
+ """Returns an initialized KV cache."""
81
+ return kv_utils.KVCache.from_model_config(self.model.config)
97
82
 
83
+ def _forward_with_kv_cache(
84
+ self,
85
+ tokens: torch.Tensor,
86
+ kv_cache: kv_utils.KVCache,
87
+ ) -> tuple[torch.Tensor, kv_utils.KVCache]:
88
+ """Forwards the model and updates an external KV cache.
98
89
 
99
- def generate(
100
- model: torch.nn.Module, prompts: torch.Tensor, response_len: int
101
- ) -> torch.Tensor:
102
- """Generates the response to the prompts.
90
+ Args:
91
+ tokens (torch.Tensor): The input tokens to forward.
92
+ kv_cache (KVCache): The KV cache to forward.
103
93
 
104
- It appends tokens output by the model to the prompts and feeds them back to
105
- the model up to decode_len.
94
+ Returns:
95
+ The output logits and the updated KV cache.
96
+ """
97
+ input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
98
+ output = self.model.forward(tokens, input_pos, kv_cache)
99
+ return output["logits"], output["kv_cache"]
106
100
 
107
- Args:
108
- model (torch.nn.Module): The model to generate. It should be a model built
109
- with ai_edge_torch Generative API.
110
- prompts (torch.Tensor): The prompts to generate.
111
- response_len (int): The number of tokens to generate.
101
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
102
+ logits, _ = self._forward_with_kv_cache(tokens, self._init_kv_cache())
103
+ return logits
112
104
 
113
- Returns:
114
- The generated tokens.
115
- """
116
- input_ids = prompts[0].int().tolist()
117
- kv_cache = kv_utils.KVCache.from_model_config(model.config)
118
- for _ in range(response_len - len(input_ids)):
119
- logits, kv_cache = forward(model, torch.tensor([input_ids]), kv_cache)
120
- generated_token = logits[0][-1].argmax().item()
121
- input_ids.append(generated_token)
122
- return torch.tensor([input_ids])
105
+ def generate(
106
+ self, prompts: torch.Tensor, max_new_tokens: int
107
+ ) -> torch.IntTensor:
108
+ input_ids = prompts[0].int().tolist()
109
+ kv_cache = self._init_kv_cache()
110
+ for _ in range(max_new_tokens):
111
+ tokens = torch.tensor([input_ids])
112
+ logits, kv_cache = self._forward_with_kv_cache(tokens, kv_cache)
113
+ generated_token = logits[0][-1].argmax().item()
114
+ input_ids.append(generated_token)
115
+ return torch.tensor([input_ids])
116
+
117
+
118
+ class TokenizerWrapper(torch.nn.Module):
119
+ """A wrapper for the tokenizer used for verification."""
120
+
121
+ def __init__(self, tokenizer: torch.nn.Module):
122
+ """Initializes the wrapper.
123
+
124
+ Args:
125
+ tokenizer (torch.nn.Module): The tokenizer to wrap.
126
+ """
127
+ super().__init__()
128
+ self.tokenizer = tokenizer
129
+
130
+ def encode(self, prompts: str) -> torch.Tensor:
131
+ """Encodes the prompts to token IDs."""
132
+ return self.tokenizer.encode(prompts, return_tensors="pt")
133
+
134
+ def decode(self, token_ids: torch.Tensor) -> str:
135
+ """Decodes the token IDs to a string."""
136
+ return self.tokenizer.decode(token_ids)
123
137
 
124
138
 
125
139
  def verify_with_input_ids(
126
140
  original_model: ModelWrapper,
127
- reauthored_model: torch.nn.Module,
141
+ reauthored_model: ReauthoredModelWrapper,
128
142
  input_ids: List[int],
129
143
  kv_cache_max_len: int = 1024,
130
144
  rtol: float = 1e-05,
@@ -136,8 +150,8 @@ def verify_with_input_ids(
136
150
 
137
151
  Args:
138
152
  original_model (ModelWrapper): The original model.
139
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
140
- Generative API.
153
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
154
+ ai_edge_torch Generative API.
141
155
  input_ids (List[int]): The input token IDs to forward with.
142
156
  kv_cache_max_len (int): The maximum sequence length of the KV cache.
143
157
  rtol (float): The relative tolerance for the comparison.
@@ -149,16 +163,15 @@ def verify_with_input_ids(
149
163
  tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
150
164
  tokens[0, : len(input_ids)] = torch.tensor([input_ids]).int()
151
165
 
152
- log_msg("Forwarding the original model...")
166
+ logging.info("Forwarding the original model...")
153
167
  outputs_original = original_model.forward(tokens)
154
- logits_original = outputs_original.logits[0, len(input_ids) - 1, :]
155
- log_msg("logits_original: ", logits_original)
168
+ logits_original = outputs_original[0, len(input_ids) - 1, :]
169
+ logging.info("logits_original: %s", logits_original)
156
170
 
157
- log_msg("Forwarding the reauthored model...")
158
- kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
159
- outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
160
- logits_reauthored = outputs_reauthored[0][0, len(input_ids) - 1, :]
161
- log_msg("logits_reauthored:", logits_reauthored)
171
+ logging.info("Forwarding the reauthored model...")
172
+ outputs_reauthored = reauthored_model.forward(tokens)
173
+ logits_reauthored = outputs_reauthored[0, len(input_ids) - 1, :]
174
+ logging.info("logits_reauthored: %s", logits_reauthored)
162
175
 
163
176
  return torch.allclose(
164
177
  logits_original, logits_reauthored, rtol=rtol, atol=atol
@@ -167,9 +180,10 @@ def verify_with_input_ids(
167
180
 
168
181
  def verify_model_with_prompts(
169
182
  original_model: ModelWrapper,
170
- reauthored_model: torch.nn.Module,
171
- tokenizer: torch.nn.Module,
183
+ reauthored_model: ReauthoredModelWrapper,
184
+ tokenizer: TokenizerWrapper,
172
185
  prompts: str,
186
+ max_new_tokens: int,
173
187
  ) -> bool:
174
188
  """Verifies if the model reauthored generates the same answer of the oringal.
175
189
 
@@ -178,35 +192,36 @@ def verify_model_with_prompts(
178
192
 
179
193
  Args:
180
194
  original_model (ModelWrapper): The original model.
181
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
182
- Generative API.
183
- tokenizer (torch.nn.Module): The tokenizer.
195
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
196
+ ai_edge_torch Generative API.
197
+ tokenizer (TokenizerWrapper): The tokenizer.
184
198
  prompts (str): The input prompts to generate answers.
199
+ max_new_tokens (int): The maximum number of new tokens to generate.
185
200
 
186
201
  Returns:
187
202
  True if the model reauthored generates the same answer of the original.
188
203
  """
189
- prompt_tokens = tokenizer.encode(prompts, return_tensors="pt")
204
+ prompt_tokens = tokenizer.encode(prompts)
190
205
 
191
- log_msg("Generating answer with the original model...")
192
- outputs_original = original_model.generate(prompt_tokens)
206
+ logging.info("Generating answer with the original model...")
207
+ outputs_original = original_model.generate(prompt_tokens, max_new_tokens)
193
208
  response_original = tokenizer.decode(outputs_original[0])
194
- log_msg("outputs_from_original_model: [[", response_original, "]]")
209
+ logging.info("outputs_from_original_model: [[%s]]", response_original)
195
210
 
196
- log_msg("Generating answer with the reauthored model...")
197
- generate_len = len(outputs_original[0])
198
- outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
211
+ logging.info("Generating answer with the reauthored model...")
212
+ outputs_reauthored = reauthored_model.generate(prompt_tokens, max_new_tokens)
199
213
  response_reauthored = tokenizer.decode(outputs_reauthored[0])
200
- log_msg("outputs from reauthored model: [[", response_reauthored, "]]")
214
+ logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
201
215
 
202
216
  return response_original == response_reauthored
203
217
 
204
218
 
205
219
  def verify_reauthored_model(
206
220
  original_model: ModelWrapper,
207
- reauthored_model: torch.nn.Module,
208
- tokenizer: torch.nn.Module,
221
+ reauthored_model: ReauthoredModelWrapper,
222
+ tokenizer: TokenizerWrapper,
209
223
  generate_prompts: List[str],
224
+ max_new_tokens: int = 30,
210
225
  forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
211
226
  rtol: float = 1e-05,
212
227
  atol: float = 1e-05,
@@ -223,29 +238,30 @@ def verify_reauthored_model(
223
238
 
224
239
  Args:
225
240
  original_model (ModelWrapper): The original model.
226
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
227
- Generative API.
228
- tokenizer (torch.nn.Module): The tokenizer.
241
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
242
+ ai_edge_torch Generative API.
243
+ tokenizer (TokenizerWrapper): The tokenizer.
229
244
  generate_prompts (List[str]): List of the input prompts to generate answers.
245
+ max_new_tokens (int): The maximum number of new tokens to generate.
230
246
  forward_input_ids (List[torch.Tensor]): List if ihe input token IDs to
231
247
  forward with.
232
248
  rtol (float): The relative tolerance for the comparison.
233
249
  atol (float): The absolute tolerance for the comparison.
234
250
  """
235
251
  for input_ids in forward_input_ids:
236
- log_msg("Verifying the reauthored model with input IDs:", input_ids)
252
+ logging.info("Verifying the reauthored model with input IDs: %s", input_ids)
237
253
  if verify_with_input_ids(
238
254
  original_model, reauthored_model, input_ids, rtol=rtol, atol=atol
239
255
  ):
240
- log_msg("PASS")
256
+ logging.info("PASS")
241
257
  else:
242
- log_msg("FAILED")
258
+ logging.error("FAILED")
243
259
 
244
260
  for prompts in generate_prompts:
245
- log_msg("Verifying the reauthored model with prompts:", prompts)
261
+ logging.info("Verifying the reauthored model with prompts:%s", prompts)
246
262
  if verify_model_with_prompts(
247
- original_model, reauthored_model, tokenizer, prompts
263
+ original_model, reauthored_model, tokenizer, prompts, max_new_tokens
248
264
  ):
249
- log_msg("PASS")
265
+ logging.info("PASS")
250
266
  else:
251
- log_msg("FAILED")
267
+ logging.error("FAILED")
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240925"
16
+ __version__ = "0.3.0.dev20240927"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240925
3
+ Version: 0.3.0.dev20240927
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=UXj1-90S3RDoHwYSmy9VdMC0Sm3EHt9ESLZbi3hnWus,706
6
+ ai_edge_torch/version.py,sha256=Z1S1T2LEv6zuiaCK0d-JIiQdzcipcMJB4-4vgSwnHyU,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -42,25 +42,31 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
42
42
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
43
43
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
44
44
  ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=cahMzvJNJfShIw4uqoBRX5iBZrI3rvsha6wpNHzeYJ0,6369
45
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=KsM6VlzluTqbodG24IFr3biPxBrLay0z0gmnG0bcU2U,9277
46
- ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=B14IR4mIw6qBVUbiIRdfdUzHMCIJCJ0RFPsYOxA46qc,1776
47
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=kSzn1ITJXqrtNQax1m4WTAnC3gMxBpcvGA7-xTO1Iuw,1802
48
- ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=HBK2d8FcWFoxVDF5zk9sLSbKZEtwZQhX-K_zm4AvQtQ,5160
45
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=UziyJVrR_QXE_vFAagjnn1KluMM74coI89-UcdGTpkQ,9243
46
+ ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
47
+ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=K77k-JpdhIwm3tbBnzpw8HQsFRwAVyszxRo82fR6-q4,1762
48
+ ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=sqltZbnyKemNvKqqi9d09i74gP-PPQFodRYfDfnhycQ,4933
49
+ ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
+ ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py,sha256=_OrerrTA6tvP9Tnwj601QO95Cm8PlOiYP-mxvtmBmb4,2186
51
+ ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=GGo6Kxiwqva4JfurGx3InU3nROW70XtYvxUwEf_6mBQ,2180
52
+ ai_edge_torch/generative/examples/llama/llama.py,sha256=NheDIa8JWiYhC9cIlw9vwGMIO_DEDSyV5Ay5masGV0Y,7120
53
+ ai_edge_torch/generative/examples/llama/verify.py,sha256=7xwKM_yzLCrmFsYj1UbsjW58ZG8Yic0xw1GFkdydrCU,2525
54
+ ai_edge_torch/generative/examples/llama/verify_3b.py,sha256=IijBWqLXINOfwayM-8EIpc7OcC6Nj5CnberStx-vDSk,2528
49
55
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
56
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
51
57
  ai_edge_torch/generative/examples/openelm/openelm.py,sha256=VcU8A0B9nQR-FTPHXqNHSHZzeIZZ_As4yvKZMnoU2P4,7482
52
- ai_edge_torch/generative/examples/openelm/verify.py,sha256=QdFKymQSCYFJcYVvA63u5uIsn1YxJ0JZD5UqN6gxraI,2112
58
+ ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
53
59
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
54
60
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
55
61
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
56
62
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
57
- ai_edge_torch/generative/examples/phi/phi3.py,sha256=DIDzpG8DZkWDcWsAVkcxzxIC3U3352uVI3zMoYZD16U,9554
58
- ai_edge_torch/generative/examples/phi/verify.py,sha256=5pQ0Bt8vGl8uTpkgXvOx8G7_rju0Gi8mIEr5NtRSAbs,2145
59
- ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=o1UTqpimkeX3MDjgdG1QTQkoZHvCEnGClA0J0WB3wJ4,2328
63
+ ai_edge_torch/generative/examples/phi/phi3.py,sha256=kf4K5uRxWvFeZBXpiIkqsFWg18u-_NfAijujyGbQqag,9254
64
+ ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
65
+ ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
60
66
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
67
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
62
68
  ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
63
- ai_edge_torch/generative/examples/smollm/verify.py,sha256=G2dAcl-VhAbx1E1PEqM6hpzPF24HqFZaz7UBEpJSQ3w,2022
69
+ ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
64
70
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
65
71
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
66
72
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=2RMi5UmfMT4Ep68ZLJsqF-fMvEumNVkIwqtsRli9HhA,6068
@@ -87,7 +93,7 @@ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=
87
93
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
88
94
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
89
95
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
90
- ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=LUChL5tA7FHL_DlTg5QKvGInmH9AwVVw9a-omcndiz8,2095
96
+ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
91
97
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
92
98
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
93
99
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -115,7 +121,7 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
115
121
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
116
122
  ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
117
123
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
118
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=IzW2HjXS2-zePZM-qEuXL4zclnGvYsNw-6tuDSeNna4,8163
124
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=kCm-L3rWbPj25E_QEbkSLiaCk3y23SjKJs-MG-EwKug,8545
119
125
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
120
126
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
121
127
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -123,7 +129,8 @@ ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0b
123
129
  ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
124
130
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
125
131
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
126
- ai_edge_torch/generative/utilities/verifier.py,sha256=vU9KgmFS7I9jNS_3H2SWROx-rbNqtMKgQC2MRhdqQ4g,8803
132
+ ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
133
+ ai_edge_torch/generative/utilities/verifier.py,sha256=wQ4EtIED_a6FRsaOXeoQVZiHNx07esOYCQYbDVLgZ2o,9520
127
134
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
128
135
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
129
136
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -170,8 +177,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
170
177
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
171
178
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
172
179
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
173
- ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
174
- ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/METADATA,sha256=5KsshdZ4-3X193HkoO2ukceyDEdWGvb8ZEMcw88qt7k,1897
175
- ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
176
- ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
177
- ai_edge_torch_nightly-0.3.0.dev20240925.dist-info/RECORD,,
180
+ ai_edge_torch_nightly-0.3.0.dev20240927.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
181
+ ai_edge_torch_nightly-0.3.0.dev20240927.dist-info/METADATA,sha256=06w25gO47Uf4Ky62kxwunGH2Y15EsPm5QGbmLcIlGvs,1897
182
+ ai_edge_torch_nightly-0.3.0.dev20240927.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
183
+ ai_edge_torch_nightly-0.3.0.dev20240927.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
184
+ ai_edge_torch_nightly-0.3.0.dev20240927.dist-info/RECORD,,