ai-edge-torch-nightly 0.2.0.dev20240604__py3-none-any.whl → 0.2.0.dev20240606__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +30 -0
- ai_edge_torch/convert/test/test_convert_composites.py +18 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +83 -49
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +7 -5
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +0 -260
- ai_edge_torch/generative/examples/t5/t5_attention.py +2 -2
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +161 -0
- ai_edge_torch/generative/layers/attention.py +27 -114
- ai_edge_torch/generative/layers/builder.py +4 -0
- ai_edge_torch/generative/layers/model_config.py +5 -1
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +117 -0
- ai_edge_torch/generative/test/test_model_conversion.py +90 -80
- ai_edge_torch/generative/utilities/loader.py +56 -27
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240606.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240606.dist-info}/RECORD +18 -16
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240606.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240606.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240604.dist-info → ai_edge_torch_nightly-0.2.0.dev20240606.dist-info}/top_level.txt +0 -0
|
@@ -33,7 +33,6 @@ class TestModelConversion(unittest.TestCase):
|
|
|
33
33
|
"""Unit tests that check for model conversion and correctness."""
|
|
34
34
|
|
|
35
35
|
def test_toy_model_with_kv_cache(self):
|
|
36
|
-
self.skipTest("b/338288901")
|
|
37
36
|
config = toy_model_with_kv_cache.get_model_config()
|
|
38
37
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
39
38
|
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
|
|
@@ -42,19 +41,21 @@ class TestModelConversion(unittest.TestCase):
|
|
|
42
41
|
|
|
43
42
|
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
44
43
|
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
44
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
45
|
+
skip_output_check = True
|
|
46
|
+
if skip_output_check is False:
|
|
47
|
+
self.assertTrue(
|
|
48
|
+
model_coverage.compare_tflite_torch(
|
|
49
|
+
edge_model,
|
|
50
|
+
pytorch_model,
|
|
51
|
+
(idx, input_pos),
|
|
52
|
+
num_valid_inputs=1,
|
|
53
|
+
atol=1e-5,
|
|
54
|
+
rtol=1e-5,
|
|
55
|
+
)
|
|
56
|
+
)
|
|
55
57
|
|
|
56
58
|
def test_toy_model_with_kv_cache_with_hlfb(self):
|
|
57
|
-
self.skipTest("b/338288901")
|
|
58
59
|
config = toy_model_with_kv_cache.get_model_config()
|
|
59
60
|
config.enable_hlfb = True
|
|
60
61
|
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
|
|
@@ -64,16 +65,19 @@ class TestModelConversion(unittest.TestCase):
|
|
|
64
65
|
|
|
65
66
|
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
|
|
66
67
|
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
68
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
69
|
+
skip_output_check = True
|
|
70
|
+
if skip_output_check is False:
|
|
71
|
+
self.assertTrue(
|
|
72
|
+
model_coverage.compare_tflite_torch(
|
|
73
|
+
edge_model,
|
|
74
|
+
pytorch_model,
|
|
75
|
+
(idx, input_pos),
|
|
76
|
+
num_valid_inputs=1,
|
|
77
|
+
atol=1e-5,
|
|
78
|
+
rtol=1e-5,
|
|
79
|
+
)
|
|
80
|
+
)
|
|
77
81
|
|
|
78
82
|
def test_tiny_llama(self):
|
|
79
83
|
self.skipTest("b/338288901")
|
|
@@ -87,19 +91,21 @@ class TestModelConversion(unittest.TestCase):
|
|
|
87
91
|
|
|
88
92
|
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
|
|
89
93
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
94
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
95
|
+
skip_output_check = True
|
|
96
|
+
if skip_output_check is False:
|
|
97
|
+
self.assertTrue(
|
|
98
|
+
model_coverage.compare_tflite_torch(
|
|
99
|
+
edge_model,
|
|
100
|
+
pytorch_model,
|
|
101
|
+
(tokens, input_pos),
|
|
102
|
+
num_valid_inputs=1,
|
|
103
|
+
atol=1e-5,
|
|
104
|
+
rtol=1e-5,
|
|
105
|
+
)
|
|
106
|
+
)
|
|
100
107
|
|
|
101
108
|
def test_tiny_llama_multisig(self):
|
|
102
|
-
self.skipTest("b/338288901")
|
|
103
109
|
config = tiny_llama.get_fake_model_config_for_test()
|
|
104
110
|
pytorch_model = tiny_llama.TinyLLamma(config)
|
|
105
111
|
|
|
@@ -122,32 +128,30 @@ class TestModelConversion(unittest.TestCase):
|
|
|
122
128
|
.convert()
|
|
123
129
|
)
|
|
124
130
|
|
|
125
|
-
#
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
)
|
|
150
|
-
)
|
|
131
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
132
|
+
skip_output_check = True
|
|
133
|
+
if skip_output_check is False:
|
|
134
|
+
copied_model = copy.deepcopy(pytorch_model)
|
|
135
|
+
|
|
136
|
+
self.assertTrue(
|
|
137
|
+
model_coverage.compare_tflite_torch(
|
|
138
|
+
edge_model,
|
|
139
|
+
pytorch_model,
|
|
140
|
+
(prefill_tokens, prefill_input_pos),
|
|
141
|
+
signature_name="prefill",
|
|
142
|
+
num_valid_inputs=1,
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
self.assertTrue(
|
|
147
|
+
model_coverage.compare_tflite_torch(
|
|
148
|
+
edge_model,
|
|
149
|
+
copied_model,
|
|
150
|
+
(decode_token, decode_input_pos),
|
|
151
|
+
signature_name="decode",
|
|
152
|
+
num_valid_inputs=1,
|
|
153
|
+
)
|
|
154
|
+
)
|
|
151
155
|
|
|
152
156
|
def test_gemma(self):
|
|
153
157
|
self.skipTest("b/338288901")
|
|
@@ -161,17 +165,20 @@ class TestModelConversion(unittest.TestCase):
|
|
|
161
165
|
|
|
162
166
|
edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
|
|
163
167
|
|
|
164
|
-
# TODO(
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
168
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
169
|
+
skip_output_check = True
|
|
170
|
+
if skip_output_check is False:
|
|
171
|
+
# TODO(talumbau, haoliang): debug numerical diff.
|
|
172
|
+
self.assertTrue(
|
|
173
|
+
model_coverage.compare_tflite_torch(
|
|
174
|
+
edge_model,
|
|
175
|
+
model,
|
|
176
|
+
(tokens, input_pos),
|
|
177
|
+
num_valid_inputs=1,
|
|
178
|
+
atol=1e-2,
|
|
179
|
+
rtol=1e-5,
|
|
180
|
+
)
|
|
181
|
+
)
|
|
175
182
|
|
|
176
183
|
def test_phi2(self):
|
|
177
184
|
self.skipTest("b/338288901")
|
|
@@ -185,16 +192,19 @@ class TestModelConversion(unittest.TestCase):
|
|
|
185
192
|
|
|
186
193
|
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
|
|
187
194
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
195
|
+
# TODO(b/338288901): re-enable test to check output tensors.
|
|
196
|
+
skip_output_check = True
|
|
197
|
+
if skip_output_check is False:
|
|
198
|
+
self.assertTrue(
|
|
199
|
+
model_coverage.compare_tflite_torch(
|
|
200
|
+
edge_model,
|
|
201
|
+
pytorch_model,
|
|
202
|
+
(tokens, input_pos),
|
|
203
|
+
num_valid_inputs=1,
|
|
204
|
+
atol=1e-5,
|
|
205
|
+
rtol=1e-5,
|
|
206
|
+
)
|
|
207
|
+
)
|
|
198
208
|
|
|
199
209
|
|
|
200
210
|
if __name__ == "__main__":
|
|
@@ -69,10 +69,16 @@ def load_pytorch_statedict(full_path: str):
|
|
|
69
69
|
Raises:
|
|
70
70
|
ValueError: If no tensors are loaded from the provided directory or file.
|
|
71
71
|
"""
|
|
72
|
-
pattern = os.path.join(full_path, "*.bin") if os.path.isdir(full_path) else full_path
|
|
73
72
|
files = []
|
|
74
|
-
|
|
75
|
-
|
|
73
|
+
patterns = []
|
|
74
|
+
if os.path.isdir(full_path):
|
|
75
|
+
patterns.append(os.path.join(full_path, "*.bin"))
|
|
76
|
+
patterns.append(os.path.join(full_path, "*.pt"))
|
|
77
|
+
else:
|
|
78
|
+
patterns.append(full_path)
|
|
79
|
+
for pattern in patterns:
|
|
80
|
+
for file in glob.glob(pattern):
|
|
81
|
+
files.append(file)
|
|
76
82
|
|
|
77
83
|
tensors = {}
|
|
78
84
|
for file in files:
|
|
@@ -93,18 +99,20 @@ class ModelLoader:
|
|
|
93
99
|
|
|
94
100
|
@dataclass
|
|
95
101
|
class TensorNames:
|
|
96
|
-
attn_query_proj: str
|
|
97
|
-
attn_key_proj: str
|
|
98
|
-
attn_value_proj: str
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
102
|
+
attn_query_proj: str = None
|
|
103
|
+
attn_key_proj: str = None
|
|
104
|
+
attn_value_proj: str = None
|
|
105
|
+
attn_fused_qkv_proj: str = None
|
|
106
|
+
attn_output_proj: str = None
|
|
107
|
+
|
|
108
|
+
ff_up_proj: str = None
|
|
109
|
+
ff_down_proj: str = None
|
|
103
110
|
ff_gate_proj: str = None
|
|
104
111
|
|
|
105
112
|
pre_attn_norm: str = None
|
|
106
113
|
pre_ff_norm: str = None
|
|
107
114
|
embedding: str = None
|
|
115
|
+
embedding_position: str = None
|
|
108
116
|
final_norm: str = None
|
|
109
117
|
lm_head: str = None
|
|
110
118
|
|
|
@@ -129,6 +137,10 @@ class ModelLoader:
|
|
|
129
137
|
strict (bool, optional): Whether the converted keys are strictly
|
|
130
138
|
matched. Defaults to True.
|
|
131
139
|
|
|
140
|
+
Returns:
|
|
141
|
+
missing_keys (List[str]): a list of str containing the missing keys
|
|
142
|
+
unexpected_keys (List[str]): a list of str containing the unexpected keys
|
|
143
|
+
|
|
132
144
|
Raises:
|
|
133
145
|
ValueError: If conversion results in unmapped tensors and strict mode is
|
|
134
146
|
enabled.
|
|
@@ -139,6 +151,10 @@ class ModelLoader:
|
|
|
139
151
|
converted_state["tok_embedding.weight"] = state.pop(
|
|
140
152
|
f"{self._names.embedding}.weight"
|
|
141
153
|
)
|
|
154
|
+
if self._names.embedding_position is not None:
|
|
155
|
+
converted_state["tok_embedding_position"] = state.pop(
|
|
156
|
+
f"{self._names.embedding_position}"
|
|
157
|
+
)
|
|
142
158
|
if self._names.lm_head is not None:
|
|
143
159
|
converted_state["lm_head.weight"] = state.pop(f"{self._names.lm_head}.weight")
|
|
144
160
|
if model.config.lm_head_use_bias:
|
|
@@ -158,7 +174,7 @@ class ModelLoader:
|
|
|
158
174
|
raise ValueError(
|
|
159
175
|
f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
|
|
160
176
|
)
|
|
161
|
-
model.load_state_dict(converted_state, strict=strict)
|
|
177
|
+
return model.load_state_dict(converted_state, strict=strict)
|
|
162
178
|
|
|
163
179
|
def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]:
|
|
164
180
|
"""A best effort method for finding appropriate state loader.
|
|
@@ -172,13 +188,15 @@ class ModelLoader:
|
|
|
172
188
|
if os.path.isdir(self._file_name):
|
|
173
189
|
if glob.glob(os.path.join(self._file_name, "*.safetensors")):
|
|
174
190
|
return load_safetensors
|
|
175
|
-
if glob.glob(os.path.join(self._file_name, "*.bin"))
|
|
191
|
+
if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob(
|
|
192
|
+
os.path.join(self._file_name, "*.pt")
|
|
193
|
+
):
|
|
176
194
|
return load_pytorch_statedict
|
|
177
195
|
|
|
178
196
|
if self._file_name.endswith(".safetensors"):
|
|
179
197
|
return load_safetensors
|
|
180
198
|
|
|
181
|
-
if self._file_name.endswith(".bin"):
|
|
199
|
+
if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"):
|
|
182
200
|
return load_pytorch_statedict
|
|
183
201
|
|
|
184
202
|
raise ValueError(f"File format not supported.")
|
|
@@ -225,22 +243,33 @@ class ModelLoader:
|
|
|
225
243
|
converted_state: Dict[str, torch.Tensor],
|
|
226
244
|
):
|
|
227
245
|
prefix = f"transformer_blocks.{idx}"
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
|
|
246
|
+
if self._names.attn_fused_qkv_proj:
|
|
247
|
+
fused_qkv_name = self._names.attn_fused_qkv_proj.format(idx)
|
|
248
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = state.pop(
|
|
249
|
+
f"{fused_qkv_name}.weight"
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
q_name = self._names.attn_query_proj.format(idx)
|
|
253
|
+
k_name = self._names.attn_key_proj.format(idx)
|
|
254
|
+
v_name = self._names.attn_value_proj.format(idx)
|
|
255
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv(
|
|
239
256
|
config,
|
|
240
|
-
state.pop(f"{q_name}.
|
|
241
|
-
state.pop(f"{k_name}.
|
|
242
|
-
state.pop(f"{v_name}.
|
|
257
|
+
state.pop(f"{q_name}.weight"),
|
|
258
|
+
state.pop(f"{k_name}.weight"),
|
|
259
|
+
state.pop(f"{v_name}.weight"),
|
|
243
260
|
)
|
|
261
|
+
if config.attn_config.qkv_use_bias:
|
|
262
|
+
if self._names.attn_fused_qkv_proj:
|
|
263
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = state.pop(
|
|
264
|
+
f"{fused_qkv_name}.bias"
|
|
265
|
+
)
|
|
266
|
+
else:
|
|
267
|
+
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
|
|
268
|
+
config,
|
|
269
|
+
state.pop(f"{q_name}.bias"),
|
|
270
|
+
state.pop(f"{k_name}.bias"),
|
|
271
|
+
state.pop(f"{v_name}.bias"),
|
|
272
|
+
)
|
|
244
273
|
|
|
245
274
|
o_name = self._names.attn_output_proj.format(idx)
|
|
246
275
|
converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.2.0.
|
|
3
|
+
Version: 0.2.0.dev20240606
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -6,7 +6,7 @@ ai_edge_torch/convert/conversion_utils.py,sha256=NpVm3Ms81_cIW5IYgGsr0BVganJJgBK
|
|
|
6
6
|
ai_edge_torch/convert/converter.py,sha256=bjj5TV5_g4sGyuSh8ThEDydlNMqhkGSY4SzXK6vwhqI,6927
|
|
7
7
|
ai_edge_torch/convert/fx_passes/__init__.py,sha256=Ll2nNwufjcV5nSruQPXiloq7F1E7pWJ2T5clXmy1lk8,2825
|
|
8
8
|
ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
|
|
9
|
-
ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=
|
|
9
|
+
ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=2yqUwJJ2R233_X9FNMOP9oYRTTzH34TR_BIUj-wfnKw,7080
|
|
10
10
|
ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py,sha256=76XYoIlFDgrzp5QemoaEalPFcEbfszkEH_PLvO1ASCk,2607
|
|
11
11
|
ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=UX6dJsxCqSkftXXvNBV-i7Bjk6H7qTyqzUnE640Itfg,1673
|
|
12
12
|
ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
|
|
@@ -22,7 +22,7 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partition
|
|
|
22
22
|
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=FlNKt2EhIKnlVEeUWTiv5sz446YKU6Yy1H0Gd6VRgkU,6432
|
|
23
23
|
ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
24
24
|
ai_edge_torch/convert/test/test_convert.py,sha256=2qPmmGqnfV_o1gfsSdjGq3-JR1b323ligiy5MdAv9NA,8021
|
|
25
|
-
ai_edge_torch/convert/test/test_convert_composites.py,sha256=
|
|
25
|
+
ai_edge_torch/convert/test/test_convert_composites.py,sha256=_Ojc-H6GOS5s8ek3_8eRBL_AiCs-k3srziPJ2R4Ulrg,7255
|
|
26
26
|
ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
|
|
27
27
|
ai_edge_torch/debug/__init__.py,sha256=TKvmnjVk3asvYcVh6C-LPr6srgAF_nppSAupWEXqwPY,707
|
|
28
28
|
ai_edge_torch/debug/culprit.py,sha256=vklaxBUfINdo44OsH7csILK70N41gEThCGchGEfbTZw,12789
|
|
@@ -40,10 +40,10 @@ ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=6nOuwx9q3AUlY
|
|
|
40
40
|
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=VvigzPQ_LJHeADTsMliwFwPe2BcnOhFgKDqr_WZ2JQ8,5540
|
|
41
41
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
42
42
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
|
|
43
|
-
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=
|
|
44
|
-
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=
|
|
43
|
+
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=yUCJemEh4n8ez-yLgVU0HZAki-PZ9nY04DFjgpx9PUc,3698
|
|
44
|
+
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=MI73RjOeD4Kh7AL0j5_QXiZq-rl_qCdibSE6eCQCyeY,3804
|
|
45
45
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=AgVAdUbSkHXONVUjAyBQEXhIUUlinf9kNljcBpWnj3A,3276
|
|
46
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=
|
|
46
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=TfbfsmuKoGsBENF9fYIAN_SMEQNhj-kjNdqQXFJGxpg,7784
|
|
47
47
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=L6hLaMQGb8-_BwSvTLIuDnZwfTqn0K4swBUjfPnYWZo,2341
|
|
48
48
|
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
|
|
49
49
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
|
|
@@ -56,22 +56,24 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5i
|
|
|
56
56
|
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
57
57
|
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
|
|
58
58
|
ai_edge_torch/generative/examples/t5/t5.py,sha256=q2gG5RRo7RgNzvHXYC0Juh6Tgt5d_RTMSWFaYvOKiZU,21065
|
|
59
|
-
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=
|
|
59
|
+
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rRgwCEdVtzcJEaGbbBjw8HxCxrCX3pXA5nelawdYiME,9036
|
|
60
60
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
61
61
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=EV07_MEG3fv9g0ZGu9gbBd5BjjrGkxCT1pv7dvhz4TI,3791
|
|
62
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=rzL5h7Z5DIEgfpc1pWgYHdKt2aR8ha_CUqTKQBSPBaU,5521
|
|
62
63
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=MUr6fSj2hBuYSlNbZtrBBpzqB_0WY-l_xYcd_TFFUjY,4831
|
|
63
64
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
64
65
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=E4I5OlC4zyl5cxiiu7uTED-zcwYRu210lP1zuT3xLBE,2566
|
|
65
66
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=hVGpuI8gpj4Rn9k4otsRE22MSLFHBDlUOgioY6Ru6VI,5629
|
|
66
67
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
67
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
|
68
|
+
ai_edge_torch/generative/layers/attention.py,sha256=zNIBXxCOA5Mz_F_dfBbKpIovhtcB6q5a-i8oAxls1d0,7071
|
|
68
69
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
|
|
69
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
|
70
|
+
ai_edge_torch/generative/layers/builder.py,sha256=WLTeDId9t3Xwt0h1zxzqoYyFvfrNzPKLskcl39q8Aqw,3403
|
|
70
71
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
|
|
71
72
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=4uiZLO3om5G3--kT04Jt0esEYznbkJ7QLzSHfb8mjc4,3090
|
|
72
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
|
73
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=2zT9nyoyuuyk5ziiww0VSJ6_JO7pDf7uOYbO9O3OQc4,4249
|
|
73
74
|
ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
|
|
74
75
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
|
|
76
|
+
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
|
|
75
77
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
76
78
|
ai_edge_torch/generative/quantize/example.py,sha256=t-YwyKSPAG-OZC1DfH-0vfie2RHHpTSQjxUY-tmhu5g,1543
|
|
77
79
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=ffBALrrbrfiG_mrOr-f3B1Gc6PlAma9gtvVnfP7SDzI,1862
|
|
@@ -81,10 +83,10 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=CRA2ENevS-3usHqidWDe2w
|
|
|
81
83
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=OQ4ghQXknA1PPjuY-xBgAmOpaIBgYFM8F2YAIot06hE,1345
|
|
82
84
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
83
85
|
ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
|
|
84
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
|
86
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=i_SAW-hD8SaHuopMZI9IuXXDFn5uSTJa1nKZhaC3dAQ,6811
|
|
85
87
|
ai_edge_torch/generative/test/test_quantize.py,sha256=f70sH1ZFzdCwYj0MG-eg54WOC4LasR0D8CTUYpjxZYM,3728
|
|
86
88
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
87
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
|
89
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=r-_hSanSjLZ_YXFpZUb0Up94u5F8JHp70Vf2nlONPSg,11269
|
|
88
90
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=guDTv-12UUvJGl4eDvvZX3t4rRKewfXO8SpcYXM6gbc,16156
|
|
89
91
|
ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
|
|
90
92
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
|
|
@@ -100,8 +102,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=ExThdTXqnWmGC3-F6sdXbXr8nYzkEe_qCz
|
|
|
100
102
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
101
103
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
102
104
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
|
|
103
|
-
ai_edge_torch_nightly-0.2.0.
|
|
104
|
-
ai_edge_torch_nightly-0.2.0.
|
|
105
|
-
ai_edge_torch_nightly-0.2.0.
|
|
106
|
-
ai_edge_torch_nightly-0.2.0.
|
|
107
|
-
ai_edge_torch_nightly-0.2.0.
|
|
105
|
+
ai_edge_torch_nightly-0.2.0.dev20240606.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
106
|
+
ai_edge_torch_nightly-0.2.0.dev20240606.dist-info/METADATA,sha256=2yFjQFvs93hoppwWWSJi-B9HC2n-h1s2rzjsyGXU1zI,1748
|
|
107
|
+
ai_edge_torch_nightly-0.2.0.dev20240606.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
|
|
108
|
+
ai_edge_torch_nightly-0.2.0.dev20240606.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
109
|
+
ai_edge_torch_nightly-0.2.0.dev20240606.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|