ai-edge-torch-nightly 0.3.0.dev20240926__py3-none-any.whl → 0.3.0.dev20240928__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/gemma1.py +2 -6
- ai_edge_torch/generative/examples/gemma/gemma2.py +2 -8
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +0 -1
- ai_edge_torch/generative/examples/gemma/verify_util.py +13 -24
- ai_edge_torch/generative/examples/llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/convert_to_tflite.py +68 -0
- ai_edge_torch/generative/examples/llama/llama.py +204 -0
- ai_edge_torch/generative/examples/llama/verify.py +73 -0
- ai_edge_torch/generative/examples/llama/verify_3b.py +73 -0
- ai_edge_torch/generative/examples/openelm/openelm.py +2 -6
- ai_edge_torch/generative/examples/openelm/verify.py +14 -7
- ai_edge_torch/generative/examples/phi/phi2.py +2 -6
- ai_edge_torch/generative/examples/phi/phi3.py +17 -24
- ai_edge_torch/generative/examples/phi/verify.py +8 -9
- ai_edge_torch/generative/examples/phi/verify_phi3.py +8 -9
- ai_edge_torch/generative/examples/smollm/smollm.py +1 -0
- ai_edge_torch/generative/examples/smollm/verify.py +14 -6
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +2 -0
- ai_edge_torch/generative/examples/t5/t5.py +0 -2
- ai_edge_torch/generative/examples/test_models/toy_model.py +5 -10
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -5
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +2 -6
- ai_edge_torch/generative/examples/tiny_llama/verify.py +15 -7
- ai_edge_torch/generative/layers/model_config.py +2 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +10 -0
- ai_edge_torch/generative/utilities/transformers_verifier.py +42 -0
- ai_edge_torch/generative/utilities/verifier.py +117 -97
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/RECORD +36 -29
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240926.dist-info → ai_edge_torch_nightly-0.3.0.dev20240928.dist-info}/top_level.txt +0 -0
@@ -16,111 +16,129 @@
|
|
16
16
|
"""Common utility functions to verify the reauthored models."""
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from typing import List
|
19
|
+
from typing import List
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
22
|
import torch
|
23
|
-
import transformers
|
24
23
|
|
25
24
|
|
26
25
|
class ModelWrapper(torch.nn.Module):
|
27
|
-
"""A wrapper for the model to be verified
|
26
|
+
"""A wrapper for the model to be verified.
|
28
27
|
|
29
|
-
|
28
|
+
It unifies the interface of forward() and generate() of models for the
|
29
|
+
verification to call.
|
30
30
|
"""
|
31
31
|
|
32
|
-
def __init__(
|
33
|
-
self,
|
34
|
-
model: torch.nn.Module,
|
35
|
-
model_format: str = "huggingface",
|
36
|
-
hf_generation_config: Optional[transformers.GenerationConfig] = None,
|
37
|
-
):
|
32
|
+
def __init__(self, model: torch.nn.Module):
|
38
33
|
"""Initializes the wrapper.
|
39
34
|
|
40
35
|
Args:
|
41
|
-
model (torch.nn.Module): The
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
hf_generation_config (transformers.GenerationConfig): The HuggingFace
|
46
|
-
generation config. This config will only be used if the underlying model
|
47
|
-
is built from HuggingFace transformers.
|
36
|
+
model (torch.nn.Module): The model which might have different interfaces
|
37
|
+
of forward() and generate(). It could be a model built from HuggingFace
|
38
|
+
transformers, a regular PyTorch model, or a model re-authored with
|
39
|
+
ai_edge_torch Generative API.
|
48
40
|
"""
|
49
41
|
super().__init__()
|
50
42
|
self.model = model
|
51
|
-
|
52
|
-
|
43
|
+
|
44
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
45
|
+
"""Gets output logits by forwarding the input tokens.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
tokens (torch.Tensor): The input tokens to forward. Its dimension is
|
49
|
+
expected to be (batch_size=1, kv_cache_max_len).
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
The output logits.
|
53
|
+
"""
|
54
|
+
raise NotImplementedError("forward() is not implemented.")
|
53
55
|
|
54
56
|
def generate(
|
55
|
-
self,
|
56
|
-
) ->
|
57
|
-
|
58
|
-
return self.model.generate(
|
59
|
-
inputs=inputs, generation_config=self.hf_generation_config
|
60
|
-
)
|
61
|
-
else:
|
62
|
-
raise NotImplementedError(
|
63
|
-
"generate() is not implemented for model format: %s"
|
64
|
-
% self.model_format
|
65
|
-
)
|
57
|
+
self, prompts: torch.Tensor, max_new_tokens: int
|
58
|
+
) -> torch.IntTensor:
|
59
|
+
"""Returns the response token IDs to the given prompts tensor.
|
66
60
|
|
67
|
-
|
68
|
-
self,
|
69
|
-
inputs: torch.Tensor,
|
70
|
-
):
|
71
|
-
return self.model.forward(inputs)
|
61
|
+
The maximum number of tokens to generate might be set by subclasses.
|
72
62
|
|
63
|
+
Args:
|
64
|
+
prompts (torch.Tensor): The input token IDs to generate with. Its shape is
|
65
|
+
expected to be (batch_size=1, input_ids_len).
|
66
|
+
max_new_tokens (int): The maximum number of response token IDs to
|
67
|
+
generate.
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
The tensor of response token IDs with shape of (batch_size=1,
|
71
|
+
response_ids_len).
|
72
|
+
"""
|
73
|
+
raise NotImplementedError("generate() is not implemented.")
|
73
74
|
|
74
|
-
def forward(
|
75
|
-
model: torch.nn.Module,
|
76
|
-
tokens: torch.Tensor,
|
77
|
-
kv_cache: kv_utils.KVCache,
|
78
|
-
) -> tuple[torch.Tensor, kv_utils.KVCache]:
|
79
|
-
"""Forwards the model reauthored with ai_edge_torch Generative API.
|
80
75
|
|
81
|
-
|
82
|
-
|
83
|
-
with ai_edge_torch Generative API.
|
84
|
-
tokens (torch.Tensor): The input tokens to forward.
|
85
|
-
kv_cache (KVCache): The KV cache to forward.
|
76
|
+
class ReauthoredModelWrapper(ModelWrapper):
|
77
|
+
"""A wrapper for the model reauthored with ai_edge_torch Generative API."""
|
86
78
|
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
91
|
-
output = model.forward(tokens, input_pos, kv_cache)
|
92
|
-
return output["logits"], output["kv_cache"]
|
79
|
+
def _init_kv_cache(self):
|
80
|
+
"""Returns an initialized KV cache."""
|
81
|
+
return kv_utils.KVCache.from_model_config(self.model.config)
|
93
82
|
|
83
|
+
def _forward_with_kv_cache(
|
84
|
+
self,
|
85
|
+
tokens: torch.Tensor,
|
86
|
+
kv_cache: kv_utils.KVCache,
|
87
|
+
) -> tuple[torch.Tensor, kv_utils.KVCache]:
|
88
|
+
"""Forwards the model and updates an external KV cache.
|
94
89
|
|
95
|
-
|
96
|
-
|
97
|
-
)
|
98
|
-
"""Generates the response to the prompts.
|
90
|
+
Args:
|
91
|
+
tokens (torch.Tensor): The input tokens to forward.
|
92
|
+
kv_cache (KVCache): The KV cache to forward.
|
99
93
|
|
100
|
-
|
101
|
-
|
94
|
+
Returns:
|
95
|
+
The output logits and the updated KV cache.
|
96
|
+
"""
|
97
|
+
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
98
|
+
output = self.model.forward(tokens, input_pos, kv_cache)
|
99
|
+
return output["logits"], output["kv_cache"]
|
102
100
|
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
prompts (torch.Tensor): The prompts to generate.
|
107
|
-
response_len (int): The number of tokens to generate.
|
101
|
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
102
|
+
logits, _ = self._forward_with_kv_cache(tokens, self._init_kv_cache())
|
103
|
+
return logits
|
108
104
|
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
105
|
+
def generate(
|
106
|
+
self, prompts: torch.Tensor, max_new_tokens: int
|
107
|
+
) -> torch.IntTensor:
|
108
|
+
input_ids = prompts[0].int().tolist()
|
109
|
+
kv_cache = self._init_kv_cache()
|
110
|
+
for _ in range(max_new_tokens):
|
111
|
+
tokens = torch.tensor([input_ids])
|
112
|
+
logits, kv_cache = self._forward_with_kv_cache(tokens, kv_cache)
|
113
|
+
generated_token = logits[0][-1].argmax().item()
|
114
|
+
input_ids.append(generated_token)
|
115
|
+
return torch.tensor([input_ids])
|
116
|
+
|
117
|
+
|
118
|
+
class TokenizerWrapper(torch.nn.Module):
|
119
|
+
"""A wrapper for the tokenizer used for verification."""
|
120
|
+
|
121
|
+
def __init__(self, tokenizer: torch.nn.Module):
|
122
|
+
"""Initializes the wrapper.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
tokenizer (torch.nn.Module): The tokenizer to wrap.
|
126
|
+
"""
|
127
|
+
super().__init__()
|
128
|
+
self.tokenizer = tokenizer
|
129
|
+
|
130
|
+
def encode(self, prompts: str) -> torch.Tensor:
|
131
|
+
"""Encodes the prompts to token IDs."""
|
132
|
+
return self.tokenizer.encode(prompts, return_tensors="pt")
|
133
|
+
|
134
|
+
def decode(self, token_ids: torch.Tensor) -> str:
|
135
|
+
"""Decodes the token IDs to a string."""
|
136
|
+
return self.tokenizer.decode(token_ids)
|
119
137
|
|
120
138
|
|
121
139
|
def verify_with_input_ids(
|
122
140
|
original_model: ModelWrapper,
|
123
|
-
reauthored_model:
|
141
|
+
reauthored_model: ReauthoredModelWrapper,
|
124
142
|
input_ids: List[int],
|
125
143
|
kv_cache_max_len: int = 1024,
|
126
144
|
rtol: float = 1e-05,
|
@@ -132,8 +150,8 @@ def verify_with_input_ids(
|
|
132
150
|
|
133
151
|
Args:
|
134
152
|
original_model (ModelWrapper): The original model.
|
135
|
-
reauthored_model (
|
136
|
-
Generative API.
|
153
|
+
reauthored_model (ReauthoredModelWrapper): The model reauthored with
|
154
|
+
ai_edge_torch Generative API.
|
137
155
|
input_ids (List[int]): The input token IDs to forward with.
|
138
156
|
kv_cache_max_len (int): The maximum sequence length of the KV cache.
|
139
157
|
rtol (float): The relative tolerance for the comparison.
|
@@ -147,13 +165,12 @@ def verify_with_input_ids(
|
|
147
165
|
|
148
166
|
logging.info("Forwarding the original model...")
|
149
167
|
outputs_original = original_model.forward(tokens)
|
150
|
-
logits_original = outputs_original
|
168
|
+
logits_original = outputs_original[0, len(input_ids) - 1, :]
|
151
169
|
logging.info("logits_original: %s", logits_original)
|
152
170
|
|
153
171
|
logging.info("Forwarding the reauthored model...")
|
154
|
-
|
155
|
-
|
156
|
-
logits_reauthored = outputs_reauthored[0][0, len(input_ids) - 1, :]
|
172
|
+
outputs_reauthored = reauthored_model.forward(tokens)
|
173
|
+
logits_reauthored = outputs_reauthored[0, len(input_ids) - 1, :]
|
157
174
|
logging.info("logits_reauthored: %s", logits_reauthored)
|
158
175
|
|
159
176
|
return torch.allclose(
|
@@ -163,9 +180,10 @@ def verify_with_input_ids(
|
|
163
180
|
|
164
181
|
def verify_model_with_prompts(
|
165
182
|
original_model: ModelWrapper,
|
166
|
-
reauthored_model:
|
167
|
-
tokenizer:
|
183
|
+
reauthored_model: ReauthoredModelWrapper,
|
184
|
+
tokenizer: TokenizerWrapper,
|
168
185
|
prompts: str,
|
186
|
+
max_new_tokens: int,
|
169
187
|
) -> bool:
|
170
188
|
"""Verifies if the model reauthored generates the same answer of the oringal.
|
171
189
|
|
@@ -174,24 +192,24 @@ def verify_model_with_prompts(
|
|
174
192
|
|
175
193
|
Args:
|
176
194
|
original_model (ModelWrapper): The original model.
|
177
|
-
reauthored_model (
|
178
|
-
Generative API.
|
179
|
-
tokenizer (
|
195
|
+
reauthored_model (ReauthoredModelWrapper): The model reauthored with
|
196
|
+
ai_edge_torch Generative API.
|
197
|
+
tokenizer (TokenizerWrapper): The tokenizer.
|
180
198
|
prompts (str): The input prompts to generate answers.
|
199
|
+
max_new_tokens (int): The maximum number of new tokens to generate.
|
181
200
|
|
182
201
|
Returns:
|
183
202
|
True if the model reauthored generates the same answer of the original.
|
184
203
|
"""
|
185
|
-
prompt_tokens = tokenizer.encode(prompts
|
204
|
+
prompt_tokens = tokenizer.encode(prompts)
|
186
205
|
|
187
206
|
logging.info("Generating answer with the original model...")
|
188
|
-
outputs_original = original_model.generate(prompt_tokens)
|
207
|
+
outputs_original = original_model.generate(prompt_tokens, max_new_tokens)
|
189
208
|
response_original = tokenizer.decode(outputs_original[0])
|
190
209
|
logging.info("outputs_from_original_model: [[%s]]", response_original)
|
191
210
|
|
192
211
|
logging.info("Generating answer with the reauthored model...")
|
193
|
-
|
194
|
-
outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
|
212
|
+
outputs_reauthored = reauthored_model.generate(prompt_tokens, max_new_tokens)
|
195
213
|
response_reauthored = tokenizer.decode(outputs_reauthored[0])
|
196
214
|
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
|
197
215
|
|
@@ -200,9 +218,10 @@ def verify_model_with_prompts(
|
|
200
218
|
|
201
219
|
def verify_reauthored_model(
|
202
220
|
original_model: ModelWrapper,
|
203
|
-
reauthored_model:
|
204
|
-
tokenizer:
|
221
|
+
reauthored_model: ReauthoredModelWrapper,
|
222
|
+
tokenizer: TokenizerWrapper,
|
205
223
|
generate_prompts: List[str],
|
224
|
+
max_new_tokens: int = 30,
|
206
225
|
forward_input_ids: List[List[int]] = [[1, 2, 3, 4]],
|
207
226
|
rtol: float = 1e-05,
|
208
227
|
atol: float = 1e-05,
|
@@ -219,10 +238,11 @@ def verify_reauthored_model(
|
|
219
238
|
|
220
239
|
Args:
|
221
240
|
original_model (ModelWrapper): The original model.
|
222
|
-
reauthored_model (
|
223
|
-
Generative API.
|
224
|
-
tokenizer (
|
241
|
+
reauthored_model (ReauthoredModelWrapper): The model reauthored with
|
242
|
+
ai_edge_torch Generative API.
|
243
|
+
tokenizer (TokenizerWrapper): The tokenizer.
|
225
244
|
generate_prompts (List[str]): List of the input prompts to generate answers.
|
245
|
+
max_new_tokens (int): The maximum number of new tokens to generate.
|
226
246
|
forward_input_ids (List[torch.Tensor]): List if ihe input token IDs to
|
227
247
|
forward with.
|
228
248
|
rtol (float): The relative tolerance for the comparison.
|
@@ -235,13 +255,13 @@ def verify_reauthored_model(
|
|
235
255
|
):
|
236
256
|
logging.info("PASS")
|
237
257
|
else:
|
238
|
-
logging.
|
258
|
+
logging.error("FAILED")
|
239
259
|
|
240
260
|
for prompts in generate_prompts:
|
241
261
|
logging.info("Verifying the reauthored model with prompts:%s", prompts)
|
242
262
|
if verify_model_with_prompts(
|
243
|
-
original_model, reauthored_model, tokenizer, prompts
|
263
|
+
original_model, reauthored_model, tokenizer, prompts, max_new_tokens
|
244
264
|
):
|
245
265
|
logging.info("PASS")
|
246
266
|
else:
|
247
|
-
logging.
|
267
|
+
logging.error("FAILED")
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240928
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=YiCjdglLzSPYyRq64U8zJSgWDFqJs-t2JSzuA0bYYzA,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -41,32 +41,38 @@ ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
|
|
41
41
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
42
42
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
|
43
43
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
|
44
|
-
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=
|
45
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
44
|
+
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=kxWmmoVvtLP5auB3UXA2vsvZmSnpBs4SBixzYeAXzVA,6255
|
45
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=7VF5RYJ8QhROQNIlx-QovO-y6-jFp_EHgAkBNChZaqE,9066
|
46
46
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
47
|
-
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=
|
48
|
-
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=
|
47
|
+
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=K77k-JpdhIwm3tbBnzpw8HQsFRwAVyszxRo82fR6-q4,1762
|
48
|
+
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=sqltZbnyKemNvKqqi9d09i74gP-PPQFodRYfDfnhycQ,4933
|
49
|
+
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
|
+
ai_edge_torch/generative/examples/llama/convert_3b_to_tflite.py,sha256=_OrerrTA6tvP9Tnwj601QO95Cm8PlOiYP-mxvtmBmb4,2186
|
51
|
+
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=GGo6Kxiwqva4JfurGx3InU3nROW70XtYvxUwEf_6mBQ,2180
|
52
|
+
ai_edge_torch/generative/examples/llama/llama.py,sha256=5vlh2Z8vEPH8Z4LoHoFYCcuOQynx4mbVE37v3yMl1hE,7162
|
53
|
+
ai_edge_torch/generative/examples/llama/verify.py,sha256=7xwKM_yzLCrmFsYj1UbsjW58ZG8Yic0xw1GFkdydrCU,2525
|
54
|
+
ai_edge_torch/generative/examples/llama/verify_3b.py,sha256=IijBWqLXINOfwayM-8EIpc7OcC6Nj5CnberStx-vDSk,2528
|
49
55
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
56
|
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKFP2UzCLC78tAkbwGlGhAArtG7Wa75NxJik,2185
|
51
|
-
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=
|
52
|
-
ai_edge_torch/generative/examples/openelm/verify.py,sha256=
|
57
|
+
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=hxbpvk0fNswzbqZfGteflqKMmkH7yzeMuW6r29s_xnQ,7374
|
58
|
+
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
53
59
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
54
60
|
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
|
55
61
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
|
56
|
-
ai_edge_torch/generative/examples/phi/phi2.py,sha256=
|
57
|
-
ai_edge_torch/generative/examples/phi/phi3.py,sha256=
|
58
|
-
ai_edge_torch/generative/examples/phi/verify.py,sha256=
|
59
|
-
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=
|
62
|
+
ai_edge_torch/generative/examples/phi/phi2.py,sha256=82SEKRwtKfT9VcNQaykGmemiov_XaXWLi4Zyw9Vtmj0,6075
|
63
|
+
ai_edge_torch/generative/examples/phi/phi3.py,sha256=Xh-l7TQdXYZJ9PViRVk2_y91Ec7Yntn0UpkuzRIG3T8,9231
|
64
|
+
ai_edge_torch/generative/examples/phi/verify.py,sha256=YPFCdbnfmvq38fbpBNr0kHPfSZo4p3_6WkLJAW3pLPo,2177
|
65
|
+
ai_edge_torch/generative/examples/phi/verify_phi3.py,sha256=kVYaBVvddfQng0IyZGxyTJEzhiPO0G4VFJm2WOc2Q94,2360
|
60
66
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
61
67
|
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=zPrDTDeRVWFi9DS32uNi-RLpzOStFOk5MhNla4ixeew,2179
|
62
|
-
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=
|
63
|
-
ai_edge_torch/generative/examples/smollm/verify.py,sha256=
|
68
|
+
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=dal8vnZjQd6vR7sc76-FYGDKUlVjOlfUALV-pwbXJGc,3264
|
69
|
+
ai_edge_torch/generative/examples/smollm/verify.py,sha256=HXYcCjDJMylVL3Pc9HU-UXqtpjtIU25o1YhPiX30aPU,2361
|
64
70
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
65
71
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
66
|
-
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=
|
72
|
+
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=lwWrKY1NpnbvHQRenpltVN65QlzjWmSScl5CLSipBkc,6110
|
67
73
|
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
|
68
|
-
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=
|
69
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=
|
74
|
+
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=ClXNntmh0PF3s6U3C7SW3tyVrsSSrV2kyz-_RF4BcqA,15715
|
75
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=mBEAUYjV1qDJy9ZAsHtm9RGce0Mbzv0VoPZpdcQl1mk,33730
|
70
76
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
71
77
|
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
|
72
78
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
|
@@ -78,16 +84,16 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=ZE6H
|
|
78
84
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=RxR5rw0wFFm_5CfAY-3-EIz83vhM9EKye8Bb5zBb0Ok,1341
|
79
85
|
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
80
86
|
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=HHtZTtUh3QgE4F74-ru_8n1pt6cqfbObw12xoaMJ7NQ,4596
|
81
|
-
ai_edge_torch/generative/examples/t5/t5.py,sha256=
|
87
|
+
ai_edge_torch/generative/examples/t5/t5.py,sha256=gFTmPi-xB8pcPRgoF3DJxvH_fT-KWTb8ii77P5UbKR0,21263
|
82
88
|
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqXquaFQPvCFBFF5zOnmGVb3Hg,8731
|
83
89
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
84
90
|
ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
|
85
|
-
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=
|
86
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=
|
91
|
+
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
|
92
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=ZpjSIiayjTEVwg5Q1vI9Iy5tq1YSF5zaVDF4HTp_Z2s,4353
|
87
93
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
88
94
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=ekxd8efjMgEvauUu3PidWOC-DszPHn5sqU753F7sJIM,2201
|
89
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
90
|
-
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=
|
95
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=aSNHOAar5yPnGAeKsv8zrqYhOq9RR_7hwqHUMBb2mkM,5930
|
96
|
+
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299fpQNXYAudBbZoDQp9934xcvg50,2426
|
91
97
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
92
98
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
93
99
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -96,7 +102,7 @@ ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHif
|
|
96
102
|
ai_edge_torch/generative/layers/builder.py,sha256=oE8DdqLA-oWkBC2zySSCh8JNAJg_hk8-W_UoMSrgDVk,5088
|
97
103
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
98
104
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
|
99
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
105
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=Fa0eFCMlyfdwd3cM1drhP9vlXRhIguDrglsHn4ax2_w,6948
|
100
106
|
ai_edge_torch/generative/layers/normalization.py,sha256=cpo88JUXbF9j3sJTU4JuwOap9ryGV05C1QkPij-YQwU,6999
|
101
107
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
102
108
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
@@ -115,7 +121,7 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
|
|
115
121
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
116
122
|
ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
|
117
123
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=s-EVLOQGjIeVtgNI8Ggs37pkRdErAliT6NhrrFigPOE,5459
|
118
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
124
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=kCm-L3rWbPj25E_QEbkSLiaCk3y23SjKJs-MG-EwKug,8545
|
119
125
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
120
126
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
121
127
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
@@ -123,7 +129,8 @@ ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0b
|
|
123
129
|
ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
|
124
130
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
125
131
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
126
|
-
ai_edge_torch/generative/utilities/
|
132
|
+
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
133
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=wQ4EtIED_a6FRsaOXeoQVZiHNx07esOYCQYbDVLgZ2o,9520
|
127
134
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
128
135
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
129
136
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -170,8 +177,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
170
177
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
171
178
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
172
179
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
173
|
-
ai_edge_torch_nightly-0.3.0.
|
174
|
-
ai_edge_torch_nightly-0.3.0.
|
175
|
-
ai_edge_torch_nightly-0.3.0.
|
176
|
-
ai_edge_torch_nightly-0.3.0.
|
177
|
-
ai_edge_torch_nightly-0.3.0.
|
180
|
+
ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
181
|
+
ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/METADATA,sha256=3HuAFZTfvmU787dVypwpmUvo4DdZSekGsqGimO-oPfM,1897
|
182
|
+
ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
183
|
+
ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
184
|
+
ai_edge_torch_nightly-0.3.0.dev20240928.dist-info/RECORD,,
|
File without changes
|
File without changes
|