tico 0.1.0.dev250714__py3-none-any.whl → 0.1.0.dev251102__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tico/__init__.py +9 -1
- tico/config/base.py +1 -1
- tico/config/v1.py +5 -0
- tico/passes/cast_aten_where_arg_type.py +1 -1
- tico/passes/cast_clamp_mixed_type_args.py +169 -0
- tico/passes/cast_mixed_type_args.py +4 -2
- tico/passes/const_prop_pass.py +1 -1
- tico/passes/convert_conv1d_to_conv2d.py +1 -1
- tico/passes/convert_expand_to_slice_cat.py +153 -0
- tico/passes/convert_matmul_to_linear.py +312 -0
- tico/passes/convert_to_relu6.py +1 -1
- tico/passes/decompose_addmm.py +0 -3
- tico/passes/decompose_batch_norm.py +2 -2
- tico/passes/decompose_fake_quantize.py +0 -3
- tico/passes/decompose_fake_quantize_tensor_qparams.py +5 -6
- tico/passes/decompose_group_norm.py +0 -3
- tico/passes/legalize_predefined_layout_operators.py +2 -11
- tico/passes/lower_to_resize_nearest_neighbor.py +1 -1
- tico/passes/lower_to_slice.py +1 -1
- tico/passes/merge_consecutive_cat.py +1 -1
- tico/passes/ops.py +1 -1
- tico/passes/remove_redundant_assert_nodes.py +3 -1
- tico/passes/remove_redundant_expand.py +3 -6
- tico/passes/remove_redundant_reshape.py +5 -5
- tico/passes/segment_index_select.py +1 -1
- tico/quantization/__init__.py +6 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/gptq.py +1 -1
- tico/quantization/algorithm/gptq/quantizer.py +292 -0
- tico/{experimental/quantization → quantization}/algorithm/gptq/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/annotator.py +7 -14
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/adaptive_avg_pool2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/add.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/conv2d.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/div.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/linear.py +5 -7
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mean.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/mul.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/relu6.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/rsqrt.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/sub.py +4 -6
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/spec.py +1 -3
- tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/utils.py +1 -1
- tico/{experimental/quantization → quantization}/algorithm/pt2e/quantizer.py +5 -2
- tico/{experimental/quantization → quantization}/algorithm/pt2e/utils.py +1 -4
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/observer.py +26 -8
- tico/{experimental/quantization → quantization}/algorithm/smoothquant/quantizer.py +28 -9
- tico/quantization/algorithm/smoothquant/smooth_quant.py +327 -0
- tico/quantization/config/base.py +26 -0
- tico/quantization/config/gptq.py +29 -0
- tico/quantization/config/pt2e.py +25 -0
- tico/quantization/config/ptq.py +119 -0
- tico/{experimental/quantization/config.py → quantization/config/smoothquant.py} +9 -36
- tico/{experimental/quantization → quantization}/evaluation/evaluate.py +8 -17
- tico/{experimental/quantization → quantization}/evaluation/executor/circle_executor.py +3 -4
- tico/{experimental/quantization → quantization}/evaluation/executor/triv24_executor.py +2 -4
- tico/quantization/evaluation/metric.py +146 -0
- tico/{experimental/quantization → quantization}/evaluation/utils.py +1 -1
- tico/quantization/passes/__init__.py +1 -0
- tico/{experimental/quantization → quantization}/passes/fold_quant_ops.py +0 -1
- tico/quantization/passes/insert_quantize_on_dtype_mismatch.py +459 -0
- tico/{experimental/quantization → quantization}/passes/quantize_bias.py +0 -1
- tico/{experimental/quantization → quantization}/passes/remove_weight_dequant_op.py +1 -1
- tico/{experimental/quantization → quantization}/public_interface.py +19 -18
- tico/{experimental/quantization → quantization}/quantizer.py +1 -1
- tico/quantization/quantizer_registry.py +73 -0
- tico/quantization/wrapq/__init__.py +1 -0
- tico/quantization/wrapq/dtypes.py +70 -0
- tico/quantization/wrapq/examples/__init__.py +1 -0
- tico/quantization/wrapq/examples/compare_ppl.py +230 -0
- tico/quantization/wrapq/examples/debug_quant_outputs.py +224 -0
- tico/quantization/wrapq/examples/quantize_linear.py +107 -0
- tico/quantization/wrapq/examples/quantize_llama_attn.py +101 -0
- tico/quantization/wrapq/examples/quantize_llama_decoder_layer.py +125 -0
- tico/quantization/wrapq/examples/quantize_llama_mlp.py +95 -0
- tico/quantization/wrapq/examples/quantize_with_gptq.py +265 -0
- tico/quantization/wrapq/mode.py +32 -0
- tico/quantization/wrapq/observers/__init__.py +1 -0
- tico/quantization/wrapq/observers/affine_base.py +128 -0
- tico/quantization/wrapq/observers/base.py +98 -0
- tico/quantization/wrapq/observers/ema.py +62 -0
- tico/quantization/wrapq/observers/identity.py +74 -0
- tico/quantization/wrapq/observers/minmax.py +39 -0
- tico/quantization/wrapq/observers/mx.py +60 -0
- tico/quantization/wrapq/qscheme.py +40 -0
- tico/quantization/wrapq/quantizer.py +179 -0
- tico/quantization/wrapq/utils/__init__.py +1 -0
- tico/quantization/wrapq/utils/introspection.py +167 -0
- tico/quantization/wrapq/utils/metrics.py +124 -0
- tico/quantization/wrapq/utils/reduce_utils.py +25 -0
- tico/quantization/wrapq/wrappers/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/fairseq/__init__.py +5 -0
- tico/quantization/wrapq/wrappers/fairseq/decoder_export_single_step.py +234 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder.py +429 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_decoder_layer.py +492 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder.py +331 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_encoder_layer.py +163 -0
- tico/quantization/wrapq/wrappers/fairseq/quant_mha.py +381 -0
- tico/quantization/wrapq/wrappers/llama/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/llama/quant_attn.py +276 -0
- tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +176 -0
- tico/quantization/wrapq/wrappers/llama/quant_mlp.py +96 -0
- tico/quantization/wrapq/wrappers/nn/__init__.py +1 -0
- tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +183 -0
- tico/quantization/wrapq/wrappers/nn/quant_linear.py +65 -0
- tico/quantization/wrapq/wrappers/nn/quant_silu.py +59 -0
- tico/quantization/wrapq/wrappers/ptq_wrapper.py +69 -0
- tico/quantization/wrapq/wrappers/quant_elementwise.py +111 -0
- tico/quantization/wrapq/wrappers/quant_module_base.py +168 -0
- tico/quantization/wrapq/wrappers/registry.py +125 -0
- tico/serialize/circle_graph.py +12 -4
- tico/serialize/circle_mapping.py +76 -2
- tico/serialize/circle_serializer.py +253 -148
- tico/serialize/operators/adapters/__init__.py +1 -0
- tico/serialize/operators/adapters/llama_rmsnorm.py +35 -0
- tico/serialize/operators/op_any.py +7 -14
- tico/serialize/operators/op_avg_pool2d.py +11 -4
- tico/serialize/operators/op_clamp.py +5 -7
- tico/serialize/operators/op_constant_pad_nd.py +41 -11
- tico/serialize/operators/op_conv2d.py +14 -6
- tico/serialize/operators/op_copy.py +26 -3
- tico/serialize/operators/op_cumsum.py +3 -1
- tico/serialize/operators/op_depthwise_conv2d.py +17 -7
- tico/serialize/operators/op_full_like.py +0 -2
- tico/serialize/operators/op_index_select.py +8 -1
- tico/serialize/operators/op_instance_norm.py +0 -6
- tico/serialize/operators/op_le.py +54 -0
- tico/serialize/operators/op_log1p.py +3 -2
- tico/serialize/operators/op_max_pool2d_with_indices.py +17 -7
- tico/serialize/operators/op_mm.py +15 -131
- tico/serialize/operators/op_mul.py +2 -8
- tico/serialize/operators/op_pow.py +3 -1
- tico/serialize/operators/op_repeat.py +12 -3
- tico/serialize/operators/op_reshape.py +1 -1
- tico/serialize/operators/op_rmsnorm.py +65 -0
- tico/serialize/operators/op_softmax.py +7 -14
- tico/serialize/operators/op_split_with_sizes.py +16 -8
- tico/serialize/operators/op_transpose_conv.py +11 -8
- tico/serialize/operators/op_view.py +2 -1
- tico/serialize/quant_param.py +5 -5
- tico/utils/convert.py +30 -17
- tico/utils/dtype.py +42 -0
- tico/utils/graph.py +1 -1
- tico/utils/model.py +2 -1
- tico/utils/padding.py +2 -2
- tico/utils/pytree_utils.py +134 -0
- tico/utils/record_input.py +102 -0
- tico/utils/register_custom_op.py +29 -4
- tico/utils/serialize.py +16 -3
- tico/utils/signature.py +247 -0
- tico/utils/torch_compat.py +52 -0
- tico/utils/utils.py +50 -58
- tico/utils/validate_args_kwargs.py +38 -3
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/METADATA +49 -2
- tico-0.1.0.dev251102.dist-info/RECORD +271 -0
- tico/experimental/quantization/__init__.py +0 -1
- tico/experimental/quantization/algorithm/gptq/quantizer.py +0 -225
- tico/experimental/quantization/algorithm/smoothquant/smooth_quant.py +0 -164
- tico/experimental/quantization/evaluation/metric.py +0 -109
- tico/experimental/quantization/passes/insert_quantize_on_dtype_mismatch.py +0 -437
- tico-0.1.0.dev250714.dist-info/RECORD +0 -209
- /tico/{experimental/quantization → quantization}/algorithm/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/gptq/quant.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/config.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/annotation/op/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/pt2e/transformation/convert_scalars_to_attrs.py +0 -0
- /tico/{experimental/quantization → quantization}/algorithm/smoothquant/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation → quantization/config}/__init__.py +0 -0
- /tico/{experimental/quantization/evaluation/executor → quantization/evaluation}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/backend.py +0 -0
- /tico/{experimental/quantization/passes → quantization/evaluation/executor}/__init__.py +0 -0
- /tico/{experimental/quantization → quantization}/evaluation/executor/backend_executor.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_backward.py +0 -0
- /tico/{experimental/quantization → quantization}/passes/propagate_qparam_forward.py +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/LICENSE +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/WHEEL +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/entry_points.txt +0 -0
- {tico-0.1.0.dev250714.dist-info → tico-0.1.0.dev251102.dist-info}/top_level.txt +0 -0
|
@@ -21,7 +21,11 @@ import torch
|
|
|
21
21
|
from circle_schema import circle
|
|
22
22
|
|
|
23
23
|
from tico.serialize.circle_graph import CircleSubgraph
|
|
24
|
-
from tico.serialize.circle_mapping import
|
|
24
|
+
from tico.serialize.circle_mapping import (
|
|
25
|
+
extract_circle_dtype,
|
|
26
|
+
extract_shape,
|
|
27
|
+
to_circle_shape,
|
|
28
|
+
)
|
|
25
29
|
from tico.serialize.operators.hashable_opcode import OpCode
|
|
26
30
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
|
27
31
|
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
|
@@ -70,11 +74,16 @@ class RepeatVisitor(NodeVisitor):
|
|
|
70
74
|
if r > 1:
|
|
71
75
|
# Except last created concat, a tensor should be created.
|
|
72
76
|
if repeat_dim_cnt > 1:
|
|
73
|
-
repeated_shape = list(tensor_shape)
|
|
77
|
+
repeated_shape: List[int | torch.SymInt] = list(tensor_shape)
|
|
74
78
|
repeated_shape[idx] = repeated_shape[idx] * r
|
|
79
|
+
|
|
80
|
+
repeated_cshape, repeated_cshape_signature = to_circle_shape(
|
|
81
|
+
repeated_shape
|
|
82
|
+
)
|
|
75
83
|
concat_output = self.graph.add_tensor_from_scratch(
|
|
76
84
|
prefix=f"{node.name}_concat_{idx}",
|
|
77
|
-
shape=
|
|
85
|
+
shape=repeated_cshape,
|
|
86
|
+
shape_signature=repeated_cshape_signature,
|
|
78
87
|
dtype=tensor_dtype,
|
|
79
88
|
source_node=node,
|
|
80
89
|
)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Dict, List, TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
import torch._ops
|
|
19
|
+
import torch.fx
|
|
20
|
+
import torch
|
|
21
|
+
from circle_schema import circle
|
|
22
|
+
|
|
23
|
+
from tico.serialize.circle_graph import CircleSubgraph
|
|
24
|
+
from tico.serialize.operators.hashable_opcode import OpCode
|
|
25
|
+
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
|
26
|
+
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
|
27
|
+
from tico.utils.validate_args_kwargs import CircleRMSNormArgs
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@register_node_visitor
|
|
31
|
+
class RMSNormVisitor(NodeVisitor):
|
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
|
33
|
+
torch.ops.circle_custom.rms_norm.default,
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
|
|
37
|
+
super().__init__(op_codes, graph)
|
|
38
|
+
|
|
39
|
+
def define_node(
|
|
40
|
+
self,
|
|
41
|
+
node: torch.fx.Node,
|
|
42
|
+
) -> circle.Operator.OperatorT:
|
|
43
|
+
args = CircleRMSNormArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
|
|
44
|
+
input = args.input
|
|
45
|
+
weight = args.weight
|
|
46
|
+
eps = args.eps
|
|
47
|
+
|
|
48
|
+
op_index = get_op_index(
|
|
49
|
+
circle.BuiltinOperator.BuiltinOperator.RMS_NORM, self._op_codes
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
inputs = [input, weight]
|
|
53
|
+
outputs = [node]
|
|
54
|
+
operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
|
|
55
|
+
|
|
56
|
+
# Op-specific option
|
|
57
|
+
operator.builtinOptionsType = (
|
|
58
|
+
circle.BuiltinOptions.BuiltinOptions.RmsNormOptions
|
|
59
|
+
)
|
|
60
|
+
option = circle.RmsNormOptions.RmsNormOptionsT()
|
|
61
|
+
option.epsilon = eps
|
|
62
|
+
|
|
63
|
+
operator.builtinOptions = option
|
|
64
|
+
|
|
65
|
+
return operator
|
|
@@ -24,25 +24,18 @@ from tico.serialize.operators.hashable_opcode import OpCode
|
|
|
24
24
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
|
25
25
|
from tico.serialize.operators.utils import create_builtin_operator, get_op_index
|
|
26
26
|
from tico.utils.errors import NotYetSupportedError
|
|
27
|
-
from tico.utils.utils import HAS_TORCH_OVER_25
|
|
28
27
|
from tico.utils.validate_args_kwargs import SafeSoftmaxArgs, SoftmaxArgs
|
|
29
28
|
|
|
30
29
|
|
|
31
30
|
@register_node_visitor
|
|
32
31
|
class SoftMaxVisitor(NodeVisitor):
|
|
33
|
-
target: List[torch._ops.OpOverload] =
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
]
|
|
41
|
-
if HAS_TORCH_OVER_25
|
|
42
|
-
else [
|
|
43
|
-
torch.ops.aten._softmax.default,
|
|
44
|
-
]
|
|
45
|
-
)
|
|
32
|
+
target: List[torch._ops.OpOverload] = [
|
|
33
|
+
torch.ops.aten._softmax.default,
|
|
34
|
+
# NOTE: Let's treat _safe_softmax as normal _softmax as its usage is for training.
|
|
35
|
+
# In order for optimization during inference, it can be replaced to softmax.
|
|
36
|
+
# ref: https://github.com/pytorch/pytorch/pull/133882
|
|
37
|
+
torch.ops.aten._safe_softmax.default,
|
|
38
|
+
]
|
|
46
39
|
|
|
47
40
|
def __init__(self, op_codes: Dict[OpCode, int], graph):
|
|
48
41
|
super().__init__(op_codes, graph)
|
|
@@ -58,12 +58,14 @@ class SplitWithSizesVisitor(NodeVisitor):
|
|
|
58
58
|
inputs = [input, split_sizes_i32, axis_i32]
|
|
59
59
|
|
|
60
60
|
"""
|
|
61
|
-
`split_with_sizes` has multiple output tensors
|
|
62
|
-
|
|
61
|
+
`split_with_sizes` has multiple output tensors along with `getitem`.
|
|
62
|
+
Unlike other ops, node itself doesn't become a circle tensor. Instead, each `getitem` will be
|
|
63
63
|
a circle tensor.
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
64
|
+
|
|
65
|
+
torch module having `split_with_sizes` may return selected outputs by using `getitem`.
|
|
66
|
+
However, one-compiler assumes that `CircleSplitV` always have all outputs.
|
|
67
|
+
|
|
68
|
+
So, let's add unused output tensors to compensate this restriction.
|
|
67
69
|
"""
|
|
68
70
|
outputs: List[Union[circle.Tensor.TensorT, torch.fx.node.Node]] = []
|
|
69
71
|
sorted_users = sorted(node.users.keys(), key=lambda x: x.args[1]) # type: ignore[arg-type, return-value]
|
|
@@ -80,11 +82,17 @@ class SplitWithSizesVisitor(NodeVisitor):
|
|
|
80
82
|
fake_tensor = node_val[idx]
|
|
81
83
|
assert isinstance(fake_tensor, FakeTensor)
|
|
82
84
|
shape = list(fake_tensor.size())
|
|
85
|
+
|
|
86
|
+
if any(isinstance(s, torch.SymInt) for s in shape):
|
|
87
|
+
# TODO: support dynamic shape
|
|
88
|
+
raise NotImplementedError("Dynamic shape is not supported yet.")
|
|
89
|
+
|
|
83
90
|
dtype = to_circle_dtype(fake_tensor.dtype)
|
|
84
91
|
tensor = self.graph.add_tensor_from_scratch(
|
|
85
|
-
f"{node.name}_unused_{idx}",
|
|
86
|
-
shape,
|
|
87
|
-
|
|
92
|
+
prefix=f"{node.name}_unused_{idx}",
|
|
93
|
+
shape=shape,
|
|
94
|
+
shape_signature=None, # TODO: support dynamic shape
|
|
95
|
+
dtype=dtype,
|
|
88
96
|
source_node=node,
|
|
89
97
|
)
|
|
90
98
|
outputs.append(tensor)
|
|
@@ -23,7 +23,8 @@ from circle_schema import circle
|
|
|
23
23
|
from tico.serialize.circle_mapping import (
|
|
24
24
|
circle_legalize_dtype_to,
|
|
25
25
|
extract_circle_dtype,
|
|
26
|
-
|
|
26
|
+
extract_circle_shape,
|
|
27
|
+
to_circle_shape,
|
|
27
28
|
)
|
|
28
29
|
from tico.serialize.operators.hashable_opcode import OpCode
|
|
29
30
|
from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
|
|
@@ -76,15 +77,13 @@ class TransposeConvVisitor(NodeVisitor):
|
|
|
76
77
|
bias = args.bias
|
|
77
78
|
stride = args.stride
|
|
78
79
|
padding = args.padding
|
|
79
|
-
output_padding = args.output_padding
|
|
80
80
|
groups = args.groups
|
|
81
|
-
dilation = args.dilation
|
|
82
81
|
|
|
83
82
|
assert groups == 1, "Only support group 1"
|
|
84
83
|
|
|
85
|
-
input_shape =
|
|
86
|
-
output_shape =
|
|
87
|
-
weight_shape =
|
|
84
|
+
input_shape, input_shape_signature = extract_circle_shape(input_)
|
|
85
|
+
output_shape, _ = extract_circle_shape(node)
|
|
86
|
+
weight_shape, _ = extract_circle_shape(weight)
|
|
88
87
|
assert len(input_shape) == 4, len(input_shape)
|
|
89
88
|
assert len(output_shape) == 4, len(output_shape)
|
|
90
89
|
assert len(weight_shape) == 4, len(weight_shape)
|
|
@@ -103,17 +102,21 @@ class TransposeConvVisitor(NodeVisitor):
|
|
|
103
102
|
],
|
|
104
103
|
dtype=torch.int32,
|
|
105
104
|
)
|
|
106
|
-
pad_output_shape = [
|
|
105
|
+
pad_output_shape: List[int | torch.SymInt] = [
|
|
107
106
|
input_shape[0],
|
|
108
107
|
input_shape[1] + pad_h * 2,
|
|
109
108
|
input_shape[2] + pad_w * 2,
|
|
110
109
|
input_shape[3],
|
|
111
110
|
]
|
|
111
|
+
pad_output_cshape, pad_output_cshape_signature = to_circle_shape(
|
|
112
|
+
pad_output_shape
|
|
113
|
+
)
|
|
112
114
|
# create padded output tensor
|
|
113
115
|
input_qparam: Optional[QuantParam] = input_.meta.get(QPARAM_KEY)
|
|
114
116
|
pad_output = self.graph.add_tensor_from_scratch(
|
|
115
117
|
prefix=f"{node.name}_input_pad_output",
|
|
116
|
-
shape=
|
|
118
|
+
shape=pad_output_cshape,
|
|
119
|
+
shape_signature=pad_output_cshape_signature,
|
|
117
120
|
dtype=extract_circle_dtype(input_),
|
|
118
121
|
qparam=input_qparam,
|
|
119
122
|
source_node=node,
|
|
@@ -56,6 +56,7 @@ class ViewVisitor(NodeVisitor):
|
|
|
56
56
|
if isinstance(size, int):
|
|
57
57
|
raise Exception("scalar size conversion is not supported yet.")
|
|
58
58
|
|
|
59
|
+
# TODO: support dynamic shape
|
|
59
60
|
size_i32 = circle_legalize_dtype_to(size, dtype=torch.int32)
|
|
60
61
|
inputs = [input, size_i32]
|
|
61
62
|
outputs = [node]
|
|
@@ -67,7 +68,7 @@ class ViewVisitor(NodeVisitor):
|
|
|
67
68
|
circle.BuiltinOptions.BuiltinOptions.ReshapeOptions
|
|
68
69
|
)
|
|
69
70
|
option = circle.ReshapeOptions.ReshapeOptionsT()
|
|
70
|
-
option.newShape = size_i32
|
|
71
|
+
option.newShape = size_i32.tolist()
|
|
71
72
|
|
|
72
73
|
operator.builtinOptions = option
|
|
73
74
|
|
tico/serialize/quant_param.py
CHANGED
|
@@ -12,6 +12,11 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from typing import List, Optional
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
|
|
15
20
|
"""
|
|
16
21
|
This is a key for torch.fx.Node's meta dict to save QuantParam
|
|
17
22
|
|
|
@@ -19,11 +24,6 @@ QuantParam can be retrieved as node.meta[QPARAM_KEY]
|
|
|
19
24
|
"""
|
|
20
25
|
QPARAM_KEY = "_quantization_parameters_"
|
|
21
26
|
|
|
22
|
-
from dataclasses import dataclass
|
|
23
|
-
from typing import List, Optional
|
|
24
|
-
|
|
25
|
-
import torch
|
|
26
|
-
|
|
27
27
|
|
|
28
28
|
@dataclass
|
|
29
29
|
class QuantParam:
|
tico/utils/convert.py
CHANGED
|
@@ -20,25 +20,14 @@ import torch
|
|
|
20
20
|
from torch.export import export, ExportedProgram
|
|
21
21
|
|
|
22
22
|
from tico.config import CompileConfigBase, get_default_config
|
|
23
|
-
from tico.experimental.quantization.passes.fold_quant_ops import FoldQuantOps
|
|
24
|
-
from tico.experimental.quantization.passes.insert_quantize_on_dtype_mismatch import (
|
|
25
|
-
InsertQuantizeOnDtypeMismatch,
|
|
26
|
-
)
|
|
27
|
-
from tico.experimental.quantization.passes.propagate_qparam_backward import (
|
|
28
|
-
PropagateQParamBackward,
|
|
29
|
-
)
|
|
30
|
-
from tico.experimental.quantization.passes.propagate_qparam_forward import (
|
|
31
|
-
PropagateQParamForward,
|
|
32
|
-
)
|
|
33
|
-
from tico.experimental.quantization.passes.quantize_bias import QuantizeBias
|
|
34
|
-
from tico.experimental.quantization.passes.remove_weight_dequant_op import (
|
|
35
|
-
RemoveWeightDequantOp,
|
|
36
|
-
)
|
|
37
23
|
from tico.passes.cast_aten_where_arg_type import CastATenWhereArgType
|
|
24
|
+
from tico.passes.cast_clamp_mixed_type_args import CastClampMixedTypeArgs
|
|
38
25
|
from tico.passes.cast_mixed_type_args import CastMixedTypeArgs
|
|
39
26
|
from tico.passes.const_prop_pass import ConstPropPass
|
|
40
27
|
from tico.passes.convert_conv1d_to_conv2d import ConvertConv1dToConv2d
|
|
28
|
+
from tico.passes.convert_expand_to_slice_cat import ConvertExpandToSliceCat
|
|
41
29
|
from tico.passes.convert_layout_op_to_reshape import ConvertLayoutOpToReshape
|
|
30
|
+
from tico.passes.convert_matmul_to_linear import ConvertMatmulToLinear
|
|
42
31
|
from tico.passes.convert_repeat_to_expand_copy import ConvertRepeatToExpandCopy
|
|
43
32
|
from tico.passes.convert_to_relu6 import ConvertToReLU6
|
|
44
33
|
from tico.passes.decompose_addmm import DecomposeAddmm
|
|
@@ -71,6 +60,14 @@ from tico.passes.remove_redundant_slice import RemoveRedundantSlice
|
|
|
71
60
|
from tico.passes.remove_redundant_to_copy import RemoveRedundantToCopy
|
|
72
61
|
from tico.passes.restore_linear import RestoreLinear
|
|
73
62
|
from tico.passes.segment_index_select import SegmentIndexSelectConst
|
|
63
|
+
from tico.quantization.passes.fold_quant_ops import FoldQuantOps
|
|
64
|
+
from tico.quantization.passes.insert_quantize_on_dtype_mismatch import (
|
|
65
|
+
InsertQuantizeOnDtypeMismatch,
|
|
66
|
+
)
|
|
67
|
+
from tico.quantization.passes.propagate_qparam_backward import PropagateQParamBackward
|
|
68
|
+
from tico.quantization.passes.propagate_qparam_forward import PropagateQParamForward
|
|
69
|
+
from tico.quantization.passes.quantize_bias import QuantizeBias
|
|
70
|
+
from tico.quantization.passes.remove_weight_dequant_op import RemoveWeightDequantOp
|
|
74
71
|
from tico.serialize.circle_serializer import build_circle
|
|
75
72
|
from tico.serialize.operators.node_visitor import get_support_targets
|
|
76
73
|
from tico.utils import logging
|
|
@@ -105,6 +102,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
|
|
|
105
102
|
torch.ops.aten._safe_softmax.default,
|
|
106
103
|
torch.ops.aten.relu6.default, # Do not decompose to hardtanh
|
|
107
104
|
torch.ops.aten.linear.default,
|
|
105
|
+
torch.ops.aten.upsample_nearest2d.vec,
|
|
108
106
|
)
|
|
109
107
|
ep = ep.run_decompositions(_preserve_ops=_preserve_ops)
|
|
110
108
|
|
|
@@ -123,6 +121,7 @@ def traced_run_decompositions(exported_program: ExportedProgram):
|
|
|
123
121
|
torch.ops.aten.relu6.default, # Do not decompose to hardtanh
|
|
124
122
|
torch.ops.aten.prelu.default,
|
|
125
123
|
torch.ops.aten.linear.default,
|
|
124
|
+
torch.ops.aten.upsample_nearest2d.vec,
|
|
126
125
|
)
|
|
127
126
|
for op in _preserve_ops:
|
|
128
127
|
if op in _decomp_table:
|
|
@@ -137,6 +136,8 @@ def traced_run_decompositions(exported_program: ExportedProgram):
|
|
|
137
136
|
torch.__version__.startswith("2.6")
|
|
138
137
|
or torch.__version__.startswith("2.7")
|
|
139
138
|
or torch.__version__.startswith("2.8")
|
|
139
|
+
or torch.__version__.startswith("2.9")
|
|
140
|
+
or torch.__version__.startswith("2.10")
|
|
140
141
|
):
|
|
141
142
|
return run_decompositions(exported_program)
|
|
142
143
|
else:
|
|
@@ -153,7 +154,7 @@ def check_unsupported_target(exported_program: ExportedProgram):
|
|
|
153
154
|
for n in exported_program.graph.nodes:
|
|
154
155
|
if n.op != "call_function":
|
|
155
156
|
continue
|
|
156
|
-
if
|
|
157
|
+
if n.target not in supported_target:
|
|
157
158
|
unsupported.append(n)
|
|
158
159
|
|
|
159
160
|
if unsupported:
|
|
@@ -245,12 +246,21 @@ def convert_exported_module_to_circle(
|
|
|
245
246
|
ConstPropPass(),
|
|
246
247
|
SegmentIndexSelectConst(),
|
|
247
248
|
LegalizeCausalMaskValue(enabled=config.get("legalize_causal_mask_value")),
|
|
249
|
+
ConvertExpandToSliceCat(enabled=config.get("convert_expand_to_slice_cat")),
|
|
250
|
+
ConvertMatmulToLinear(
|
|
251
|
+
enable_lhs_const=config.get("convert_lhs_const_mm_to_fc"),
|
|
252
|
+
enable_rhs_const=config.get("convert_rhs_const_mm_to_fc"),
|
|
253
|
+
enable_single_batch_lhs_const_bmm=config.get(
|
|
254
|
+
"convert_single_batch_lhs_const_bmm_to_fc"
|
|
255
|
+
),
|
|
256
|
+
),
|
|
248
257
|
LowerToResizeNearestNeighbor(),
|
|
249
258
|
LegalizePreDefinedLayoutOperators(),
|
|
250
259
|
LowerPow2ToMul(),
|
|
251
260
|
ConvertConv1dToConv2d(),
|
|
252
261
|
*LowerToSlicePasses(),
|
|
253
262
|
FuseLeadingUnsqueezeReshape(),
|
|
263
|
+
CastClampMixedTypeArgs(),
|
|
254
264
|
]
|
|
255
265
|
)
|
|
256
266
|
circle_legalize.run(exported_program)
|
|
@@ -282,7 +292,7 @@ def convert_exported_module_to_circle(
|
|
|
282
292
|
|
|
283
293
|
check_unsupported_target(exported_program)
|
|
284
294
|
check_training_ops(exported_program)
|
|
285
|
-
circle_program = build_circle(exported_program)
|
|
295
|
+
circle_program = build_circle(exported_program, config)
|
|
286
296
|
|
|
287
297
|
return circle_program
|
|
288
298
|
|
|
@@ -291,6 +301,7 @@ def convert(
|
|
|
291
301
|
mod: torch.nn.Module,
|
|
292
302
|
args: Tuple[Any, ...],
|
|
293
303
|
kwargs: Optional[Dict[str, Any]] = None,
|
|
304
|
+
dynamic_shapes: Optional[dict] = None,
|
|
294
305
|
strict: bool = True,
|
|
295
306
|
config: CompileConfigBase = get_default_config(),
|
|
296
307
|
) -> CircleModel:
|
|
@@ -301,7 +312,9 @@ def convert(
|
|
|
301
312
|
)
|
|
302
313
|
|
|
303
314
|
with torch.no_grad():
|
|
304
|
-
exported_program = export(
|
|
315
|
+
exported_program = export(
|
|
316
|
+
mod, args, kwargs, dynamic_shapes=dynamic_shapes, strict=strict
|
|
317
|
+
)
|
|
305
318
|
|
|
306
319
|
circle_binary = convert_exported_module_to_circle(exported_program, config=config)
|
|
307
320
|
|
tico/utils/dtype.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from circle_schema import circle
|
|
5
|
+
|
|
6
|
+
NUMPY_TO_TORCH_DTYPE_DICT = {
|
|
7
|
+
np.dtype("float32"): torch.float32,
|
|
8
|
+
np.dtype("float64"): torch.float64,
|
|
9
|
+
np.dtype("float16"): torch.float16,
|
|
10
|
+
np.dtype("complex64"): torch.complex64,
|
|
11
|
+
np.dtype("complex128"): torch.complex128,
|
|
12
|
+
np.dtype("int64"): torch.int64,
|
|
13
|
+
np.dtype("int32"): torch.int32,
|
|
14
|
+
np.dtype("int16"): torch.int16,
|
|
15
|
+
np.dtype("int8"): torch.int8,
|
|
16
|
+
np.dtype("uint8"): torch.uint8,
|
|
17
|
+
np.dtype("bool"): torch.bool,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
CIRCLE_TO_TORCH_DTYPE_DICT = {
|
|
21
|
+
circle.TensorType.TensorType.FLOAT32: torch.float32,
|
|
22
|
+
circle.TensorType.TensorType.UINT8: torch.uint8,
|
|
23
|
+
circle.TensorType.TensorType.INT8: torch.int8,
|
|
24
|
+
circle.TensorType.TensorType.INT16: torch.int16,
|
|
25
|
+
circle.TensorType.TensorType.INT32: torch.int32,
|
|
26
|
+
circle.TensorType.TensorType.INT64: torch.int64,
|
|
27
|
+
circle.TensorType.TensorType.BOOL: torch.bool,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def numpy_dtype_to_torch_dtype(np_dtype: np.dtype) -> torch.dtype:
|
|
32
|
+
return NUMPY_TO_TORCH_DTYPE_DICT[np_dtype]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def circle_dtype_to_torch_dtype(circle_dtype: int) -> torch.dtype:
|
|
36
|
+
assert isinstance(circle_dtype, int)
|
|
37
|
+
if circle_dtype not in CIRCLE_TO_TORCH_DTYPE_DICT:
|
|
38
|
+
raise RuntimeError(f"Unsupported dtype {circle_dtype}")
|
|
39
|
+
|
|
40
|
+
torch_dtype = CIRCLE_TO_TORCH_DTYPE_DICT[circle_dtype]
|
|
41
|
+
assert torch_dtype is not None
|
|
42
|
+
return torch_dtype
|
tico/utils/graph.py
CHANGED
|
@@ -24,7 +24,7 @@ import torch
|
|
|
24
24
|
from torch.export import ExportedProgram
|
|
25
25
|
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
|
|
26
26
|
|
|
27
|
-
from tico.utils.utils import get_fake_mode
|
|
27
|
+
from tico.utils.utils import get_fake_mode
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
def is_torch_param(node: torch.fx.Node, ep: ExportedProgram):
|
tico/utils/model.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
from __future__ import annotations
|
|
16
16
|
|
|
17
|
+
from pathlib import Path
|
|
17
18
|
from typing import Any
|
|
18
19
|
|
|
19
20
|
from tico.interpreter import infer
|
|
@@ -32,6 +33,6 @@ class CircleModel:
|
|
|
32
33
|
buf = bytes(f.read())
|
|
33
34
|
return CircleModel(buf)
|
|
34
35
|
|
|
35
|
-
def save(self, circle_path: str) -> None:
|
|
36
|
+
def save(self, circle_path: str | Path) -> None:
|
|
36
37
|
with open(circle_path, "wb") as f:
|
|
37
38
|
f.write(self.circle_binary)
|
tico/utils/padding.py
CHANGED
|
@@ -39,8 +39,8 @@ class ConvPaddingInfo(NamedTuple):
|
|
|
39
39
|
|
|
40
40
|
def identify_padding(
|
|
41
41
|
padding: PaddingValue,
|
|
42
|
-
input_shape: Sequence[int],
|
|
43
|
-
output_shape: Sequence[int],
|
|
42
|
+
input_shape: Sequence[int | torch.SymInt] | torch.Size,
|
|
43
|
+
output_shape: Sequence[int | torch.SymInt] | torch.Size,
|
|
44
44
|
stride: Sequence[int],
|
|
45
45
|
) -> ConvPaddingInfo:
|
|
46
46
|
"""
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from packaging.version import Version
|
|
5
|
+
|
|
6
|
+
from tico.utils import logging
|
|
7
|
+
from tico.utils.installed_packages import is_transformers_installed
|
|
8
|
+
|
|
9
|
+
__all__ = ["register_dynamic_cache"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def register_dynamic_cache():
|
|
13
|
+
PyTreeRegistryHelper().register_dynamic_cache()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PyTreeRegistryHelper:
|
|
17
|
+
"""
|
|
18
|
+
Thread-safe singleton helper class for registering custom PyTree nodes.
|
|
19
|
+
|
|
20
|
+
This class provides functionality to register DynamicCache as a PyTree node
|
|
21
|
+
for torch.export compatibility. This registration is only needed for
|
|
22
|
+
transformers versions below 4.50.0.
|
|
23
|
+
|
|
24
|
+
Thread Safety:
|
|
25
|
+
- Uses a class-level threading.Lock() to ensure thread-safe singleton instantiation
|
|
26
|
+
- Uses the same lock to protect the registration process from concurrent calls
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
_instance = None # Class variable to hold the singleton instance
|
|
30
|
+
_has_called = False # Flag to track if registration has been performed
|
|
31
|
+
_lock = threading.Lock() # Class-level lock for thread-safe operations
|
|
32
|
+
|
|
33
|
+
def __init__(self):
|
|
34
|
+
"""Private constructor to prevent direct instantiation"""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
def __new__(cls, *args, **kwargs):
|
|
38
|
+
"""
|
|
39
|
+
Thread-safe singleton instance creation using double-checked locking pattern.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
PyTreeRegistryHelper: The singleton instance of this class
|
|
43
|
+
"""
|
|
44
|
+
if not cls._instance:
|
|
45
|
+
with cls._lock: # Acquire lock for thread-safe instantiation
|
|
46
|
+
if not cls._instance: # Double-check after acquiring lock
|
|
47
|
+
cls._instance = super().__new__(cls)
|
|
48
|
+
return cls._instance
|
|
49
|
+
|
|
50
|
+
def register_dynamic_cache(self):
|
|
51
|
+
"""
|
|
52
|
+
Registers DynamicCache as a PyTree node for torch.export compatibility.
|
|
53
|
+
|
|
54
|
+
This method is thread-safe and idempotent - it will only perform the
|
|
55
|
+
registration once, even if called multiple times from different threads.
|
|
56
|
+
|
|
57
|
+
Note:
|
|
58
|
+
This registration is only needed for transformers versions below 4.50.0.
|
|
59
|
+
|
|
60
|
+
Raises:
|
|
61
|
+
ImportError: If transformers package is not installed
|
|
62
|
+
"""
|
|
63
|
+
with self._lock: # Acquire lock for thread-safe registration
|
|
64
|
+
if self.__class__._has_called:
|
|
65
|
+
logger = logging.getLogger(__name__)
|
|
66
|
+
logger.debug("register_dynamic_cache already called, skipping")
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
self.__class__._has_called = True
|
|
70
|
+
logger = logging.getLogger(__name__)
|
|
71
|
+
logger.info("Registering DynamicCache PyTree node")
|
|
72
|
+
|
|
73
|
+
if not is_transformers_installed: # type: ignore[truthy-function]
|
|
74
|
+
raise ImportError("transformers package is not installed")
|
|
75
|
+
|
|
76
|
+
import transformers
|
|
77
|
+
|
|
78
|
+
HAS_TRANSFORMERS_LESS_4_50_0 = Version(transformers.__version__) < Version(
|
|
79
|
+
"4.50.0"
|
|
80
|
+
)
|
|
81
|
+
if not HAS_TRANSFORMERS_LESS_4_50_0:
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
from transformers.cache_utils import DynamicCache
|
|
85
|
+
|
|
86
|
+
def _flatten_dynamic_cache(dynamic_cache: DynamicCache):
|
|
87
|
+
if not isinstance(dynamic_cache, DynamicCache):
|
|
88
|
+
raise RuntimeError(
|
|
89
|
+
"This pytree flattening function should only be applied to DynamicCache"
|
|
90
|
+
)
|
|
91
|
+
HAS_TORCH_2_6_0 = Version(torch.__version__) >= Version("2.6.0")
|
|
92
|
+
if not HAS_TORCH_2_6_0:
|
|
93
|
+
logger = logging.getLogger(__name__)
|
|
94
|
+
logger.warning_once(
|
|
95
|
+
"DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions."
|
|
96
|
+
)
|
|
97
|
+
dictionary = {
|
|
98
|
+
"key_cache": getattr(dynamic_cache, "key_cache"),
|
|
99
|
+
"value_cache": getattr(dynamic_cache, "value_cache"),
|
|
100
|
+
}
|
|
101
|
+
return torch.utils._pytree._dict_flatten(dictionary)
|
|
102
|
+
|
|
103
|
+
def _flatten_with_keys_dynamic_cache(dynamic_cache: DynamicCache):
|
|
104
|
+
dictionary = {
|
|
105
|
+
"key_cache": getattr(dynamic_cache, "key_cache"),
|
|
106
|
+
"value_cache": getattr(dynamic_cache, "value_cache"),
|
|
107
|
+
}
|
|
108
|
+
return torch.utils._pytree._dict_flatten_with_keys(dictionary)
|
|
109
|
+
|
|
110
|
+
def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
|
|
111
|
+
dictionary = torch.utils._pytree._dict_unflatten(values, context)
|
|
112
|
+
cache = DynamicCache()
|
|
113
|
+
for k, v in dictionary.items():
|
|
114
|
+
setattr(cache, k, v)
|
|
115
|
+
return cache
|
|
116
|
+
|
|
117
|
+
def _flatten_dynamic_cache_for_fx(cache, spec):
|
|
118
|
+
dictionary = {
|
|
119
|
+
"key_cache": getattr(cache, "key_cache"),
|
|
120
|
+
"value_cache": getattr(cache, "value_cache"),
|
|
121
|
+
}
|
|
122
|
+
return torch.fx._pytree._dict_flatten_spec(dictionary, spec)
|
|
123
|
+
|
|
124
|
+
torch.utils._pytree.register_pytree_node(
|
|
125
|
+
DynamicCache,
|
|
126
|
+
_flatten_dynamic_cache,
|
|
127
|
+
_unflatten_dynamic_cache,
|
|
128
|
+
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
|
|
129
|
+
flatten_with_keys_fn=_flatten_with_keys_dynamic_cache,
|
|
130
|
+
)
|
|
131
|
+
# TODO: This won't be needed in torch 2.7+.
|
|
132
|
+
torch.fx._pytree.register_pytree_flatten_spec(
|
|
133
|
+
DynamicCache, _flatten_dynamic_cache_for_fx
|
|
134
|
+
)
|