ai-edge-torch-nightly 0.3.0.dev20250122__py3-none-any.whl → 0.3.0.dev20250124__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_config.py +9 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +11 -8
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +22 -24
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +3 -4
- ai_edge_torch/generative/layers/kv_cache.py +2 -1
- ai_edge_torch/generative/layers/model_config.py +5 -1
- ai_edge_torch/generative/utilities/bmm_4d.py +76 -0
- ai_edge_torch/generative/utilities/converter.py +5 -0
- ai_edge_torch/generative/utilities/model_builder.py +3 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +1 -1
- ai_edge_torch/quantize/pt2e_quantizer_utils.py +22 -2
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250122.dist-info → ai_edge_torch_nightly-0.3.0.dev20250124.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20250122.dist-info → ai_edge_torch_nightly-0.3.0.dev20250124.dist-info}/RECORD +17 -16
- {ai_edge_torch_nightly-0.3.0.dev20250122.dist-info → ai_edge_torch_nightly-0.3.0.dev20250124.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250122.dist-info → ai_edge_torch_nightly-0.3.0.dev20250124.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20250122.dist-info → ai_edge_torch_nightly-0.3.0.dev20250124.dist-info}/top_level.txt +0 -0
ai_edge_torch/_config.py
CHANGED
@@ -65,5 +65,14 @@ class _Config:
|
|
65
65
|
def enable_group_norm_composite(self, value: bool):
|
66
66
|
os.environ["ENABLE_GROUP_NORM_COMPOSITE"] = "y" if value else "n"
|
67
67
|
|
68
|
+
@property
|
69
|
+
def layout_optimize_partitioner(self) -> str:
|
70
|
+
"""The algorithm to use for layout optimization."""
|
71
|
+
return os.environ.get("AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER", "DEFAULT")
|
72
|
+
|
73
|
+
@layout_optimize_partitioner.setter
|
74
|
+
def layout_optimize_partitioner(self, value: str):
|
75
|
+
os.environ["AIEDGETORCH_LAYOUT_OPTIMIZE_PARTITIONER"] = str(value).upper()
|
76
|
+
|
68
77
|
|
69
78
|
config = _Config()
|
@@ -201,8 +201,14 @@ def _aten_group_norm_checker(node):
|
|
201
201
|
return NHWCable(can_be=can_be, must_be=must_be)
|
202
202
|
|
203
203
|
|
204
|
-
@nhwcable_node_checkers.register(aten.native_group_norm)
|
204
|
+
@nhwcable_node_checkers.register(aten.native_group_norm.default)
|
205
205
|
def _aten_native_group_norm_checker(node):
|
206
|
+
# aten.group_norm is removed from the decomp table, so aten.native_group_norm
|
207
|
+
# should never exist in the graph. However, torch 2.5.1 could ignore the
|
208
|
+
# decomp table updates, so still add this native_group_norm checker and
|
209
|
+
# rewriter to be safe.
|
210
|
+
# The checker and rewriter are the same as the ones for aten.group_norm.
|
211
|
+
|
206
212
|
val = node.meta.get("val")
|
207
213
|
if (
|
208
214
|
not isinstance(val, (list, tuple))
|
@@ -210,13 +216,10 @@ def _aten_native_group_norm_checker(node):
|
|
210
216
|
or not hasattr(val[0], "shape")
|
211
217
|
):
|
212
218
|
return NHWCable(can_be=False, must_be=False)
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
# TODO(b/354780253): Re-enable NHWC rewriter with proper lowering.
|
218
|
-
return NHWCable(can_be=False, must_be=False)
|
219
|
-
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
|
219
|
+
|
220
|
+
can_be = len(val[0].shape) == 4
|
221
|
+
must_be = can_be and ai_edge_torch.config.enable_group_norm_composite
|
222
|
+
return NHWCable(can_be=can_be, must_be=must_be)
|
220
223
|
|
221
224
|
|
222
225
|
# ==== Ops must be NCHW
|
@@ -391,34 +391,32 @@ def _aten_native_group_norm(node):
|
|
391
391
|
eps: float,
|
392
392
|
**kwargs,
|
393
393
|
):
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
flattened_inner_size,
|
399
|
-
num_groups,
|
400
|
-
num_channels // num_groups,
|
401
|
-
],
|
402
|
-
)
|
403
|
-
reduction_dims = [1, 3]
|
404
|
-
|
405
|
-
biased_var, mean = torch.var_mean(
|
406
|
-
input_reshaped, dim=reduction_dims, unbiased=False, keepdim=True
|
394
|
+
is_composite_supported = (
|
395
|
+
ai_edge_torch.config.enable_group_norm_composite
|
396
|
+
and weight is not None
|
397
|
+
and bias is not None
|
407
398
|
)
|
408
|
-
rstd = torch.rsqrt(biased_var + eps)
|
409
|
-
|
410
|
-
out = (input_reshaped - mean) * rstd
|
411
|
-
out = torch.reshape(out, input.shape)
|
412
399
|
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
400
|
+
builder = None
|
401
|
+
if is_composite_supported:
|
402
|
+
builder = StableHLOCompositeBuilder(
|
403
|
+
name="odml.group_norm",
|
404
|
+
attr={
|
405
|
+
"num_groups": num_groups,
|
406
|
+
"epsilon": eps,
|
407
|
+
"reduction_axes": [3],
|
408
|
+
"channel_axis": 3,
|
409
|
+
},
|
410
|
+
)
|
411
|
+
input, weight, bias = builder.mark_inputs(input, weight, bias)
|
417
412
|
|
418
|
-
|
419
|
-
|
413
|
+
input = utils.tensor_to_nchw(input)
|
414
|
+
output = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
|
415
|
+
output = utils.tensor_to_nhwc(output)
|
420
416
|
|
421
|
-
|
417
|
+
if builder is not None:
|
418
|
+
output = builder.mark_outputs(output)
|
419
|
+
return (output, None, None)
|
422
420
|
|
423
421
|
node.target = native_group_norm
|
424
422
|
|
@@ -18,6 +18,7 @@ import operator
|
|
18
18
|
import os
|
19
19
|
from typing import Union
|
20
20
|
|
21
|
+
import ai_edge_torch
|
21
22
|
from ai_edge_torch import fx_infra
|
22
23
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_check # NOQA
|
23
24
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark # NOQA
|
@@ -261,10 +262,8 @@ class OptimizeLayoutTransposesPass(fx_infra.ExportedProgramPassBase):
|
|
261
262
|
self.mark_const_nodes(exported_program)
|
262
263
|
|
263
264
|
graph_module = exported_program.graph_module
|
264
|
-
partitioner =
|
265
|
-
|
266
|
-
)
|
267
|
-
if partitioner == "MINCUT":
|
265
|
+
partitioner = ai_edge_torch.config.layout_optimize_partitioner
|
266
|
+
if partitioner in ("MINCUT", "OPTIMAL"):
|
268
267
|
graph_module = layout_partitioners.min_cut.partition(graph_module)
|
269
268
|
elif partitioner == "GREEDY":
|
270
269
|
graph_module = layout_partitioners.greedy.partition(graph_module)
|
@@ -81,7 +81,8 @@ class KVCache:
|
|
81
81
|
"""
|
82
82
|
caches = [
|
83
83
|
KVCacheEntry.from_model_config(
|
84
|
-
config.kv_cache_max
|
84
|
+
config.kv_cache_max if not config.block_config(idx).kv_cache_max_len
|
85
|
+
else config.block_config(idx).kv_cache_max_len,
|
85
86
|
config.block_config(idx).attn_config,
|
86
87
|
dtype,
|
87
88
|
device,
|
@@ -164,6 +164,9 @@ class TransformerBlockConfig:
|
|
164
164
|
parallel_residual: bool = False
|
165
165
|
# The Attention computation will include relative positional bias.
|
166
166
|
relative_attention: bool = False
|
167
|
+
# KV Cache length for this block. Only used when attention types are different
|
168
|
+
# across blocks
|
169
|
+
kv_cache_max_len: Optional[int] = None
|
167
170
|
|
168
171
|
|
169
172
|
@dataclasses.dataclass
|
@@ -200,7 +203,8 @@ class ModelConfig:
|
|
200
203
|
embedding_use_bias: bool = False
|
201
204
|
# Image embedding parameters.
|
202
205
|
image_embedding: Optional[ImageEmbeddingConfig] = None
|
203
|
-
|
206
|
+
# Number of image tokens
|
207
|
+
num_mm_tokens_per_image: Optional[int] = None
|
204
208
|
# Use bias term within LLM's HEAD.
|
205
209
|
lm_head_use_bias: bool = False
|
206
210
|
# Whether LLM's HEAD shares the weight of the embedding.
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
# Common utility functions for data loading etc.
|
16
|
+
from dataclasses import dataclass
|
17
|
+
import glob
|
18
|
+
import os
|
19
|
+
from typing import Sequence
|
20
|
+
from ai_edge_torch.odml_torch import lowerings
|
21
|
+
from jax._src.lib.mlir import ir
|
22
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
23
|
+
import torch
|
24
|
+
|
25
|
+
|
26
|
+
# Use torch.library.custom_op to define a new custom operator.
|
27
|
+
@torch.library.custom_op("ai_edge_torch::bmm_4d", mutates_args=())
|
28
|
+
def bmm_4d(
|
29
|
+
lhs: torch.Tensor,
|
30
|
+
rhs: torch.Tensor,
|
31
|
+
) -> torch.Tensor:
|
32
|
+
if not (lhs.ndim == 4 and rhs.ndim == 4):
|
33
|
+
raise ValueError("bmm_4d requires LHS and RHS have rank 4.")
|
34
|
+
d0_can_bcast = lhs.shape[0] == rhs.shape[0] or lhs.shape[0] == 1 or rhs.shape[0] == 1
|
35
|
+
d1_can_bcast = lhs.shape[1] == rhs.shape[1] or lhs.shape[1] == 1 or rhs.shape[1] == 1
|
36
|
+
if not (d0_can_bcast and d1_can_bcast):
|
37
|
+
raise ValueError("bmm_4d requires that dimensions 0 and 1 can broadcast.")
|
38
|
+
|
39
|
+
if not lhs.shape[-1] == rhs.shape[-1]:
|
40
|
+
raise ValueError("bmm_4d requires LHS and RHS have same last dimension.")
|
41
|
+
|
42
|
+
return torch.einsum("abcd,abed->abce", lhs, rhs)
|
43
|
+
|
44
|
+
|
45
|
+
# Use register_fake to add a ``FakeTensor`` kernel for the operator
|
46
|
+
@bmm_4d.register_fake
|
47
|
+
def _(lhs, rhs):
|
48
|
+
return torch.einsum("abcd,abed->abce", lhs, rhs)
|
49
|
+
|
50
|
+
|
51
|
+
@lowerings.lower(torch.ops.ai_edge_torch.bmm_4d)
|
52
|
+
def _bmm_4d_lower(
|
53
|
+
lctx,
|
54
|
+
lhs: ir.Value,
|
55
|
+
rhs: ir.Value,
|
56
|
+
):
|
57
|
+
dot_dnums = stablehlo.DotDimensionNumbers.get(
|
58
|
+
lhs_batching_dimensions=[0, 1],
|
59
|
+
rhs_batching_dimensions=[0, 1],
|
60
|
+
lhs_contracting_dimensions=(3,),
|
61
|
+
rhs_contracting_dimensions=(3,),
|
62
|
+
)
|
63
|
+
return stablehlo.dot_general(
|
64
|
+
ir.RankedTensorType.get(
|
65
|
+
(
|
66
|
+
lhs.type.shape[0],
|
67
|
+
lhs.type.shape[1],
|
68
|
+
lhs.type.shape[2],
|
69
|
+
rhs.type.shape[2],
|
70
|
+
),
|
71
|
+
lhs.type.element_type,
|
72
|
+
),
|
73
|
+
lhs,
|
74
|
+
rhs,
|
75
|
+
dot_dnums,
|
76
|
+
)
|
@@ -174,6 +174,9 @@ def _export_helper(
|
|
174
174
|
'input_pos': prefill_input_pos,
|
175
175
|
'kv_cache': kv,
|
176
176
|
}
|
177
|
+
if export_config.prefill_mask is not None:
|
178
|
+
sample_kwargs['mask'] = export_config.prefill_mask
|
179
|
+
|
177
180
|
if lora is not None:
|
178
181
|
prefill_signature_name += f'_lora_r{lora.get_rank()}'
|
179
182
|
sample_kwargs['lora'] = lora
|
@@ -199,6 +202,8 @@ def _export_helper(
|
|
199
202
|
'input_pos': decode_input_pos,
|
200
203
|
'kv_cache': kv,
|
201
204
|
}
|
205
|
+
if export_config.decode_mask is not None:
|
206
|
+
sample_kwargs['mask'] = export_config.decode_mask
|
202
207
|
if lora is not None:
|
203
208
|
sample_kwargs['lora'] = lora
|
204
209
|
|
@@ -55,6 +55,9 @@ class ExportConfig:
|
|
55
55
|
# On prefill signatures, should the model produce logit output?
|
56
56
|
# When False, only decode signatures will produce output.
|
57
57
|
output_logits_on_prefill: bool = False
|
58
|
+
# Attention masks given as inputs to the model.
|
59
|
+
prefill_mask: Optional[torch.Tensor] = None
|
60
|
+
decode_mask: Optional[torch.Tensor] = None
|
58
61
|
|
59
62
|
|
60
63
|
class DecoderOnlyModel(nn.Module):
|
@@ -21,8 +21,6 @@ from typing import Callable, Dict, List, NamedTuple, Optional
|
|
21
21
|
import torch
|
22
22
|
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
|
23
23
|
from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
|
24
|
-
from torch.ao.quantization.pt2e.utils import _conv1d_bn_example_inputs
|
25
|
-
from torch.ao.quantization.pt2e.utils import _conv2d_bn_example_inputs
|
26
24
|
from torch.ao.quantization.pt2e.utils import _get_aten_graph_module_for_pattern
|
27
25
|
from torch.ao.quantization.quantizer import QuantizationAnnotation
|
28
26
|
from torch.ao.quantization.quantizer import QuantizationSpec
|
@@ -47,6 +45,28 @@ __all__ = [
|
|
47
45
|
"propagate_annotation",
|
48
46
|
]
|
49
47
|
|
48
|
+
# Example inputs for conv-bn1d patterns
|
49
|
+
_conv1d_bn_example_inputs = (
|
50
|
+
torch.randn(1, 1, 3), # x
|
51
|
+
torch.randn(1, 1, 1), # conv_weight
|
52
|
+
torch.randn(1), # conv_bias
|
53
|
+
torch.randn(1), # bn_weight
|
54
|
+
torch.randn(1), # bn_bias
|
55
|
+
torch.randn(1), # bn_running_mean
|
56
|
+
torch.randn(1), # bn_running_var
|
57
|
+
)
|
58
|
+
|
59
|
+
# Example inputs for conv-bn2d patterns
|
60
|
+
_conv2d_bn_example_inputs = (
|
61
|
+
torch.randn(1, 1, 3, 3), # x
|
62
|
+
torch.randn(1, 1, 1, 1), # conv_weight
|
63
|
+
torch.randn(1), # conv_bias
|
64
|
+
torch.randn(1), # bn_weight
|
65
|
+
torch.randn(1), # bn_bias
|
66
|
+
torch.randn(1), # bn_running_mean
|
67
|
+
torch.randn(1), # bn_running_var
|
68
|
+
)
|
69
|
+
|
50
70
|
|
51
71
|
@dataclass(eq=True, frozen=True)
|
52
72
|
class QuantizationConfig:
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250124
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -1,8 +1,8 @@
|
|
1
1
|
ai_edge_torch/__init__.py,sha256=8sPR_5uXJA4NEE0nIwNdSl-ADOJEoR8hAgYvBQDY70Y,1208
|
2
|
-
ai_edge_torch/_config.py,sha256=
|
2
|
+
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=xloGd_dX0MD8k-quT07WLlEN1zIGVtCKu6xBSvjofrc,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=mckvxznKLXdF2HuJg_IxQaT5Ty-iWl_iXElHEugH3VI,5452
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -16,11 +16,11 @@ ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=Z6E3U7SYZv
|
|
16
16
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=HCOkj0k3NhaYbtfjE8HDXVmYhZ9fL5V_u6VunVh9mN4,2116
|
17
17
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=UKC-wM93-oe8spxyFqgybJ0TwnSRw8f-SOA2glCh2FA,890
|
18
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/_decomp_registry.py,sha256=aWO_zHDF4j_hokoKJQNFIFmua4ysXztsgS6pcyBUht0,1082
|
19
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=
|
19
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=7yEKSfXskXUk4tsd7c8vL155O-iU4eUjXCU5RSZHrbw,8204
|
20
20
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
|
21
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=
|
21
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=OCFcPP618zH8IE12KTBQm2hRTtsaSeO3egvlOBUpNxA,13911
|
22
22
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=OhisegHY2j4cv_m9auCh9Mq9qmm1lUqpFLVO9X-oBlc,1032
|
23
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=
|
23
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=mr0MiLbaQmU-3S3KT-vb58kRWbNT3VJiCKY-K7_3tFg,10556
|
24
24
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
|
25
25
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/__init__.py,sha256=D8VX8SbCzfoyvPgMFHK7yxD7R-bzLxp2gfdKxgrWekA,742
|
26
26
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
|
@@ -128,9 +128,9 @@ ai_edge_torch/generative/layers/attention.py,sha256=GrAy8CT1pEsgRoB8JQP6PlnNYk8k
|
|
128
128
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
129
129
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
130
130
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
131
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
131
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=sGGAZD0mWYuO4FukZfDbHXoxpBOBE9lTYICvZzDj5F8,6400
|
132
132
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
133
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
133
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=ZVRWEGw1BnLbLCuoR71kWGqQteKp-UM1YvMbbWYlkNw,7999
|
134
134
|
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
135
135
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
136
136
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=vp8dVx6tOe99neJhpbrtIt5fvN5NFw19JVH1v0yi5Mg,4154
|
@@ -155,10 +155,11 @@ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=yzMgXkiZxHUF
|
|
155
155
|
ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
|
156
156
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
157
157
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
158
|
-
ai_edge_torch/generative/utilities/
|
158
|
+
ai_edge_torch/generative/utilities/bmm_4d.py,sha256=2BMOYiFVUsl-bjxmLkrX4N7kpO0CnhB7eDYxm_iBCr8,2533
|
159
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=QIYxT-zATMzsD3LG-keRkxpJqDKXkbil4Se1KXthWFg,7726
|
159
160
|
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
160
161
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
161
|
-
ai_edge_torch/generative/utilities/model_builder.py,sha256=
|
162
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=aXigoFEMLAKk7HQuWJM5ILs3igA4z2VH64ZCzCuBhDE,6671
|
162
163
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
163
164
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
164
165
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
@@ -197,7 +198,7 @@ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw5
|
|
197
198
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
198
199
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
199
200
|
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=VhmeGFnB5hrUsALiVWV96JJOqPDrTIWouHjTvLuT5eU,2477
|
200
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
201
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=fte81SZxgxeMcI3wWVKSTnUjIxVVilOJ6H3TybXyDmQ,11558
|
201
202
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
202
203
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
203
204
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
@@ -207,13 +208,13 @@ ai_edge_torch/odml_torch/lowerings/utils.py,sha256=pqM6mumpviFDHRaabp93CUAngzEZm
|
|
207
208
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
208
209
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
209
210
|
ai_edge_torch/quantize/pt2e_quantizer.py,sha256=CKIEhs9jCcna64qj1jFH9zEbMbRdyeGV_TmSqEBPjes,15741
|
210
|
-
ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=
|
211
|
+
ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=nuO3w9gOj9sKcsTBBexVDw8UZnd06KsjNrFr_gyNaiA,36710
|
211
212
|
ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9PphCRdO8o,3172
|
212
213
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
213
214
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
214
215
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
215
|
-
ai_edge_torch_nightly-0.3.0.
|
216
|
-
ai_edge_torch_nightly-0.3.0.
|
217
|
-
ai_edge_torch_nightly-0.3.0.
|
218
|
-
ai_edge_torch_nightly-0.3.0.
|
219
|
-
ai_edge_torch_nightly-0.3.0.
|
216
|
+
ai_edge_torch_nightly-0.3.0.dev20250124.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
217
|
+
ai_edge_torch_nightly-0.3.0.dev20250124.dist-info/METADATA,sha256=WKVqBXJtXvMv3JfqtYKcl1GFgKrKtSbZ-tJAol5PPHk,1966
|
218
|
+
ai_edge_torch_nightly-0.3.0.dev20250124.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
219
|
+
ai_edge_torch_nightly-0.3.0.dev20250124.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
220
|
+
ai_edge_torch_nightly-0.3.0.dev20250124.dist-info/RECORD,,
|
File without changes
|
File without changes
|