ai-edge-torch-nightly 0.3.0.dev20240926__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
  2. ai_edge_torch/generative/examples/gemma/gemma2.py +2 -8
  3. ai_edge_torch/generative/examples/gemma/verify_gemma2.py +0 -1
  4. ai_edge_torch/generative/examples/gemma/verify_util.py +13 -24
  5. ai_edge_torch/generative/examples/llama/__init__.py +14 -0
  6. ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
  7. ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
  8. ai_edge_torch/generative/examples/llama/llama.py +204 -0
  9. ai_edge_torch/generative/examples/llama/verify.py +73 -0
  10. ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
  11. ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
  12. ai_edge_torch/generative/examples/openelm/verify.py +14 -7
  13. ai_edge_torch/generative/examples/phi/phi2.py +2 -6
  14. ai_edge_torch/generative/examples/phi/phi3.py +17 -24
  15. ai_edge_torch/generative/examples/phi/verify.py +8 -9
  16. ai_edge_torch/generative/examples/phi/verify_phi3.py +8 -9
  17. ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
  18. ai_edge_torch/generative/examples/smollm/verify.py +14 -6
  19. ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
  20. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +2 -0
  21. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +2 -0
  22. ai_edge_torch/generative/examples/t5/t5.py +0 -2
  23. ai_edge_torch/generative/examples/test_models/toy_model.py +5 -10
  24. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -5
  25. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
  26. ai_edge_torch/generative/examples/tiny_llama/verify.py +15 -7
  27. ai_edge_torch/generative/layers/model_config.py +2 -0
  28. ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
  29. ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
  30. ai_edge_torch/generative/utilities/verifier.py +117 -97
  31. ai_edge_torch/version.py +1 -1
  32. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
  33. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +36 -29
  34. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
  35. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
  36. {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -16,111 +16,129 @@
16
16
  """Common utility functions to verify the reauthored models."""
17
17
 
18
18
  import logging
19
- from typing import List, Optional, Union
19
+ from typing import List
20
20
 
21
21
  from ai_edge_torch.generative.layers import kv_cache as kv_utils
22
22
  import torch
23
- import transformers
24
23
 
25
24
 
26
25
  class ModelWrapper(torch.nn.Module):
27
- """A wrapper for the model to be verified, this could be a HuggingFace model
26
+ """A wrapper for the model to be verified.
28
27
 
29
- or a regular PyTorch model.
28
+ It unifies the interface of forward() and generate() of models for the
29
+ verification to call.
30
30
  """
31
31
 
32
- def __init__(
33
- self,
34
- model: torch.nn.Module,
35
- model_format: str = "huggingface",
36
- hf_generation_config: Optional[transformers.GenerationConfig] = None,
37
- ):
32
+ def __init__(self, model: torch.nn.Module):
38
33
  """Initializes the wrapper.
39
34
 
40
35
  Args:
41
- model (torch.nn.Module): The original model. This could be a model built
42
- from HuggingFace transformers, or a regular PyTorch model.
43
- model_format (str): The format of the model. It should be either
44
- "huggingface" or "pytorch".
45
- hf_generation_config (transformers.GenerationConfig): The HuggingFace
46
- generation config. This config will only be used if the underlying model
47
- is built from HuggingFace transformers.
36
+ model (torch.nn.Module): The model which might have different interfaces
37
+ of forward() and generate(). It could be a model built from HuggingFace
38
+ transformers, a regular PyTorch model, or a model re-authored with
39
+ ai_edge_torch Generative API.
48
40
  """
49
41
  super().__init__()
50
42
  self.model = model
51
- self.model_format = model_format
52
- self.hf_generation_config = hf_generation_config
43
+
44
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
45
+ """Gets output logits by forwarding the input tokens.
46
+
47
+ Args:
48
+ tokens (torch.Tensor): The input tokens to forward. Its dimension is
49
+ expected to be (batch_size=1, kv_cache_max_len).
50
+
51
+ Returns:
52
+ The output logits.
53
+ """
54
+ raise NotImplementedError("forward() is not implemented.")
53
55
 
54
56
  def generate(
55
- self, inputs: torch.Tensor
56
- ) -> Union[transformers.utils.ModelOutput, torch.LongTensor]:
57
- if self.model_format == "huggingface":
58
- return self.model.generate(
59
- inputs=inputs, generation_config=self.hf_generation_config
60
- )
61
- else:
62
- raise NotImplementedError(
63
- "generate() is not implemented for model format: %s"
64
- % self.model_format
65
- )
57
+ self, prompts: torch.Tensor, max_new_tokens: int
58
+ ) -> torch.IntTensor:
59
+ """Returns the response token IDs to the given prompts tensor.
66
60
 
67
- def forward(
68
- self,
69
- inputs: torch.Tensor,
70
- ):
71
- return self.model.forward(inputs)
61
+ The maximum number of tokens to generate might be set by subclasses.
72
62
 
63
+ Args:
64
+ prompts (torch.Tensor): The input token IDs to generate with. Its shape is
65
+ expected to be (batch_size=1, input_ids_len).
66
+ max_new_tokens (int): The maximum number of response token IDs to
67
+ generate.
68
+
69
+ Returns:
70
+ The tensor of response token IDs with shape of (batch_size=1,
71
+ response_ids_len).
72
+ """
73
+ raise NotImplementedError("generate() is not implemented.")
73
74
 
74
- def forward(
75
- model: torch.nn.Module,
76
- tokens: torch.Tensor,
77
- kv_cache: kv_utils.KVCache,
78
- ) -> tuple[torch.Tensor, kv_utils.KVCache]:
79
- """Forwards the model reauthored with ai_edge_torch Generative API.
80
75
 
81
- Args:
82
- model (torch.nn.Module): The model to forward. It should be a model built
83
- with ai_edge_torch Generative API.
84
- tokens (torch.Tensor): The input tokens to forward.
85
- kv_cache (KVCache): The KV cache to forward.
76
+ class ReauthoredModelWrapper(ModelWrapper):
77
+ """A wrapper for the model reauthored with ai_edge_torch Generative API."""
86
78
 
87
- Returns:
88
- The output logits and the updated KV cache.
89
- """
90
- input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
91
- output = model.forward(tokens, input_pos, kv_cache)
92
- return output["logits"], output["kv_cache"]
79
+ def _init_kv_cache(self):
80
+ """Returns an initialized KV cache."""
81
+ return kv_utils.KVCache.from_model_config(self.model.config)
93
82
 
83
+ def _forward_with_kv_cache(
84
+ self,
85
+ tokens: torch.Tensor,
86
+ kv_cache: kv_utils.KVCache,
87
+ ) -> tuple[torch.Tensor, kv_utils.KVCache]:
88
+ """Forwards the model and updates an external KV cache.
94
89
 
95
- def generate(
96
- model: torch.nn.Module, prompts: torch.Tensor, response_len: int
97
- ) -> torch.Tensor:
98
- """Generates the response to the prompts.
90
+ Args:
91
+ tokens (torch.Tensor): The input tokens to forward.
92
+ kv_cache (KVCache): The KV cache to forward.
99
93
 
100
- It appends tokens output by the model to the prompts and feeds them back to
101
- the model up to decode_len.
94
+ Returns:
95
+ The output logits and the updated KV cache.
96
+ """
97
+ input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
98
+ output = self.model.forward(tokens, input_pos, kv_cache)
99
+ return output["logits"], output["kv_cache"]
102
100
 
103
- Args:
104
- model (torch.nn.Module): The model to generate. It should be a model built
105
- with ai_edge_torch Generative API.
106
- prompts (torch.Tensor): The prompts to generate.
107
- response_len (int): The number of tokens to generate.
101
+ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
102
+ logits, _ = self._forward_with_kv_cache(tokens, self._init_kv_cache())
103
+ return logits
108
104
 
109
- Returns:
110
- The generated tokens.
111
- """
112
- input_ids = prompts[0].int().tolist()
113
- kv_cache = kv_utils.KVCache.from_model_config(model.config)
114
- for _ in range(response_len - len(input_ids)):
115
- logits, kv_cache = forward(model, torch.tensor([input_ids]), kv_cache)
116
- generated_token = logits[0][-1].argmax().item()
117
- input_ids.append(generated_token)
118
- return torch.tensor([input_ids])
105
+ def generate(
106
+ self, prompts: torch.Tensor, max_new_tokens: int
107
+ ) -> torch.IntTensor:
108
+ input_ids = prompts[0].int().tolist()
109
+ kv_cache = self._init_kv_cache()
110
+ for _ in range(max_new_tokens):
111
+ tokens = torch.tensor([input_ids])
112
+ logits, kv_cache = self._forward_with_kv_cache(tokens, kv_cache)
113
+ generated_token = logits[0][-1].argmax().item()
114
+ input_ids.append(generated_token)
115
+ return torch.tensor([input_ids])
116
+
117
+
118
+ class TokenizerWrapper(torch.nn.Module):
119
+ """A wrapper for the tokenizer used for verification."""
120
+
121
+ def __init__(self, tokenizer: torch.nn.Module):
122
+ """Initializes the wrapper.
123
+
124
+ Args:
125
+ tokenizer (torch.nn.Module): The tokenizer to wrap.
126
+ """
127
+ super().__init__()
128
+ self.tokenizer = tokenizer
129
+
130
+ def encode(self, prompts: str) -> torch.Tensor:
131
+ """Encodes the prompts to token IDs."""
132
+ return self.tokenizer.encode(prompts, return_tensors="pt")
133
+
134
+ def decode(self, token_ids: torch.Tensor) -> str:
135
+ """Decodes the token IDs to a string."""
136
+ return self.tokenizer.decode(token_ids)
119
137
 
120
138
 
121
139
  def verify_with_input_ids(
122
140
  original_model: ModelWrapper,
123
- reauthored_model: torch.nn.Module,
141
+ reauthored_model: ReauthoredModelWrapper,
124
142
  input_ids: List[int],
125
143
  kv_cache_max_len: int = 1024,
126
144
  rtol: float = 1e-05,
@@ -132,8 +150,8 @@ def verify_with_input_ids(
132
150
 
133
151
  Args:
134
152
  original_model (ModelWrapper): The original model.
135
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
136
- Generative API.
153
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
154
+ ai_edge_torch Generative API.
137
155
  input_ids (List[int]): The input token IDs to forward with.
138
156
  kv_cache_max_len (int): The maximum sequence length of the KV cache.
139
157
  rtol (float): The relative tolerance for the comparison.
@@ -147,13 +165,12 @@ def verify_with_input_ids(
147
165
 
148
166
  logging.info("Forwarding the original model...")
149
167
  outputs_original = original_model.forward(tokens)
150
- logits_original = outputs_original.logits[0, len(input_ids) - 1, :]
168
+ logits_original = outputs_original[0, len(input_ids) - 1, :]
151
169
  logging.info("logits_original: %s", logits_original)
152
170
 
153
171
  logging.info("Forwarding the reauthored model...")
154
- kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
155
- outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
156
- logits_reauthored = outputs_reauthored[0][0, len(input_ids) - 1, :]
172
+ outputs_reauthored = reauthored_model.forward(tokens)
173
+ logits_reauthored = outputs_reauthored[0, len(input_ids) - 1, :]
157
174
  logging.info("logits_reauthored: %s", logits_reauthored)
158
175
 
159
176
  return torch.allclose(
@@ -163,9 +180,10 @@ def verify_with_input_ids(
163
180
 
164
181
  def verify_model_with_prompts(
165
182
  original_model: ModelWrapper,
166
- reauthored_model: torch.nn.Module,
167
- tokenizer: torch.nn.Module,
183
+ reauthored_model: ReauthoredModelWrapper,
184
+ tokenizer: TokenizerWrapper,
168
185
  prompts: str,
186
+ max_new_tokens: int,
169
187
  ) -> bool:
170
188
  """Verifies if the model reauthored generates the same answer of the oringal.
171
189
 
@@ -174,24 +192,24 @@ def verify_model_with_prompts(
174
192
 
175
193
  Args:
176
194
  original_model (ModelWrapper): The original model.
177
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
178
- Generative API.
179
- tokenizer (torch.nn.Module): The tokenizer.
195
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
196
+ ai_edge_torch Generative API.
197
+ tokenizer (TokenizerWrapper): The tokenizer.
180
198
  prompts (str): The input prompts to generate answers.
199
+ max_new_tokens (int): The maximum number of new tokens to generate.
181
200
 
182
201
  Returns:
183
202
  True if the model reauthored generates the same answer of the original.
184
203
  """
185
- prompt_tokens = tokenizer.encode(prompts, return_tensors="pt")
204
+ prompt_tokens = tokenizer.encode(prompts)
186
205
 
187
206
  logging.info("Generating answer with the original model...")
188
- outputs_original = original_model.generate(prompt_tokens)
207
+ outputs_original = original_model.generate(prompt_tokens, max_new_tokens)
189
208
  response_original = tokenizer.decode(outputs_original[0])
190
209
  logging.info("outputs_from_original_model: [[%s]]", response_original)
191
210
 
192
211
  logging.info("Generating answer with the reauthored model...")
193
- generate_len = len(outputs_original[0])
194
- outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
212
+ outputs_reauthored = reauthored_model.generate(prompt_tokens, max_new_tokens)
195
213
  response_reauthored = tokenizer.decode(outputs_reauthored[0])
196
214
  logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
197
215
 
@@ -200,9 +218,10 @@ def verify_model_with_prompts(
200
218
 
201
219
  def verify_reauthored_model(
202
220
  original_model: ModelWrapper,
203
- reauthored_model: torch.nn.Module,
204
- tokenizer: torch.nn.Module,
221
+ reauthored_model: ReauthoredModelWrapper,
222
+ tokenizer: TokenizerWrapper,
205
223
  generate_prompts: List[str],
224
+ max_new_tokens: int = 30,
206
225
  forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
207
226
  rtol: float = 1e-05,
208
227
  atol: float = 1e-05,
@@ -219,10 +238,11 @@ def verify_reauthored_model(
219
238
 
220
239
  Args:
221
240
  original_model (ModelWrapper): The original model.
222
- reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
223
- Generative API.
224
- tokenizer (torch.nn.Module): The tokenizer.
241
+ reauthored_model (ReauthoredModelWrapper): The model reauthored with
242
+ ai_edge_torch Generative API.
243
+ tokenizer (TokenizerWrapper): The tokenizer.
225
244
  generate_prompts (List[str]): List of the input prompts to generate answers.
245
+ max_new_tokens (int): The maximum number of new tokens to generate.
226
246
  forward_input_ids (List[torch.Tensor]): List if ihe input token IDs to
227
247
  forward with.
228
248
  rtol (float): The relative tolerance for the comparison.
@@ -235,13 +255,13 @@ def verify_reauthored_model(
235
255
  ):
236
256
  logging.info("PASS")
237
257
  else:
238
- logging.info("FAILED")
258
+ logging.error("FAILED")
239
259
 
240
260
  for prompts in generate_prompts:
241
261
  logging.info("Verifying the reauthored model with prompts:%s", prompts)
242
262
  if verify_model_with_prompts(
243
- original_model, reauthored_model, tokenizer, prompts
263
+ original_model, reauthored_model, tokenizer, prompts, max_new_tokens
244
264
  ):
245
265
  logging.info("PASS")
246
266
  else:
247
- logging.info("FAILED")
267
+ logging.error("FAILED")
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240926"
16
+ __version__ = "0.3.0.dev20240928"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240926
3
+ Version: 0.3.0.dev20240928
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=e_e6TIee2wiwaLShw_LBUsVRwNFQHZYOv7WtAcrMix4,706
6
+ ai_edge_torch/version.py,sha256=YiCjdglLzSPYyRq64U8zJSgWDFqJs-t2JSzuA0bYYzA,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -41,32 +41,38 @@ ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
41
41
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
42
  ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
43
43
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
44
- ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=cahMzvJNJfShIw4uqoBRX5iBZrI3rvsha6wpNHzeYJ0,6369
45
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=UziyJVrR_QXE_vFAagjnn1KluMM74coI89-UcdGTpkQ,9243
44
+ ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=kxWmmoVvtLP5auB3UXA2vsvZmSnpBs4SBixzYeAXzVA,6255
45
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=7VF5RYJ8QhROQNIlx-QovO-y6-jFp_EHgAkBNChZaqE,9066
46
46
  ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
47
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=p0wmhUDkNj6m9QZVssHjSc2HzIxci_1XMvZqS1X3xK4,1818
48
- ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=vQp_QN8nRXv81-wrI8h8722jV68VEUMGSrpXa5kRk68,5175
47
+ ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=K77k-JpdhIwm3tbBnzpw8HQsFRwAVyszxRo82fR6-q4,1762
48
+ ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=sqltZbnyKemNvKqqi9d09i74gP-PPQFodRYfDfnhycQ,4933
49
+ ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
+ ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py,sha256=_OrerrTA6tvP9Tnwj601QO95Cm8PlOiYP-mxvtmBmb4,2186
51
+ ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=GGo6Kxiwqva4JfurGx3InU3nROW70XtYvxUwEf_6mBQ,2180
52
+ ai_edge_torch/generative/examples/llama/llama.py,sha256=5vlh2Z8vEPH8Z4LoHoFYCcuOQynx4mbVE37v3yMl1hE,7162
53
+ ai_edge_torch/generative/examples/llama/verify.py,sha256=7xwKM_yzLCrmFsYj1UbsjW58ZG8Yic0xw1GFkdydrCU,2525
54
+ ai_edge_torch/generative/examples/llama/verify_3b.py,sha256=IijBWqLXINOfwayM-8EIpc7OcC6Nj5CnberStx-vDSk,2528
49
55
  ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
56
  ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
51
- ai_edge_torch/generative/examples/openelm/openelm.py,sha256=VcU8A0B9nQR-FTPHXqNHSHZzeIZZ_As4yvKZMnoU2P4,7482
52
- ai_edge_torch/generative/examples/openelm/verify.py,sha256=_tglrzub_qAVDHGacriyzPFlKeIFhZ4KnDGkZ22g7So,2127
57
+ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=hxbpvk0fNswzbqZfGteflqKMmkH7yzeMuW6r29s_xnQ,7374
58
+ ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
53
59
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
54
60
  ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
55
61
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
56
- ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
57
- ai_edge_torch/generative/examples/phi/phi3.py,sha256=DIDzpG8DZkWDcWsAVkcxzxIC3U3352uVI3zMoYZD16U,9554
58
- ai_edge_torch/generative/examples/phi/verify.py,sha256=HJ7RkE0CTtpTdu_pgW4VX37_4Q20Ow6nG1DtA_eXsZE,2161
59
- ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=7ygNNafJBu803JhxQm0aXMccCy6YX6-D-DT9k4QCyL0,2344
62
+ ai_edge_torch/generative/examples/phi/phi2.py,sha256=82SEKRwtKfT9VcNQaykGmemiov_XaXWLi4Zyw9Vtmj0,6075
63
+ ai_edge_torch/generative/examples/phi/phi3.py,sha256=Xh-l7TQdXYZJ9PViRVk2_y91Ec7Yntn0UpkuzRIG3T8,9231
64
+ ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
65
+ ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
60
66
  ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
67
  ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
62
- ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
63
- ai_edge_torch/generative/examples/smollm/verify.py,sha256=FhnNiISi8JZWQaA-AaL9giS7IzkBAvcHkuDFBISj4Zg,2038
68
+ ai_edge_torch/generative/examples/smollm/smollm.py,sha256=dal8vnZjQd6vR7sc76-FYGDKUlVjOlfUALV-pwbXJGc,3264
69
+ ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
64
70
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
65
71
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
66
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=2RMi5UmfMT4Ep68ZLJsqF-fMvEumNVkIwqtsRli9HhA,6068
72
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=lwWrKY1NpnbvHQRenpltVN65QlzjWmSScl5CLSipBkc,6110
67
73
  ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
68
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=ZTRD56e8MsdGPJr7vpLa4Ju_BFw_b-FUgXgd-SO5MBw,15665
69
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=6FAnevL8ZfCK2YCSPivarUH0Z8wGKSmnPpJNC0OI5A8,33680
74
+ ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=ClXNntmh0PF3s6U3C7SW3tyVrsSSrV2kyz-_RF4BcqA,15715
75
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=mBEAUYjV1qDJy9ZAsHtm9RGce0Mbzv0VoPZpdcQl1mk,33730
70
76
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
71
77
  ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
72
78
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
@@ -78,16 +84,16 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=ZE6H
78
84
  ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=RxR5rw0wFFm_5CfAY-3-EIz83vhM9EKye8Bb5zBb0Ok,1341
79
85
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
80
86
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=HHtZTtUh3QgE4F74-ru_8n1pt6cqfbObw12xoaMJ7NQ,4596
81
- ai_edge_torch/generative/examples/t5/t5.py,sha256=OZ67knK-UB1dBjxydG-Jwkp0Z3FzOCqGPTdg5aBFu4w,21328
87
+ ai_edge_torch/generative/examples/t5/t5.py,sha256=gFTmPi-xB8pcPRgoF3DJxvH_fT-KWTb8ii77P5UbKR0,21263
82
88
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqXquaFQPvCFBFF5zOnmGVb3Hg,8731
83
89
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
84
90
  ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
85
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=LTuzres5DHmrMT6U9rCrGf6vmR9SmopmB8sO6Cd2NxQ,5255
86
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=xDYTh4m3vBEb6r3_ERhmj5qILW7YdVDAnZ-fitgYONg,4450
91
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
92
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=ZpjSIiayjTEVwg5Q1vI9Iy5tq1YSF5zaVDF4HTp_Z2s,4353
87
93
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
88
94
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
89
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
90
- ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=Yt49ndXTL_fUfjrQjx-AyhD4x796IOTjuBegiQe7_Yc,2111
95
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=aSNHOAar5yPnGAeKsv8zrqYhOq9RR_7hwqHUMBb2mkM,5930
96
+ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
91
97
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
92
98
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
93
99
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -96,7 +102,7 @@ ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHif
96
102
  ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
97
103
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
98
104
  ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
99
- ai_edge_torch/generative/layers/model_config.py,sha256=l5Rb3h3GK2pux-Lg3BONTD6b7klxXqUbDDtYs_bGKLk,6879
105
+ ai_edge_torch/generative/layers/model_config.py,sha256=Fa0eFCMlyfdwd3cM1drhP9vlXRhIguDrglsHn4ax2_w,6948
100
106
  ai_edge_torch/generative/layers/normalization.py,sha256=cpo88JUXbF9j3sJTU4JuwOap9ryGV05C1QkPij-YQwU,6999
101
107
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
102
108
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
@@ -115,7 +121,7 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
115
121
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
116
122
  ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
117
123
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
118
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=IzW2HjXS2-zePZM-qEuXL4zclnGvYsNw-6tuDSeNna4,8163
124
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=kCm-L3rWbPj25E_QEbkSLiaCk3y23SjKJs-MG-EwKug,8545
119
125
  ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
120
126
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
121
127
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -123,7 +129,8 @@ ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0b
123
129
  ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
124
130
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
125
131
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
126
- ai_edge_torch/generative/utilities/verifier.py,sha256=vo5haJqtae-M_s4qLMh1_ItU0zluwqekV8jQ-cEjXCQ,8807
132
+ ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
133
+ ai_edge_torch/generative/utilities/verifier.py,sha256=wQ4EtIED_a6FRsaOXeoQVZiHNx07esOYCQYbDVLgZ2o,9520
127
134
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
128
135
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
129
136
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -170,8 +177,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
170
177
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
171
178
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
172
179
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
173
- ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
174
- ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/METADATA,sha256=r2brEhCKGcZsIaGpoJLT3DXvY9Yi-55_A-ZwjVtPTnk,1897
175
- ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
176
- ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
177
- ai_edge_torch_nightly-0.3.0.dev20240926.dist-info/RECORD,,
180
+ ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
181
+ ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/METADATA,sha256=3HuAFZTfvmU787dVypwpmUvo4DdZSekGsqGimO-oPfM,1897
182
+ ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
183
+ ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
184
+ ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/RECORD,,