ai-edge-torch-nightly 0.3.0.dev20250123__py3-none-any.whl → 0.3.0.dev20250125__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_config.py +9 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +11 -8
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +22 -24
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +3 -4
- ai_edge_torch/generative/examples/deepseek/__init__.py +14 -0
- ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/deepseek/deepseek.py +92 -0
- ai_edge_torch/generative/examples/deepseek/verify.py +70 -0
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +3 -0
- ai_edge_torch/generative/layers/experimental/__init__.py +14 -0
- ai_edge_torch/generative/layers/experimental/attention.py +269 -0
- ai_edge_torch/generative/layers/experimental/kv_cache.py +314 -0
- ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py +97 -0
- ai_edge_torch/generative/layers/experimental/types.py +97 -0
- ai_edge_torch/generative/layers/kv_cache.py +2 -1
- ai_edge_torch/generative/layers/model_config.py +5 -1
- ai_edge_torch/generative/test/test_model_conversion_large.py +11 -2
- ai_edge_torch/generative/utilities/bmm_4d.py +76 -0
- ai_edge_torch/generative/utilities/converter.py +18 -2
- ai_edge_torch/generative/utilities/model_builder.py +6 -1
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +1 -1
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +22 -2
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/RECORD +28 -18
- {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250123.dist-info → ai_edge_torch_nightly-0.3.0.dev20250125.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,97 @@
|
|
1
|
+
# Copyright 2025 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
# A listing of types describes the K and V tensors in KV caches.
|
16
|
+
|
17
|
+
import enum
|
18
|
+
from enum import Enum, auto
|
19
|
+
from typing import Tuple
|
20
|
+
from torch import nn
|
21
|
+
|
22
|
+
|
23
|
+
@enum.unique
|
24
|
+
class TensorDims(Enum):
|
25
|
+
BATCH = enum.auto()
|
26
|
+
SEQUENCE = enum.auto()
|
27
|
+
NUM_HEADS = enum.auto()
|
28
|
+
HEAD_DIM = enum.auto()
|
29
|
+
MODEL_DIM = enum.auto() # often num_heads * head_dim
|
30
|
+
|
31
|
+
|
32
|
+
DIM_TO_LETTER = {
|
33
|
+
TensorDims.BATCH: 'B',
|
34
|
+
TensorDims.SEQUENCE: 'T',
|
35
|
+
TensorDims.NUM_HEADS: 'N',
|
36
|
+
TensorDims.HEAD_DIM: 'H',
|
37
|
+
TensorDims.MODEL_DIM: 'D',
|
38
|
+
}
|
39
|
+
|
40
|
+
|
41
|
+
class TensorDimensionMeta(type):
|
42
|
+
"""Metaclass to create classes representing an order of tensor dimensions."""
|
43
|
+
|
44
|
+
def __new__(cls, name, bases, attrs, dimensions: Tuple[TensorDims]):
|
45
|
+
"""Creates a new class with the given name and tensor dimension order.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
name: Name of the new class.
|
49
|
+
bases: Base classes for the new class.
|
50
|
+
attrs: Attributes for the new class.
|
51
|
+
dimensions: A tuple of TensorDims defining the order.
|
52
|
+
"""
|
53
|
+
|
54
|
+
attrs['dimensions'] = (
|
55
|
+
dimensions # Store the dimensions as a class attribute
|
56
|
+
)
|
57
|
+
return super().__new__(cls, name, bases, attrs)
|
58
|
+
|
59
|
+
def __init__(cls, name, bases, attrs, dimensions: Tuple[TensorDims]):
|
60
|
+
super().__init__(name, bases, attrs)
|
61
|
+
|
62
|
+
def __repr__(cls):
|
63
|
+
return f'{cls.__name__}'
|
64
|
+
|
65
|
+
|
66
|
+
def create_tensor_dimension_order_class(dims: Tuple[TensorDims]):
|
67
|
+
"""Creates a TensorDimensionMeta class with the specified dimensions.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
dimensions: A tuple of TensorDims.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
A new class representing the tensor dimension order.
|
74
|
+
"""
|
75
|
+
name = ''.join(DIM_TO_LETTER[d] for d in dims)
|
76
|
+
# Derive from nn.Module for torch tracing compatiblity.
|
77
|
+
return TensorDimensionMeta(name, (nn.Module,), {}, dimensions=dims)
|
78
|
+
|
79
|
+
|
80
|
+
BTNH = create_tensor_dimension_order_class((
|
81
|
+
TensorDims.BATCH,
|
82
|
+
TensorDims.SEQUENCE,
|
83
|
+
TensorDims.NUM_HEADS,
|
84
|
+
TensorDims.HEAD_DIM,
|
85
|
+
))
|
86
|
+
BNTH = create_tensor_dimension_order_class((
|
87
|
+
TensorDims.BATCH,
|
88
|
+
TensorDims.NUM_HEADS,
|
89
|
+
TensorDims.SEQUENCE,
|
90
|
+
TensorDims.HEAD_DIM,
|
91
|
+
))
|
92
|
+
BNHT = create_tensor_dimension_order_class((
|
93
|
+
TensorDims.BATCH,
|
94
|
+
TensorDims.NUM_HEADS,
|
95
|
+
TensorDims.HEAD_DIM,
|
96
|
+
TensorDims.SEQUENCE,
|
97
|
+
))
|
@@ -81,7 +81,8 @@ class KVCache:
|
|
81
81
|
"""
|
82
82
|
caches = [
|
83
83
|
KVCacheEntry.from_model_config(
|
84
|
-
config.kv_cache_max
|
84
|
+
config.kv_cache_max if not config.block_config(idx).kv_cache_max_len
|
85
|
+
else config.block_config(idx).kv_cache_max_len,
|
85
86
|
config.block_config(idx).attn_config,
|
86
87
|
dtype,
|
87
88
|
device,
|
@@ -164,6 +164,9 @@ class TransformerBlockConfig:
|
|
164
164
|
parallel_residual: bool = False
|
165
165
|
# The Attention computation will include relative positional bias.
|
166
166
|
relative_attention: bool = False
|
167
|
+
# KV Cache length for this block. Only used when attention types are different
|
168
|
+
# across blocks
|
169
|
+
kv_cache_max_len: Optional[int] = None
|
167
170
|
|
168
171
|
|
169
172
|
@dataclasses.dataclass
|
@@ -200,7 +203,8 @@ class ModelConfig:
|
|
200
203
|
embedding_use_bias: bool = False
|
201
204
|
# Image embedding parameters.
|
202
205
|
image_embedding: Optional[ImageEmbeddingConfig] = None
|
203
|
-
|
206
|
+
# Number of image tokens
|
207
|
+
num_mm_tokens_per_image: Optional[int] = None
|
204
208
|
# Use bias term within LLM's HEAD.
|
205
209
|
lm_head_use_bias: bool = False
|
206
210
|
# Whether LLM's HEAD shares the weight of the embedding.
|
@@ -17,6 +17,7 @@
|
|
17
17
|
|
18
18
|
import ai_edge_torch
|
19
19
|
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
|
20
|
+
from ai_edge_torch.generative.examples.deepseek import deepseek
|
20
21
|
from ai_edge_torch.generative.examples.gemma import gemma1
|
21
22
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
23
|
from ai_edge_torch.generative.examples.llama import llama
|
@@ -150,16 +151,15 @@ class TestModelConversion(googletest.TestCase):
|
|
150
151
|
ai_edge_torch.config.in_oss,
|
151
152
|
reason="tests with custom ops are not supported in oss",
|
152
153
|
)
|
153
|
-
|
154
154
|
def test_smollm2(self):
|
155
155
|
config = smollm.get_fake_model_config_v2()
|
156
156
|
pytorch_model = smollm.SmolLM2(config).eval()
|
157
157
|
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
158
|
+
|
158
159
|
@googletest.skipIf(
|
159
160
|
ai_edge_torch.config.in_oss,
|
160
161
|
reason="tests with custom ops are not supported in oss",
|
161
162
|
)
|
162
|
-
|
163
163
|
def test_openelm(self):
|
164
164
|
config = openelm.get_fake_model_config()
|
165
165
|
pytorch_model = openelm.OpenELM(config).eval()
|
@@ -174,6 +174,15 @@ class TestModelConversion(googletest.TestCase):
|
|
174
174
|
pytorch_model = qwen.Qwen(config).eval()
|
175
175
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
176
176
|
|
177
|
+
@googletest.skipIf(
|
178
|
+
ai_edge_torch.config.in_oss,
|
179
|
+
reason="tests with custom ops are not supported in oss",
|
180
|
+
)
|
181
|
+
def test_deepseek(self):
|
182
|
+
config = deepseek.get_fake_model_config()
|
183
|
+
pytorch_model = deepseek.DeepSeekDistillQwen(config).eval()
|
184
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
|
185
|
+
|
177
186
|
@googletest.skipIf(
|
178
187
|
ai_edge_torch.config.in_oss,
|
179
188
|
reason="tests with custom ops are not supported in oss",
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
# Common utility functions for data loading etc.
|
16
|
+
from dataclasses import dataclass
|
17
|
+
import glob
|
18
|
+
import os
|
19
|
+
from typing import Sequence
|
20
|
+
from ai_edge_torch.odml_torch import lowerings
|
21
|
+
from jax._src.lib.mlir import ir
|
22
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
23
|
+
import torch
|
24
|
+
|
25
|
+
|
26
|
+
# Use torch.library.custom_op to define a new custom operator.
|
27
|
+
@torch.library.custom_op("ai_edge_torch::bmm_4d", mutates_args=())
|
28
|
+
def bmm_4d(
|
29
|
+
lhs: torch.Tensor,
|
30
|
+
rhs: torch.Tensor,
|
31
|
+
) -> torch.Tensor:
|
32
|
+
if not (lhs.ndim == 4 and rhs.ndim == 4):
|
33
|
+
raise ValueError("bmm_4d requires LHS and RHS have rank 4.")
|
34
|
+
d0_can_bcast = lhs.shape[0] == rhs.shape[0] or lhs.shape[0] == 1 or rhs.shape[0] == 1
|
35
|
+
d1_can_bcast = lhs.shape[1] == rhs.shape[1] or lhs.shape[1] == 1 or rhs.shape[1] == 1
|
36
|
+
if not (d0_can_bcast and d1_can_bcast):
|
37
|
+
raise ValueError("bmm_4d requires that dimensions 0 and 1 can broadcast.")
|
38
|
+
|
39
|
+
if not lhs.shape[-1] == rhs.shape[-1]:
|
40
|
+
raise ValueError("bmm_4d requires LHS and RHS have same last dimension.")
|
41
|
+
|
42
|
+
return torch.einsum("abcd,abed->abce", lhs, rhs)
|
43
|
+
|
44
|
+
|
45
|
+
# Use register_fake to add a ``FakeTensor`` kernel for the operator
|
46
|
+
@bmm_4d.register_fake
|
47
|
+
def _(lhs, rhs):
|
48
|
+
return torch.einsum("abcd,abed->abce", lhs, rhs)
|
49
|
+
|
50
|
+
|
51
|
+
@lowerings.lower(torch.ops.ai_edge_torch.bmm_4d)
|
52
|
+
def _bmm_4d_lower(
|
53
|
+
lctx,
|
54
|
+
lhs: ir.Value,
|
55
|
+
rhs: ir.Value,
|
56
|
+
):
|
57
|
+
dot_dnums = stablehlo.DotDimensionNumbers.get(
|
58
|
+
lhs_batching_dimensions=[0, 1],
|
59
|
+
rhs_batching_dimensions=[0, 1],
|
60
|
+
lhs_contracting_dimensions=(3,),
|
61
|
+
rhs_contracting_dimensions=(3,),
|
62
|
+
)
|
63
|
+
return stablehlo.dot_general(
|
64
|
+
ir.RankedTensorType.get(
|
65
|
+
(
|
66
|
+
lhs.type.shape[0],
|
67
|
+
lhs.type.shape[1],
|
68
|
+
lhs.type.shape[2],
|
69
|
+
rhs.type.shape[2],
|
70
|
+
),
|
71
|
+
lhs.type.element_type,
|
72
|
+
),
|
73
|
+
lhs,
|
74
|
+
rhs,
|
75
|
+
dot_dnums,
|
76
|
+
)
|
@@ -19,7 +19,6 @@ import os
|
|
19
19
|
from typing import Optional, Union
|
20
20
|
from ai_edge_torch._convert import converter as converter_utils
|
21
21
|
from ai_edge_torch.generative.layers import lora as lora_utils
|
22
|
-
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
23
22
|
import ai_edge_torch.generative.layers.model_config as cfg
|
24
23
|
from ai_edge_torch.generative.quantize import quant_recipes
|
25
24
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
@@ -151,9 +150,21 @@ def _export_helper(
|
|
151
150
|
else None
|
152
151
|
)
|
153
152
|
|
153
|
+
if export_config.prefill_mask is None:
|
154
|
+
prefill_masks = None
|
155
|
+
elif isinstance(export_config.prefill_mask, torch.Tensor):
|
156
|
+
prefill_masks = [export_config.prefill_mask]
|
157
|
+
elif isinstance(export_config.prefill_mask, list):
|
158
|
+
prefill_masks = export_config.prefill_mask
|
159
|
+
else:
|
160
|
+
raise ValueError('Prefill masks unrecognized.')
|
161
|
+
|
162
|
+
if prefill_masks:
|
163
|
+
assert len(prefill_masks) == len(prefill_seq_lens)
|
164
|
+
|
154
165
|
decode_token = torch.tensor([[0]], dtype=torch.int)
|
155
166
|
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
156
|
-
kv =
|
167
|
+
kv = export_config.kvcache_cls.from_model_config(config)
|
157
168
|
|
158
169
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
159
170
|
|
@@ -174,6 +185,9 @@ def _export_helper(
|
|
174
185
|
'input_pos': prefill_input_pos,
|
175
186
|
'kv_cache': kv,
|
176
187
|
}
|
188
|
+
if prefill_masks is not None:
|
189
|
+
sample_kwargs['mask'] = prefill_masks[i]
|
190
|
+
|
177
191
|
if lora is not None:
|
178
192
|
prefill_signature_name += f'_lora_r{lora.get_rank()}'
|
179
193
|
sample_kwargs['lora'] = lora
|
@@ -199,6 +213,8 @@ def _export_helper(
|
|
199
213
|
'input_pos': decode_input_pos,
|
200
214
|
'kv_cache': kv,
|
201
215
|
}
|
216
|
+
if export_config.decode_mask is not None:
|
217
|
+
sample_kwargs['mask'] = export_config.decode_mask
|
202
218
|
if lora is not None:
|
203
219
|
sample_kwargs['lora'] = lora
|
204
220
|
|
@@ -17,7 +17,7 @@
|
|
17
17
|
|
18
18
|
import copy
|
19
19
|
from dataclasses import dataclass
|
20
|
-
from typing import Optional, Tuple
|
20
|
+
from typing import List, Optional, Tuple
|
21
21
|
|
22
22
|
from ai_edge_torch.generative.layers import attention
|
23
23
|
from ai_edge_torch.generative.layers import builder
|
@@ -55,6 +55,11 @@ class ExportConfig:
|
|
55
55
|
# On prefill signatures, should the model produce logit output?
|
56
56
|
# When False, only decode signatures will produce output.
|
57
57
|
output_logits_on_prefill: bool = False
|
58
|
+
# Attention masks given as inputs to the model.
|
59
|
+
prefill_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
|
60
|
+
decode_mask: Optional[torch.Tensor | List[torch.Tensor]] = None
|
61
|
+
# The KV Cache class for K and V buffers in attention.
|
62
|
+
kvcache_cls: type = kv_utils.KVCache
|
58
63
|
|
59
64
|
|
60
65
|
class DecoderOnlyModel(nn.Module):
|
@@ -21,8 +21,6 @@ from typing import Callable, Dict, List, NamedTuple, Optional
|
|
21
21
|
import torch
|
22
22
|
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
|
23
23
|
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
24
|
-
from torch.ao.quantization.pt2e.utils import _conv1d_bn_example_inputs
|
25
|
-
from torch.ao.quantization.pt2e.utils import _conv2d_bn_example_inputs
|
26
24
|
from torch.ao.quantization.pt2e.utils import _get_aten_graph_module_for_pattern
|
27
25
|
from torch.ao.quantization.quantizer import QuantizationAnnotation
|
28
26
|
from torch.ao.quantization.quantizer import QuantizationSpec
|
@@ -47,6 +45,28 @@ __all__ = [
|
|
47
45
|
"propagate_annotation",
|
48
46
|
]
|
49
47
|
|
48
|
+
# Example inputs for conv-bn1d patterns
|
49
|
+
_conv1d_bn_example_inputs = (
|
50
|
+
torch.randn(1, 1, 3), # x
|
51
|
+
torch.randn(1, 1, 1), # conv_weight
|
52
|
+
torch.randn(1), # conv_bias
|
53
|
+
torch.randn(1), # bn_weight
|
54
|
+
torch.randn(1), # bn_bias
|
55
|
+
torch.randn(1), # bn_running_mean
|
56
|
+
torch.randn(1), # bn_running_var
|
57
|
+
)
|
58
|
+
|
59
|
+
# Example inputs for conv-bn2d patterns
|
60
|
+
_conv2d_bn_example_inputs = (
|
61
|
+
torch.randn(1, 1, 3, 3), # x
|
62
|
+
torch.randn(1, 1, 1, 1), # conv_weight
|
63
|
+
torch.randn(1), # conv_bias
|
64
|
+
torch.randn(1), # bn_weight
|
65
|
+
torch.randn(1), # bn_bias
|
66
|
+
torch.randn(1), # bn_running_mean
|
67
|
+
torch.randn(1), # bn_running_var
|
68
|
+
)
|
69
|
+
|
50
70
|
|
51
71
|
@dataclass(eq=True, frozen=True)
|
52
72
|
class QuantizationConfig:
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250125
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -1,8 +1,8 @@
|
|
1
1
|
ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,1208
|
2
|
-
ai_edge_torch/_config.py,sha256=
|
2
|
+
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=yuz53SwRvngiQ41D-VX7MPmVGe-Vi-UR3v12E-o3P4I,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -16,11 +16,11 @@ ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZv
|
|
16
16
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
|
17
17
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=UKC-wM93-oe8spxyFqgybJ0TwnSRw8f-SOA2glCh2FA,890
|
18
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py,sha256=aWO_zHDF4j_hokoKJQNFIFmua4ysXztsgS6pcyBUht0,1082
|
19
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=
|
19
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=7yEKSfXskXUk4tsd7c8vL155O-iU4eUjXCU5RSZHrbw,8204
|
20
20
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
|
21
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=
|
21
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=OCFcPP618zH8IE12KTBQm2hRTtsaSeO3egvlOBUpNxA,13911
|
22
22
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=OhisegHY2j4cv_m9auCh9Mq9qmm1lUqpFLVO9X-oBlc,1032
|
23
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=
|
23
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=mr0MiLbaQmU-3S3KT-vb58kRWbNT3VJiCKY-K7_3tFg,10556
|
24
24
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
|
25
25
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=D8VX8SbCzfoyvPgMFHK7yxD7R-bzLxp2gfdKxgrWekA,742
|
26
26
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
@@ -49,6 +49,10 @@ ai_edge_torch/generative/examples/amd_llama_135m/__init__.py,sha256=hHLluseD2R0H
|
|
49
49
|
ai_edge_torch/generative/examples/amd_llama_135m/amd_llama_135m.py,sha256=urNif89PyCXbdXT5spOeDvdM5luJ-a5HaXHM86v4JnU,2766
|
50
50
|
ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=Oqlg5ZoUuG2aU3067QaPpmEXWOdB8GEq7u_NWoBpoB4,2337
|
51
51
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
|
52
|
+
ai_edge_torch/generative/examples/deepseek/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
53
|
+
ai_edge_torch/generative/examples/deepseek/convert_to_tflite.py,sha256=I5eA-XfFdHjYwDsLIjn23T2e-IgnSCQ129-5DOU8j44,2532
|
54
|
+
ai_edge_torch/generative/examples/deepseek/deepseek.py,sha256=AOAJ7ltXwY5IbmcCP2nVHW9FmRwexzfNxnoDlR-sW9c,2885
|
55
|
+
ai_edge_torch/generative/examples/deepseek/verify.py,sha256=sDYBhmE_CeZw5iLIQ7rJNGLjhcTyKUQGdg7_QQBh9WM,2398
|
52
56
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
53
57
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=8HJi0cutxPstafVNs2LfBKdUzufVucje1Vrfjw_RS_g,2527
|
54
58
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=MX8fZhJJPZ5IoMiNHX0tLkRpHYqVuh4qhW0rkeIfmYw,2529
|
@@ -97,7 +101,7 @@ ai_edge_torch/generative/examples/smollm/verify.py,sha256=KpYxVz_lv61YWy6HLfwT68
|
|
97
101
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
98
102
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
99
103
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=5M4auM33SgCTODt0VT8TO-EVILruqGDRiNILBPeB83Y,6072
|
100
|
-
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=
|
104
|
+
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=GtwKAByEk0ENGEWbUmC2mAAPkbLZ3M5xH1HIToyu8QE,5307
|
101
105
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=sQKQ-k6H9kG2brgwLsktjCMeN2h0POyfMP6iNsPNKWc,16271
|
102
106
|
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=6W58LxmHHkz2ctgpknQkyoDANZAnE9Byp_svfqLpQf0,34793
|
103
107
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
@@ -128,12 +132,17 @@ ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8k
|
|
128
132
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
129
133
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
130
134
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
131
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
135
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=sGGAZD0mWYuO4FukZfDbHXoxpBOBE9lTYICvZzDj5F8,6400
|
132
136
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
133
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
137
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=ZVRWEGw1BnLbLCuoR71kWGqQteKp-UM1YvMbbWYlkNw,7999
|
134
138
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
135
139
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
136
140
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
|
141
|
+
ai_edge_torch/generative/layers/experimental/__init__.py,sha256=nz-K0h8DfiATHzR6s1_bCw2akUmHWffU1bDRSkIzSqI,592
|
142
|
+
ai_edge_torch/generative/layers/experimental/attention.py,sha256=KC1UkIhaPx2DNRfkxCXO7eZZMeNm2UxkjFi-fB8HVhw,9212
|
143
|
+
ai_edge_torch/generative/layers/experimental/kv_cache.py,sha256=gE_q8YoSzOhGgbSm0K91jXkbFKnFJpuYf-hxMzLNw78,8976
|
144
|
+
ai_edge_torch/generative/layers/experimental/scaled_dot_product_attention.py,sha256=1vMh1L3uYX4ptKQMWcAjxkL1v2-g0jmOiuai8ydp0dc,2879
|
145
|
+
ai_edge_torch/generative/layers/experimental/types.py,sha256=bPPxw6TOCZVWdeDP3vCbOnjNP5-bdUMmfsfO-EtdazQ,2847
|
137
146
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
138
147
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
|
139
148
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
@@ -151,14 +160,15 @@ ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOryp
|
|
151
160
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
152
161
|
ai_edge_torch/generative/test/test_lora.py,sha256=6QIM6RLTc2HrodGpp_aS3OxM9Rco2KAzEnYgotkg41M,5310
|
153
162
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
|
154
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
163
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=AJs_ARfWUqwuFRwYtQQOLd87CiD4mUDwAhq885cqc4Q,12875
|
155
164
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
156
165
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
157
166
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
158
|
-
ai_edge_torch/generative/utilities/
|
167
|
+
ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
|
168
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=6siSpCvH_cLV-eP40lkF_AqjBpYv68xeMRQ722fKgE0,8065
|
159
169
|
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
160
170
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
161
|
-
ai_edge_torch/generative/utilities/model_builder.py,sha256=
|
171
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=5WqcxpeTdt51nVoUwt9g5kKB5wQKj2eYbiaz7k6Ofxg,6815
|
162
172
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
163
173
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
164
174
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
@@ -197,7 +207,7 @@ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw5
|
|
197
207
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
198
208
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
199
209
|
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
|
200
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
210
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=fte81SZxgxeMcI3wWVKSTnUjIxVVilOJ6H3TybXyDmQ,11558
|
201
211
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
202
212
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
203
213
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
@@ -207,13 +217,13 @@ ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZm
|
|
207
217
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
208
218
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
209
219
|
ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
|
210
|
-
ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=
|
220
|
+
ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=nuO3w9gOj9sKcsTBBexVDw8UZnd06KsjNrFr_gyNaiA,36710
|
211
221
|
ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9PphCRdO8o,3172
|
212
222
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
213
223
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
214
224
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
215
|
-
ai_edge_torch_nightly-0.3.0.
|
216
|
-
ai_edge_torch_nightly-0.3.0.
|
217
|
-
ai_edge_torch_nightly-0.3.0.
|
218
|
-
ai_edge_torch_nightly-0.3.0.
|
219
|
-
ai_edge_torch_nightly-0.3.0.
|
225
|
+
ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
226
|
+
ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/METADATA,sha256=BkUH2iAinJYGmBLTMdeYSpihXAHY_mBOkeprZLPaDGk,1966
|
227
|
+
ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
228
|
+
ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
229
|
+
ai_edge_torch_nightly-0.3.0.dev20250125.dist-info/RECORD,,
|
File without changes
|
File without changes
|