ai-edge-torch-nightly 0.2.0.dev20240806__py3-none-any.whl → 0.3.0.dev20240809__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/__init__.py +5 -5
- ai_edge_torch/{convert → _convert}/conversion.py +40 -50
- ai_edge_torch/_convert/conversion_utils.py +64 -0
- ai_edge_torch/{convert → _convert}/converter.py +83 -43
- ai_edge_torch/{convert → _convert}/fx_passes/__init__.py +9 -9
- ai_edge_torch/{convert → _convert}/fx_passes/build_aten_composite_pass.py +51 -26
- ai_edge_torch/{convert → _convert}/fx_passes/build_interpolate_composite_pass.py +11 -8
- ai_edge_torch/{convert → _convert}/fx_passes/canonicalize_pass.py +3 -4
- ai_edge_torch/{convert → _convert}/fx_passes/inject_mlir_debuginfo_pass.py +2 -2
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py +16 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_check.py +7 -5
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_mark.py +2 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py +1 -0
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +14 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +5 -6
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +17 -14
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +3 -2
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/pass_body.py +15 -17
- ai_edge_torch/{convert → _convert}/fx_passes/optimize_layout_transposes_pass/utils.py +2 -0
- ai_edge_torch/_convert/signature.py +100 -0
- ai_edge_torch/{convert → _convert}/test/test_convert.py +50 -52
- ai_edge_torch/{convert → _convert}/test/test_convert_composites.py +16 -12
- ai_edge_torch/{convert → _convert}/test/test_convert_multisig.py +6 -4
- ai_edge_torch/{convert → _convert}/test/test_to_channel_last_io.py +5 -4
- ai_edge_torch/{convert → _convert}/to_channel_last_io.py +4 -1
- ai_edge_torch/config.py +24 -0
- ai_edge_torch/conftest.py +20 -0
- ai_edge_torch/debug/culprit.py +22 -22
- ai_edge_torch/debug/test/test_culprit.py +4 -3
- ai_edge_torch/debug/test/test_search_model.py +5 -5
- ai_edge_torch/debug/utils.py +11 -2
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +3 -3
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/experimental/phi/phi2.py +4 -1
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +4 -5
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/gemma/gemma.py +4 -1
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/phi2/phi2.py +4 -1
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +2 -0
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py +3 -2
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +57 -20
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +20 -9
- ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py +1 -0
- ai_edge_torch/generative/examples/t5/t5.py +2 -2
- ai_edge_torch/generative/examples/t5/t5_attention.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +4 -1
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +6 -5
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +7 -7
- ai_edge_torch/generative/examples/tiny_llama/__init__.py +14 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +5 -5
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +4 -1
- ai_edge_torch/generative/fx_passes/__init__.py +2 -2
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +4 -3
- ai_edge_torch/generative/layers/attention.py +35 -26
- ai_edge_torch/generative/layers/attention_utils.py +23 -12
- ai_edge_torch/generative/layers/builder.py +0 -1
- ai_edge_torch/generative/layers/feed_forward.py +6 -10
- ai_edge_torch/generative/layers/kv_cache.py +0 -1
- ai_edge_torch/generative/layers/model_config.py +2 -5
- ai_edge_torch/generative/layers/normalization.py +5 -7
- ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -3
- ai_edge_torch/generative/layers/unet/blocks_2d.py +33 -26
- ai_edge_torch/generative/layers/unet/model_config.py +14 -15
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +14 -0
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +0 -2
- ai_edge_torch/generative/quantize/quant_recipe.py +8 -6
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +2 -1
- ai_edge_torch/generative/test/test_experimental_ekv.py +6 -7
- ai_edge_torch/generative/test/{loader_test.py → test_loader.py} +4 -3
- ai_edge_torch/generative/test/test_model_conversion.py +24 -25
- ai_edge_torch/generative/test/test_quantize.py +10 -5
- ai_edge_torch/generative/utilities/loader.py +12 -12
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +69 -24
- ai_edge_torch/generative/utilities/t5_loader.py +12 -13
- ai_edge_torch/hlfb/__init__.py +1 -1
- ai_edge_torch/hlfb/mark_pattern/__init__.py +9 -6
- ai_edge_torch/hlfb/mark_pattern/passes.py +23 -3
- ai_edge_torch/hlfb/mark_pattern/pattern.py +23 -23
- ai_edge_torch/hlfb/test/test_mark_pattern.py +13 -12
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +8 -6
- ai_edge_torch/{convert/fx_passes/optimize_layout_transposes_pass → lowertools}/__init__.py +1 -1
- ai_edge_torch/lowertools/_shim.py +80 -0
- ai_edge_torch/lowertools/common_utils.py +89 -0
- ai_edge_torch/lowertools/odml_torch_utils.py +211 -0
- ai_edge_torch/lowertools/torch_xla_utils.py +273 -0
- ai_edge_torch/model.py +14 -9
- ai_edge_torch/quantize/pt2e_quantizer.py +22 -9
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +13 -12
- ai_edge_torch/quantize/quant_config.py +7 -7
- ai_edge_torch/testing/model_coverage/model_coverage.py +19 -10
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/METADATA +1 -1
- ai_edge_torch_nightly-0.3.0.dev20240809.dist-info/RECORD +141 -0
- ai_edge_torch/convert/conversion_utils.py +0 -439
- ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/RECORD +0 -133
- /ai_edge_torch/{convert → _convert}/__init__.py +0 -0
- /ai_edge_torch/{convert → _convert}/fx_passes/_pass_base.py +0 -0
- /ai_edge_torch/{convert → _convert}/test/__init__.py +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.2.0.dev20240806.dist-info → ai_edge_torch_nightly-0.3.0.dev20240809.dist-info}/top_level.txt +0 -0
|
@@ -1,439 +0,0 @@
|
|
|
1
|
-
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
import collections
|
|
17
|
-
import copy
|
|
18
|
-
from dataclasses import dataclass
|
|
19
|
-
import gc
|
|
20
|
-
import itertools
|
|
21
|
-
import logging
|
|
22
|
-
import tempfile
|
|
23
|
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
24
|
-
|
|
25
|
-
from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
|
|
26
|
-
from ai_edge_torch.quantize import quant_config as qcfg
|
|
27
|
-
import torch
|
|
28
|
-
import torch.utils._pytree as pytree
|
|
29
|
-
from torch_xla import stablehlo
|
|
30
|
-
|
|
31
|
-
try:
|
|
32
|
-
import tensorflow as tf
|
|
33
|
-
|
|
34
|
-
from tensorflow.compiler.tf2xla.python import xla as tfxla
|
|
35
|
-
|
|
36
|
-
from tensorflow.lite.python import conversion_metadata_schema_py_generated as conversion_metadata_fb # isort:skip
|
|
37
|
-
except ImportError:
|
|
38
|
-
logging.error(
|
|
39
|
-
"This module needs tensorflow with xla support.\n"
|
|
40
|
-
"Please install tensorflow with `pip install tf-nightly`.\n"
|
|
41
|
-
)
|
|
42
|
-
raise
|
|
43
|
-
|
|
44
|
-
DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
@dataclass
|
|
48
|
-
class Signature:
|
|
49
|
-
name: str
|
|
50
|
-
module: torch.nn.Module
|
|
51
|
-
sample_args: tuple[torch.Tensor]
|
|
52
|
-
sample_kwargs: dict[str, torch.Tensor]
|
|
53
|
-
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
|
|
54
|
-
|
|
55
|
-
@property
|
|
56
|
-
def _normalized_sample_args_kwargs(self):
|
|
57
|
-
args, kwargs = self.sample_args, self.sample_kwargs
|
|
58
|
-
if args is not None:
|
|
59
|
-
if not isinstance(args, tuple):
|
|
60
|
-
# TODO(b/352584188): Check value types
|
|
61
|
-
raise ValueError("sample_args must be a tuple of torch tensors.")
|
|
62
|
-
if kwargs is not None:
|
|
63
|
-
if not isinstance(kwargs, dict) or not all(
|
|
64
|
-
isinstance(key, str) for key in kwargs.keys()
|
|
65
|
-
):
|
|
66
|
-
# TODO(b/352584188): Check value types
|
|
67
|
-
raise ValueError("sample_kwargs must be a dict of string to tensor.")
|
|
68
|
-
|
|
69
|
-
args = args if args is not None else tuple()
|
|
70
|
-
kwargs = kwargs if kwargs is not None else {}
|
|
71
|
-
return args, kwargs
|
|
72
|
-
|
|
73
|
-
@property
|
|
74
|
-
def flat_arg_names(self) -> list[str]:
|
|
75
|
-
spec = pytree.tree_flatten(self._normalized_sample_args_kwargs)[1]
|
|
76
|
-
args_spec, kwargs_spec = spec.children_specs
|
|
77
|
-
|
|
78
|
-
names = []
|
|
79
|
-
for i in range(args_spec.num_leaves):
|
|
80
|
-
names.append(f"args_{i}")
|
|
81
|
-
|
|
82
|
-
kwargs_names = self._flat_kwarg_names(
|
|
83
|
-
kwargs_spec.children_specs, kwargs_spec.context
|
|
84
|
-
)
|
|
85
|
-
names.extend(kwargs_names)
|
|
86
|
-
return names
|
|
87
|
-
|
|
88
|
-
def _flat_kwarg_names(self, specs, context) -> List[str]:
|
|
89
|
-
flat_names = []
|
|
90
|
-
if context is None:
|
|
91
|
-
for i, spec in enumerate(specs):
|
|
92
|
-
if spec.children_specs:
|
|
93
|
-
flat_names.extend([
|
|
94
|
-
f"{i}_{name}"
|
|
95
|
-
for name in self._flat_kwarg_names(
|
|
96
|
-
spec.children_specs, spec.context
|
|
97
|
-
)
|
|
98
|
-
])
|
|
99
|
-
else:
|
|
100
|
-
flat_names.append(f"{i}")
|
|
101
|
-
else:
|
|
102
|
-
flat_ctx = self._flatten_list(context)
|
|
103
|
-
for prefix, spec in zip(flat_ctx, specs):
|
|
104
|
-
leaf_flat_names = self._flat_kwarg_names(
|
|
105
|
-
spec.children_specs, spec.context
|
|
106
|
-
)
|
|
107
|
-
if leaf_flat_names:
|
|
108
|
-
flat_names.extend([f"{prefix}_{name}" for name in leaf_flat_names])
|
|
109
|
-
else:
|
|
110
|
-
flat_names.append(prefix)
|
|
111
|
-
|
|
112
|
-
return flat_names
|
|
113
|
-
|
|
114
|
-
def _flatten_list(self, l: List) -> List:
|
|
115
|
-
flattened = []
|
|
116
|
-
for item in l:
|
|
117
|
-
if isinstance(item, list):
|
|
118
|
-
flattened.extend(self._flatten_list(item))
|
|
119
|
-
else:
|
|
120
|
-
flattened.append(item)
|
|
121
|
-
return flattened
|
|
122
|
-
|
|
123
|
-
@property
|
|
124
|
-
def flat_args(self) -> tuple[Any]:
|
|
125
|
-
args, kwargs = self._normalized_sample_args_kwargs
|
|
126
|
-
return tuple([*args, *kwargs.values()])
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
def exported_program_to_stablehlo_bundle(
|
|
130
|
-
exported_program: torch.export.ExportedProgram,
|
|
131
|
-
sample_args: tuple[torch.Tensor],
|
|
132
|
-
) -> stablehlo.StableHLOModelBundle:
|
|
133
|
-
# Setting export_weights to False here so that pytorch/xla avoids copying the weights
|
|
134
|
-
# to a numpy array which would lead to memory bloat. This means that the state_dict
|
|
135
|
-
# in the returned bundle is going to be empty.
|
|
136
|
-
return stablehlo.exported_program_to_stablehlo(
|
|
137
|
-
exported_program,
|
|
138
|
-
stablehlo.StableHLOExportOptions(
|
|
139
|
-
override_tracing_arguments=sample_args, export_weights=False
|
|
140
|
-
),
|
|
141
|
-
)._bundle
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
def _torch_to_tf_tensor(torch_tensor: torch.Tensor):
|
|
145
|
-
if not torch_tensor.is_contiguous():
|
|
146
|
-
torch_tensor = torch_tensor.contiguous()
|
|
147
|
-
|
|
148
|
-
try:
|
|
149
|
-
dlpack_capsule = torch.utils.dlpack.to_dlpack(torch_tensor)
|
|
150
|
-
tf_tensor = tf.experimental.dlpack.from_dlpack(dlpack_capsule)
|
|
151
|
-
except Exception:
|
|
152
|
-
logging.info(
|
|
153
|
-
"Can not use dlpack to convert torch tensors. Falling back to numpy."
|
|
154
|
-
)
|
|
155
|
-
nparray = torch_tensor.cpu().detach().numpy()
|
|
156
|
-
tf_tensor = tf.convert_to_tensor(nparray)
|
|
157
|
-
|
|
158
|
-
return tf_tensor
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
def _get_states(
|
|
162
|
-
exported_programs: list[torch.export.ExportedProgram],
|
|
163
|
-
signatures: list[Signature],
|
|
164
|
-
):
|
|
165
|
-
for exported_program, signature in zip(exported_programs, signatures):
|
|
166
|
-
args, _ = exported_program.example_inputs
|
|
167
|
-
# Calling this to get **all** the state including model buffers.
|
|
168
|
-
_flat_input_args = exported_program._graph_module_flat_inputs(args, {})
|
|
169
|
-
for tensor, input_spec in zip(
|
|
170
|
-
_flat_input_args, exported_program.graph_signature.input_specs
|
|
171
|
-
):
|
|
172
|
-
# Only interested in Tensors that are part of the state (and not user input).
|
|
173
|
-
if (
|
|
174
|
-
not isinstance(tensor, torch.Tensor)
|
|
175
|
-
or input_spec.kind
|
|
176
|
-
== torch.export.graph_signature.InputKind.USER_INPUT
|
|
177
|
-
):
|
|
178
|
-
continue
|
|
179
|
-
yield signature, tensor, input_spec
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
def _tensor_unique_id(tensor: torch.Tensor):
|
|
183
|
-
return (
|
|
184
|
-
str(tensor.device),
|
|
185
|
-
tensor.shape,
|
|
186
|
-
tensor.stride(),
|
|
187
|
-
tensor.untyped_storage().data_ptr(),
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
def _gather_state_dict(
|
|
192
|
-
exported_programs: list[torch.export.ExportedProgram],
|
|
193
|
-
signatures: list[Signature],
|
|
194
|
-
):
|
|
195
|
-
deduped_tensor_map = {}
|
|
196
|
-
|
|
197
|
-
for _, tensor, _ in _get_states(exported_programs, signatures):
|
|
198
|
-
unique_id = _tensor_unique_id(tensor)
|
|
199
|
-
deduped_tensor_map[unique_id] = _torch_to_tf_tensor(tensor)
|
|
200
|
-
|
|
201
|
-
state_dict = {}
|
|
202
|
-
for signature, tensor, input_spec in _get_states(
|
|
203
|
-
exported_programs, signatures
|
|
204
|
-
):
|
|
205
|
-
unique_id = _tensor_unique_id(tensor)
|
|
206
|
-
state_dict[signature.name + "_" + input_spec.target] = deduped_tensor_map[
|
|
207
|
-
unique_id
|
|
208
|
-
]
|
|
209
|
-
|
|
210
|
-
return state_dict
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
def merge_stablehlo_bundles(
|
|
214
|
-
bundles: list[stablehlo.StableHLOModelBundle],
|
|
215
|
-
signatures: list[Signature],
|
|
216
|
-
exported_programs: list[torch.export.ExportedProgram],
|
|
217
|
-
) -> stablehlo.StableHLOGraphModule:
|
|
218
|
-
state_dict = _gather_state_dict(exported_programs, signatures)
|
|
219
|
-
|
|
220
|
-
new_bundle = stablehlo.StableHLOModelBundle(
|
|
221
|
-
state_dict=state_dict, additional_constants=[], stablehlo_funcs=[]
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
for bundle, signature in zip(bundles, signatures):
|
|
225
|
-
const_offset = len(new_bundle.additional_constants)
|
|
226
|
-
for func in bundle.stablehlo_funcs:
|
|
227
|
-
func.meta.name = signature.name + "_" + func.meta.name
|
|
228
|
-
for loc in func.meta.input_locations:
|
|
229
|
-
if loc.type_ == stablehlo.VariableType.CONSTANT:
|
|
230
|
-
loc.position += const_offset
|
|
231
|
-
elif loc.type_ == stablehlo.VariableType.PARAMETER:
|
|
232
|
-
loc.name = signature.name + "_" + loc.name
|
|
233
|
-
new_bundle.stablehlo_funcs.append(func)
|
|
234
|
-
new_bundle.additional_constants.extend(bundle.additional_constants)
|
|
235
|
-
return stablehlo.StableHLOGraphModule(new_bundle)
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
def _get_shape_with_dynamic(signature: stablehlo.VariableSignature):
|
|
239
|
-
shape = copy.copy(signature.shape)
|
|
240
|
-
for i in signature.dynamic_dims:
|
|
241
|
-
shape[i] = None
|
|
242
|
-
return shape
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
def _wrap_as_tf_func(
|
|
246
|
-
func: stablehlo.StableHLOFunc, bundle: stablehlo.StableHLOModelBundle
|
|
247
|
-
):
|
|
248
|
-
def inner(*args):
|
|
249
|
-
type_info = [sig.dtype for sig in func.meta.output_signature]
|
|
250
|
-
shape_info = [
|
|
251
|
-
_get_shape_with_dynamic(sig) for sig in func.meta.output_signature
|
|
252
|
-
]
|
|
253
|
-
call_args = stablehlo._extract_call_parameters(args, func.meta, bundle)
|
|
254
|
-
return tfxla.call_module(
|
|
255
|
-
tuple(call_args),
|
|
256
|
-
version=5,
|
|
257
|
-
Tout=type_info,
|
|
258
|
-
Sout=shape_info,
|
|
259
|
-
function_list=[],
|
|
260
|
-
module=func.bytecode,
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
return inner
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
def _make_tf_function(
|
|
267
|
-
shlo_graph_module: stablehlo.StableHLOGraphModule,
|
|
268
|
-
bundle: stablehlo.StableHLOModelBundle = None,
|
|
269
|
-
):
|
|
270
|
-
bundle = shlo_graph_module._bundle if bundle is None else bundle
|
|
271
|
-
return [
|
|
272
|
-
_wrap_as_tf_func(func, bundle)
|
|
273
|
-
for func in shlo_graph_module._bundle.stablehlo_funcs
|
|
274
|
-
]
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
def _make_tf_signature(
|
|
278
|
-
meta: stablehlo.StableHLOFunctionMeta,
|
|
279
|
-
signature: Signature,
|
|
280
|
-
) -> list[tf.TensorSpec]:
|
|
281
|
-
input_names = signature.flat_arg_names
|
|
282
|
-
input_pos_to_spec = {
|
|
283
|
-
loc.position: spec
|
|
284
|
-
for loc, spec in itertools.chain(
|
|
285
|
-
zip(meta.input_locations, meta.input_signature), meta.unused_inputs
|
|
286
|
-
)
|
|
287
|
-
if loc.type_ == stablehlo.VariableType.INPUT_ARG
|
|
288
|
-
}
|
|
289
|
-
assert len(input_pos_to_spec) == len(input_names)
|
|
290
|
-
|
|
291
|
-
primitive_type_to_tf_type = {"int": "int32", "float": "float32"}
|
|
292
|
-
ret: list[tf.TensorSpec] = []
|
|
293
|
-
for i, name in enumerate(input_names):
|
|
294
|
-
spec = input_pos_to_spec[i]
|
|
295
|
-
shape = _get_shape_with_dynamic(spec)
|
|
296
|
-
ret.append(
|
|
297
|
-
tf.TensorSpec(
|
|
298
|
-
shape=shape,
|
|
299
|
-
dtype=primitive_type_to_tf_type[spec.dtype]
|
|
300
|
-
if spec.dtype in primitive_type_to_tf_type
|
|
301
|
-
else spec.dtype,
|
|
302
|
-
name=name,
|
|
303
|
-
)
|
|
304
|
-
)
|
|
305
|
-
return ret
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
def _apply_tfl_backdoor_flags(
|
|
309
|
-
converter: tf.lite.TFLiteConverter, tfl_converter_flags: dict
|
|
310
|
-
):
|
|
311
|
-
def _set_converter_flag(path: list):
|
|
312
|
-
if len(path) < 2:
|
|
313
|
-
raise ValueError("Expecting at least two values in the path.")
|
|
314
|
-
|
|
315
|
-
target_obj = converter
|
|
316
|
-
for idx in range(len(path) - 2):
|
|
317
|
-
target_obj = getattr(target_obj, path[idx])
|
|
318
|
-
|
|
319
|
-
setattr(target_obj, path[-2], path[-1])
|
|
320
|
-
|
|
321
|
-
def _iterate_dict_tree(flags_dict: dict, path: list):
|
|
322
|
-
for key, value in flags_dict.items():
|
|
323
|
-
path.append(key)
|
|
324
|
-
if isinstance(value, dict):
|
|
325
|
-
_iterate_dict_tree(value, path)
|
|
326
|
-
else:
|
|
327
|
-
path.append(value)
|
|
328
|
-
_set_converter_flag(path)
|
|
329
|
-
path.pop()
|
|
330
|
-
path.pop()
|
|
331
|
-
|
|
332
|
-
_iterate_dict_tree(tfl_converter_flags, [])
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
def _set_tfl_converter_quant_flags(
|
|
336
|
-
converter: tf.lite.TFLiteConverter, quant_config: qcfg.QuantConfig
|
|
337
|
-
):
|
|
338
|
-
if quant_config is not None:
|
|
339
|
-
quantizer_mode = quant_config._quantizer_mode
|
|
340
|
-
if quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_DYNAMIC:
|
|
341
|
-
converter._experimental_qdq_conversion_mode = "DYNAMIC"
|
|
342
|
-
elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_STATIC:
|
|
343
|
-
converter._experimental_qdq_conversion_mode = "STATIC"
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
def convert_stablehlo_to_tflite(
|
|
347
|
-
shlo_graph_module: stablehlo.StableHLOGraphModule,
|
|
348
|
-
signatures: list[Signature],
|
|
349
|
-
*,
|
|
350
|
-
quant_config: Optional[qcfg.QuantConfig] = None,
|
|
351
|
-
_tfl_converter_flags: dict = {},
|
|
352
|
-
) -> None:
|
|
353
|
-
"""Converts a StableHLOGraphModule to a tflite model.
|
|
354
|
-
Args:
|
|
355
|
-
shlo_graph_module - model to export and save
|
|
356
|
-
signatures: List of signatures from which names of the signatures is extracted.
|
|
357
|
-
quant_config: User-defined quantization method and scheme of the model.
|
|
358
|
-
_tfl_converter_flags: A nested dictionary allowing setting flags for the underlying tflite converter.
|
|
359
|
-
"""
|
|
360
|
-
|
|
361
|
-
bundle = shlo_graph_module._bundle
|
|
362
|
-
tf_module = tf.Module()
|
|
363
|
-
bundle.state_dict = {
|
|
364
|
-
k: tf.Variable(v, trainable=False) for k, v in bundle.state_dict.items()
|
|
365
|
-
}
|
|
366
|
-
bundle.additional_constants = [
|
|
367
|
-
tf.Variable(v, trainable=False) for v in bundle.additional_constants
|
|
368
|
-
]
|
|
369
|
-
tf_signatures: list[list[tf.TensorSpec]] = list(
|
|
370
|
-
_make_tf_signature(func.meta, sig)
|
|
371
|
-
for func, sig in zip(bundle.stablehlo_funcs, signatures)
|
|
372
|
-
)
|
|
373
|
-
|
|
374
|
-
tf_functions = _make_tf_function(shlo_graph_module, bundle)
|
|
375
|
-
|
|
376
|
-
tf_module.f = []
|
|
377
|
-
for tf_sig, func in zip(tf_signatures, tf_functions):
|
|
378
|
-
tf_module.f.append(
|
|
379
|
-
tf.function(
|
|
380
|
-
func,
|
|
381
|
-
input_signature=tf_sig,
|
|
382
|
-
)
|
|
383
|
-
)
|
|
384
|
-
|
|
385
|
-
tf_module._variables = (
|
|
386
|
-
list(bundle.state_dict.values()) + bundle.additional_constants
|
|
387
|
-
)
|
|
388
|
-
del bundle
|
|
389
|
-
gc.collect()
|
|
390
|
-
|
|
391
|
-
tf_concrete_funcs = [
|
|
392
|
-
func.get_concrete_function(*tf_sig)
|
|
393
|
-
for func, tf_sig in zip(tf_module.f, tf_signatures)
|
|
394
|
-
]
|
|
395
|
-
|
|
396
|
-
# We need to temporarily save since TFLite's from_concrete_functions does not
|
|
397
|
-
# allow providing names for each of the concrete functions.
|
|
398
|
-
with tempfile.TemporaryDirectory() as temp_dir_path:
|
|
399
|
-
tf.saved_model.save(
|
|
400
|
-
tf_module,
|
|
401
|
-
temp_dir_path,
|
|
402
|
-
signatures={
|
|
403
|
-
sig.name: tf_concrete_funcs[idx]
|
|
404
|
-
for idx, sig in enumerate(signatures)
|
|
405
|
-
},
|
|
406
|
-
)
|
|
407
|
-
# Clean up intermediate memory early.
|
|
408
|
-
del tf_module
|
|
409
|
-
del tf_concrete_funcs
|
|
410
|
-
gc.collect()
|
|
411
|
-
|
|
412
|
-
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
|
|
413
|
-
converter._set_original_model_type(conversion_metadata_fb.ModelType.PYTORCH)
|
|
414
|
-
converter._experimental_enable_composite_direct_lowering = True
|
|
415
|
-
|
|
416
|
-
_set_tfl_converter_quant_flags(converter, quant_config)
|
|
417
|
-
if (
|
|
418
|
-
quant_config is not None
|
|
419
|
-
and quant_config._quantizer_mode
|
|
420
|
-
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
|
421
|
-
):
|
|
422
|
-
translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
|
|
423
|
-
quant_config.generative_recipe
|
|
424
|
-
)
|
|
425
|
-
|
|
426
|
-
_apply_tfl_backdoor_flags(converter, _tfl_converter_flags)
|
|
427
|
-
|
|
428
|
-
tflite_model = converter.convert()
|
|
429
|
-
|
|
430
|
-
if (
|
|
431
|
-
quant_config is not None
|
|
432
|
-
and quant_config._quantizer_mode
|
|
433
|
-
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
|
|
434
|
-
):
|
|
435
|
-
tflite_model = translate_recipe.quantize_model(
|
|
436
|
-
tflite_model, translated_recipe
|
|
437
|
-
)
|
|
438
|
-
|
|
439
|
-
return tflite_model
|
|
@@ -1,133 +0,0 @@
|
|
|
1
|
-
ai_edge_torch/__init__.py,sha256=WTuorXzCALfr89FC4kX_PBtKOQLipN1hcW2tMDSQW9w,1100
|
|
2
|
-
ai_edge_torch/model.py,sha256=pSyY9O7J1i-SJu7g4mFD853MJBNFE6LSzBgJw7dtWuI,4494
|
|
3
|
-
ai_edge_torch/version.py,sha256=h26UNeBme8QimcI1g_yZcMkf5IeuglS6yqwXpkaRSfM,706
|
|
4
|
-
ai_edge_torch/convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
5
|
-
ai_edge_torch/convert/conversion.py,sha256=bkOyaTTZR9lT1VJMxwCSjcplheYv1HNSwt8A9kEo388,4183
|
|
6
|
-
ai_edge_torch/convert/conversion_utils.py,sha256=GAOFepARe_vxOaetplMBBaexxojSijJzXvkxft88-Lc,13945
|
|
7
|
-
ai_edge_torch/convert/converter.py,sha256=6BoHl_GEIOkTr1oBg-VzZb5tr6Rv9yDwxKczYd6cu1o,7956
|
|
8
|
-
ai_edge_torch/convert/to_channel_last_io.py,sha256=b7Q0_6Lam6IV-3TyhabVTMS7j0ppFpKDOIHTNAw2PnI,2814
|
|
9
|
-
ai_edge_torch/convert/fx_passes/__init__.py,sha256=D4Xe8YmeP2N0yEN_bc7pEJH47KkwGFf4COZOILmDL4w,2809
|
|
10
|
-
ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=WVYZuocpygHAzk9u1GNoGowAIOHTlJXyA_NklmYkRms,1672
|
|
11
|
-
ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=QaZ5JV7RazGbC2Khdai795vlO5jDc3yhgx3HHNmzHDs,8246
|
|
12
|
-
ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py,sha256=BWSU9nkD5DzxHI_WGcs9uH6qKWCw0XB2etDEV6PsZkg,4181
|
|
13
|
-
ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=eW0Yae2cL2ALYVkhsuk3wX8v41P6bkGaABtRgdPCdxk,1672
|
|
14
|
-
ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
|
|
15
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=VA9bekxPVhLk4MYlIRXnOzrSnbCtUmGj7OQ_fJcKQtc,795
|
|
16
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=KrMDtpRVgxpS6dxgT_shjYYjL8Ij3L0PNLpn-StSUU0,7546
|
|
17
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=uXCHC23pWN-3JmDtAErWbSUnL8jjlQgUAy4gqtfDsQU,1560
|
|
18
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=_FuPbJewiPTqb-aNXR-qiujvsI4J0z6p5JWp8AIg6qE,12496
|
|
19
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=o9PAcAgvS5uG0xA2io2XEWaELgwPODRRJAkfegob4so,981
|
|
20
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=sJqKFDR67svsMh9t0jFav0CzpMZCw29PV3yJ-LCjtoY,10752
|
|
21
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=bItkXVaPA9THcFypAmqldpkLuD8WpOFmKlhVbBJJkPk,2076
|
|
22
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=B-zisphkH7aRCUOJNdwHnTA0fQXuDpN08q3Qjy5bL6E,715
|
|
23
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=FkNNS7tkPm0oanUhjipJxV91-mkcL3YYBj1a8uODmfw,2296
|
|
24
|
-
ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=iAYFw6pK9sjXi_uEYRxzezIkHXQosxjNzIhGmpfRFWM,7190
|
|
25
|
-
ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
26
|
-
ai_edge_torch/convert/test/test_convert.py,sha256=k7YPpLKQ-_M89jzf0mftrga_F55B7drfreWkAr9GPWw,12789
|
|
27
|
-
ai_edge_torch/convert/test/test_convert_composites.py,sha256=tEBhunjRz6WXPidPTSwMVGfwNYCDBrXbcJ1WOUACL1U,7682
|
|
28
|
-
ai_edge_torch/convert/test/test_convert_multisig.py,sha256=XzLgxxqVEVn00JEFUeu6dXJi71pWsX0FwVwXgvZpbZs,4623
|
|
29
|
-
ai_edge_torch/convert/test/test_to_channel_last_io.py,sha256=fRR_NkvfUnsleZgNc5fS9Y4apyiRgOX-3tLNE-uSlCA,2929
|
|
30
|
-
ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
|
|
31
|
-
ai_edge_torch/debug/culprit.py,sha256=PQaeR_csuF6F6rR9JrmltGSCkpGx1PxLyPkUiMzoj7w,14785
|
|
32
|
-
ai_edge_torch/debug/utils.py,sha256=gpK1PbiKc6KRMbtpgsBVgTNqd-RZWhqXcFJVDVlvhEI,1437
|
|
33
|
-
ai_edge_torch/debug/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
34
|
-
ai_edge_torch/debug/test/test_culprit.py,sha256=4dwskvGKHhDqzPQDFJkiifhD3505ljFEEj13h9KqBg4,3736
|
|
35
|
-
ai_edge_torch/debug/test/test_search_model.py,sha256=tWmoMJe81ssOc22Id9J2buNNC3j7QeIt7bP8WW0L57M,1603
|
|
36
|
-
ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
37
|
-
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
38
|
-
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
39
|
-
ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
40
|
-
ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
41
|
-
ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=Tfy8GhWakUGBjuEG9kOLpffwcrnuWF93UzTshK_yGaM,3085
|
|
42
|
-
ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=EJQLQqx5M2v6oNzmf8M2o4dg6I3wZ4ZWngoASW4EXpM,6634
|
|
43
|
-
ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
44
|
-
ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=_0RoLi6ElYGkIVqKpDuIyGiUjhHjbyQaZjcL2iVNYh4,3055
|
|
45
|
-
ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=jYiekxKoXpGhjnsKTQJC3dTiAY1h9B7hFsOtvNiTShA,6178
|
|
46
|
-
ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
47
|
-
ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=sLL9ULX29IveaN5XoFqCm2DW4XBbtBF-CHaJygnKDgU,3125
|
|
48
|
-
ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=PEr9olL5oINCwQK8AS1Ba4VdoavOA3eVKDxMAYiOnDk,6319
|
|
49
|
-
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
50
|
-
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=leyFwQI35Q_OCYo91j9cbKAam72A127AVVomzEqd6rs,2540
|
|
51
|
-
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=BshAPWJ96fo6YHqFiwVQWrRxVLRIJJeSk2vTRbHhzw8,6182
|
|
52
|
-
ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
53
|
-
ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=uXbmtefNnzOF7rTOQ69Gv1Xuod-PyW_ysU60T1l3RVQ,2524
|
|
54
|
-
ai_edge_torch/generative/examples/phi2/phi2.py,sha256=tYtpIaxFWh-fyDmKCdYB1I6g-UJp0dmUUObIRO_VxN0,5805
|
|
55
|
-
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
56
|
-
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
|
57
|
-
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=4L3u6R0KDDN3B4BthU2Lwvc8Tuw5M0ZR_y__Uwo7VN8,4424
|
|
58
|
-
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=7ra36nM5tQwSw-vi6QCFLx5IssZhT-6yVK4H3XsAc4w,5044
|
|
59
|
-
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=NUnrzwU-77iJw0mXbWKsgmTYk6iS_GMzGf8Fb3iJ5Xc,13970
|
|
60
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=S3nRz_bJdXjxJa29eJMPLAgbehjsAdQSROTBA7AmEGg,29160
|
|
61
|
-
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
|
62
|
-
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=sYMd9OFa_VnMkn5bZ1ZA1CPhmdRHtIIcLw7j3CkOANw,8624
|
|
63
|
-
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
|
|
64
|
-
ai_edge_torch/generative/examples/stable_diffusion/util.py,sha256=XIXIB0vCvQKOGyIyiZeiIA5DLeSXjkudywvJS4FK7AM,2431
|
|
65
|
-
ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py,sha256=uQWKzCD_49ackNFrt50H04dkDXxfAwUCtMWWQre5SVE,830
|
|
66
|
-
ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py,sha256=wBBNM24waZ57M1rXonwesfUkKe9DqpqO3eW6BfZkrD0,2323
|
|
67
|
-
ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py,sha256=c89ldwtuQ2_yspGrGa7oh7fsvTt6A86Whxa6fBK9YOQ,2526
|
|
68
|
-
ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=ZE6HyOoBJrmTh54KVFf7DjNBnBS0pT4cgviYaq8HGMU,2801
|
|
69
|
-
ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5iRfU5MO6GR6K3WrdddIU_9U7ZZGEEb7zGKVY1WFl-8,1340
|
|
70
|
-
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
71
|
-
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=CZVuNEL8OHPkdsz70WOvNpTJ9LFkiDnlwgJiXfUZCVk,4548
|
|
72
|
-
ai_edge_torch/generative/examples/t5/t5.py,sha256=WUKIxq4cBO2SkcZSwrruIghquWij70rhfbr78M8Ivew,20861
|
|
73
|
-
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=FpiAPmeZL4c9BlxOkLoZPzVm3P8JL3zwLqPs68xDqaA,8427
|
|
74
|
-
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
75
|
-
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=RgqS5OuKiZb_EYS61i6toVRqUdNQTUzMGuiEGs6NbdU,3903
|
|
76
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=76whgq2mmHYUpNmZ1b_5fBigrrHHVbgC6kuNGvAB9zU,5795
|
|
77
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=IzK2gSkZAgBjWQwIURUfh7W19E6Ejkw9GrphgoiUkRg,4852
|
|
78
|
-
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
79
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=rnozSJHU-4UjasyIDM-Q2DvXcdckoHcy4lgb3cpSiS0,2568
|
|
80
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=3UmZonW9x9cg-HjNBrpeDnoWdSRC711cOSwN0sZ1_wA,5876
|
|
81
|
-
ai_edge_torch/generative/fx_passes/__init__.py,sha256=C5Xkh1OFSV9Xw68Q3JVQ7BYPjr1o7O6sjnmUhKeb3dg,1171
|
|
82
|
-
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=CQhQ7HGtkMHfUUBdOoa1I8fsNxnCf3Uzndvd0QQ7G5M,2005
|
|
83
|
-
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
84
|
-
ai_edge_torch/generative/layers/attention.py,sha256=ECSzuP6tlwliSAIK8Qu021L2YxqNlmoS_8er5CsyHWU,12032
|
|
85
|
-
ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
|
|
86
|
-
ai_edge_torch/generative/layers/builder.py,sha256=BKc1JbKuW0AIlPzeoTXOaPBLWTCVERTON8qYPu7RFr0,4162
|
|
87
|
-
ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
|
|
88
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=nVFfWx6HzWrPeF5FRErx5JvgUPJz-qqRvFqChTpxGc8,3099
|
|
89
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=8jDECxQUmmUMDFke67NtTy2LDTt8OiA9iMc55b-JGTU,5048
|
|
90
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
|
|
91
|
-
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=eYOmQC-nVUz6sdTou8xIIaBgQZ6aum09NA2QAI-CRnM,1389
|
|
92
|
-
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=6WMe-A5KSSujQcZ34hIeSnnor3AXrw10cQ5FKy-30IU,3390
|
|
93
|
-
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
94
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=evbrY-tBGjnlcKyZ1a44cY5XsTG9oOFXelTIxhhll1o,26911
|
|
95
|
-
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
|
96
|
-
ai_edge_torch/generative/layers/unet/model_config.py,sha256=GU12QEJwO6ukveMR9JRsrhE0YIPKuhk1U81CylmOQTA,9097
|
|
97
|
-
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
98
|
-
ai_edge_torch/generative/quantize/example.py,sha256=mqi3zFUp4w198DGnRkmZCWUZdUXTkvg1_tdTdOk9IkA,1535
|
|
99
|
-
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
|
100
|
-
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=TOPmTa92pozBST6hiizhteiWkgla9oVdiF3d5ToCEoc,5152
|
|
101
|
-
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=5yCOwHTUA-SgWqP27pvCLPBj1z_AcjXCqyPwQFo15O8,2270
|
|
102
|
-
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
|
|
103
|
-
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
|
104
|
-
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
105
|
-
ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=460YflyuWSVxcLSMpdVAaO9n_4NYjqtBLSDWBQjpD5M,5276
|
|
106
|
-
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
107
|
-
ai_edge_torch/generative/test/loader_test.py,sha256=WfH3IJvKzqum0HcrD16E0yvO6TA9ZUt2rthc82vVtsk,3342
|
|
108
|
-
ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=TJWNOS8iM5iQWvBAA33r5AeYnGvm9w_GxTCbfV93flw,4317
|
|
109
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=JoyV1CBkykKwA9o9SUq-DDMrpkwdHKNsNW_y073bKOY,7588
|
|
110
|
-
ai_edge_torch/generative/test/test_quantize.py,sha256=PttH_FH8U63U4CfMKJPfHd1_BMlTmdjt_Ko0s9FEGF0,5149
|
|
111
|
-
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
112
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=r7sh2o35lHbTxXzC1-nj-Q-iO5XJvJBpBcDXminjV6c,11771
|
|
113
|
-
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=orwszJ-K2TFb1MsmqpD31IoZWMQH79NTDj6Ieu-jXig,33979
|
|
114
|
-
ai_edge_torch/generative/utilities/t5_loader.py,sha256=WJr8bkYYn6sSO_J6Rb2vzBOh6AYlOdgLp3HTbcds7fs,16838
|
|
115
|
-
ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
|
|
116
|
-
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=EQfw6kreyvOa964JBX7CIN95jj7LgipWxvSTF6EpieY,4798
|
|
117
|
-
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=YV2YKBkh7y7j7sd7EA81vf_1hUKUvTRiy1pfqZustXc,1539
|
|
118
|
-
ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=SwlCyFKMD2VSOwabNkHaJ1ZWHHyo9bRH-rdgTHBA_oY,9817
|
|
119
|
-
ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
120
|
-
ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=RT3AcDcNCdH9IW7j3UadrZmDcv21A3zZX7O5Zxo8TA4,4275
|
|
121
|
-
ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=lCQQmrJl_EG7g2eRHAeol1G2DdcWO9_s24sSz9LrODY,8254
|
|
122
|
-
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
|
123
|
-
ai_edge_torch/quantize/pt2e_quantizer.py,sha256=7Yun-SdfJB4QKmKLR1Py5QFCMDc2mj4Ymy9bxVpE8eI,15703
|
|
124
|
-
ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=4uCQAy_9HPgv4xSQa9_EQY6xPGjPQsUklZYsKv3SbcM,36182
|
|
125
|
-
ai_edge_torch/quantize/quant_config.py,sha256=yP93mRbsB03K1_dYCRIKgxRNEP4EJOYF68Rfb4w8CDg,3184
|
|
126
|
-
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
127
|
-
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
128
|
-
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=049yZFfnlVefQJAXkcn84ETzVneaZIlz8e0X1BW3vvI,4520
|
|
129
|
-
ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
130
|
-
ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/METADATA,sha256=AzY7SHlmHXXuEnqfieHBKvaqVX45QtymK82O_1MEs6A,1885
|
|
131
|
-
ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
|
132
|
-
ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
133
|
-
ai_edge_torch_nightly-0.2.0.dev20240806.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|