tico 0.1.0.dev250810__py3-none-any.whl → 0.1.0.dev250811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tico/__init__.py CHANGED
@@ -29,7 +29,7 @@ __all__ = [
29
29
  ]
30
30
 
31
31
  # THIS LINE IS AUTOMATICALLY GENERATED BY setup.py
32
- __version__ = "0.1.0.dev250810"
32
+ __version__ = "0.1.0.dev250811"
33
33
 
34
34
  MINIMUM_SUPPORTED_VERSION = "2.5.0"
35
35
  SECURE_TORCH_VERSION = "2.6.0"
@@ -0,0 +1 @@
1
+ # DO NOT REMOVE THIS FILE
@@ -0,0 +1,35 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from contextlib import contextmanager
16
+
17
+ import torch
18
+
19
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
20
+
21
+
22
+ def llama_rmsnorm_forward_adapter(self: LlamaRMSNorm, hidden_states: torch.Tensor):
23
+ return torch.ops.circle_custom.rms_norm(
24
+ hidden_states, self.weight, self.variance_epsilon
25
+ )
26
+
27
+
28
+ @contextmanager
29
+ def patched_llama_rmsnorm():
30
+ orig = LlamaRMSNorm.forward
31
+ LlamaRMSNorm.forward = llama_rmsnorm_forward_adapter
32
+ try:
33
+ yield
34
+ finally:
35
+ LlamaRMSNorm.forward = orig
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, List, TYPE_CHECKING
16
+
17
+ if TYPE_CHECKING:
18
+ import torch._ops
19
+ import torch.fx
20
+ import torch
21
+ from circle_schema import circle
22
+
23
+ from tico.serialize.circle_graph import CircleSubgraph
24
+ from tico.serialize.operators.hashable_opcode import OpCode
25
+ from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor
26
+ from tico.serialize.operators.utils import create_builtin_operator, get_op_index
27
+ from tico.utils.validate_args_kwargs import CircleRMSNormArgs
28
+
29
+
30
+ @register_node_visitor
31
+ class RMSNormVisitor(NodeVisitor):
32
+ target: List[torch._ops.OpOverload] = [
33
+ torch.ops.circle_custom.rms_norm.default,
34
+ ]
35
+
36
+ def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph):
37
+ super().__init__(op_codes, graph)
38
+
39
+ def define_node(
40
+ self,
41
+ node: torch.fx.Node,
42
+ ) -> circle.Operator.OperatorT:
43
+ args = CircleRMSNormArgs(*node.args, **node.kwargs) # type: ignore[arg-type]
44
+ input = args.input
45
+ weight = args.weight
46
+ eps = args.eps
47
+
48
+ op_index = get_op_index(
49
+ circle.BuiltinOperator.BuiltinOperator.RMS_NORM, self._op_codes
50
+ )
51
+
52
+ inputs = [input, weight]
53
+ outputs = [node]
54
+ operator = create_builtin_operator(self.graph, op_index, inputs, outputs)
55
+
56
+ # Op-specific option
57
+ operator.builtinOptionsType = (
58
+ circle.BuiltinOptions.BuiltinOptions.RmsNormOptions
59
+ )
60
+ option = circle.RmsNormOptions.RmsNormOptionsT()
61
+ option.epsilon = eps
62
+
63
+ operator.builtinOptions = option
64
+
65
+ return operator
@@ -703,6 +703,28 @@ def CircleQuantizeMX():
703
703
  return input_
704
704
 
705
705
 
706
+ def CircleRMSNorm():
707
+ @custom_op("circle_custom::rms_norm", mutates_args=())
708
+ def rms_norm(
709
+ hidden_states: torch.Tensor,
710
+ weight: Optional[torch.Tensor] = None,
711
+ eps: float = 1e-05,
712
+ ) -> torch.Tensor:
713
+ input_dtype = hidden_states.dtype
714
+ hidden_states = hidden_states.to(torch.float32)
715
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
716
+ hidden_states = hidden_states * torch.rsqrt(variance + eps)
717
+ return weight * hidden_states.to(input_dtype)
718
+
719
+ @register_fake("circle_custom::rms_norm")
720
+ def _(
721
+ hidden_states: torch.Tensor,
722
+ weight: Optional[torch.Tensor] = None,
723
+ eps: float = 1e-05,
724
+ ) -> torch.Tensor:
725
+ return hidden_states.new_empty(hidden_states.size())
726
+
727
+
706
728
  # Add custom ops to the torch namespace
707
729
  def RegisterOps():
708
730
  CircleResizeNearestNeighbor()
@@ -715,3 +737,4 @@ def RegisterOps():
715
737
  CircleAvgPool2D()
716
738
  CircleInstanceNorm()
717
739
  CircleQuantizeMX()
740
+ CircleRMSNorm()
@@ -171,6 +171,19 @@ class CatArgs:
171
171
  dim: int = 0
172
172
 
173
173
 
174
+ @enforce_type
175
+ @dataclass
176
+ class CircleRMSNormArgs:
177
+ """
178
+ This is not aten ops but custom op for RMSNorm.
179
+ circle_custom.rms_norm(Tensor input, Tensor? weight=None, float? eps=None) -> Tensor
180
+ """
181
+
182
+ input: torch.fx.Node
183
+ weight: Optional[torch.fx.Node]
184
+ eps: Optional[float]
185
+
186
+
174
187
  @enforce_type
175
188
  @dataclass
176
189
  class ClampArgs:
@@ -931,6 +944,19 @@ class ResizeNearestNeighborArgs:
931
944
  size: List[int]
932
945
 
933
946
 
947
+ @enforce_type
948
+ @dataclass
949
+ class RMSNormArgs:
950
+ """
951
+ rms_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor
952
+ """
953
+
954
+ input: torch.fx.Node
955
+ normalized_shape: List[int]
956
+ weight: Optional[torch.fx.Node]
957
+ eps: Optional[float]
958
+
959
+
934
960
  @enforce_type
935
961
  @dataclass
936
962
  class RoundArgs:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: tico
3
- Version: 0.1.0.dev250810
3
+ Version: 0.1.0.dev250811
4
4
  Summary: Convert exported Torch module to circle
5
5
  Home-page: UNKNOWN
6
6
  License: UNKNOWN
@@ -1,4 +1,4 @@
1
- tico/__init__.py,sha256=H67V6wHv1jk6smTy2N_nzF9vFR0U8Rf-Vsc19oWjhpM,1883
1
+ tico/__init__.py,sha256=i6HyVzF3572_JnWulVxE-KRy80Yd4W998rViuVqVHqg,1883
2
2
  tico/pt2_to_circle.py,sha256=gu3MD4Iqc0zMZcCZ2IT8oGbyj21CTSbT3Rgd9s2B_9A,2767
3
3
  tico/config/__init__.py,sha256=xZzCXjZ84qE-CsBi-dfaL05bqpQ3stKKfTXhnrJRyVs,142
4
4
  tico/config/base.py,sha256=q5xMqGxTUZs4mFqt5c7i_y9U00fYgdMGl9nUqIVMlCo,1248
@@ -166,6 +166,7 @@ tico/serialize/operators/op_relu6.py,sha256=ZWqEolfAKjOdUC1ZCg0iuu4dBhkJRxVYR2tU
166
166
  tico/serialize/operators/op_repeat.py,sha256=VrRxD31pT3hRGH-5n6ia3PJBXh_u0GvIl1hZZYFrKTQ,4507
167
167
  tico/serialize/operators/op_reshape.py,sha256=6wErQpmDX9mAmfJRCTg_cg1uOdJZqHm8Nux8dNI53Vg,2559
168
168
  tico/serialize/operators/op_resize_nearest_neighbor.py,sha256=dXaAnZ5M_ko_tH-HolxNpHFXkDUQ8x45myskojP5XZE,2771
169
+ tico/serialize/operators/op_rmsnorm.py,sha256=vkJgg2YtTY9pjceTLh6gTZ-MN3EltnlEyAP5gVc5SiU,2216
169
170
  tico/serialize/operators/op_round.py,sha256=pe6w_TB4xGLu0iPv4Qo0a0fIkY9DgCgXk5127TWt8pE,1837
170
171
  tico/serialize/operators/op_rsqrt.py,sha256=yl2vd8InjhLPbE0vHIrEera6DVXlY9dLgO7yZZCH3RI,1837
171
172
  tico/serialize/operators/op_scalar_tensor.py,sha256=vDWxi4hXwyDJJhvfMR_QrBInw_No3WeU_M4gtfZqmbo,1928
@@ -186,6 +187,8 @@ tico/serialize/operators/op_unsqueeze.py,sha256=ZHhfVXSWEiwb2VDYX5uhxbGQyzZjKT7C
186
187
  tico/serialize/operators/op_view.py,sha256=xxE-GvTJ1UpcHst5KXYz3qKY-eJQvXKKrSZiA2O7E40,2593
187
188
  tico/serialize/operators/op_where.py,sha256=doE81GSwygrPBm3JIfN9w7kKXxeIYKxgk0eoY22QIcg,2845
188
189
  tico/serialize/operators/utils.py,sha256=lXGpEJW1h8U_-gfc6EWjvvSiq3yJ9P-v1v3EMRT_pSk,2954
190
+ tico/serialize/operators/adapters/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
191
+ tico/serialize/operators/adapters/llama_rmsnorm.py,sha256=6t3dhfNpR03eIjsmhymF2JKd6lCf7PvInqMf77c_BOE,1139
189
192
  tico/utils/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
190
193
  tico/utils/convert.py,sha256=GgZwZtiqFzTdszfUQO0vcX39lKjs97gYwZ-Tiw_4Bbo,13222
191
194
  tico/utils/define.py,sha256=Ypgp7YffM4pgPl4Zh6TmogSn1OxGBMRw_e09qYGflZk,1467
@@ -200,20 +203,20 @@ tico/utils/padding.py,sha256=qKke-dJeeLHiRaePjDS66txrGyiYuipLVQeqLYad8uk,3349
200
203
  tico/utils/passes.py,sha256=kGmDe__5cPaO6i5EDAoXSVe6yXEoX9hAny4ROb3ZEmQ,2409
201
204
  tico/utils/pytree_utils.py,sha256=jrk3N6X6LiUnBCX_gM1K9nywbVAJBVnszlTAgeIeDUc,5219
202
205
  tico/utils/record_input.py,sha256=QN-8D71G_WAX3QQQ5CIwbEfFJZTQ3CvL4wCMiVddua4,3894
203
- tico/utils/register_custom_op.py,sha256=3-Yl6iYmx1qQA2igNHt4hYhQhQMkdPb7gF50LIY8yvc,27350
206
+ tico/utils/register_custom_op.py,sha256=n91UtmPedoqhkR8fBNRbk9Msq79pn9DHNHlt99l2s_w,28142
204
207
  tico/utils/serialize.py,sha256=mEuusEzi82WFsz3AkowgWwxSLeo50JDxyOj6yYDQhEI,1914
205
208
  tico/utils/signature.py,sha256=R2GV0alRpXEbZISqPKyxCUWbgDcsrQ2ovbVG3737IzA,9595
206
209
  tico/utils/torch_compat.py,sha256=oc6PztVsXdHcQ3iaVR90wLLxrGaj6zFHWZ8K9rRS6q8,1795
207
210
  tico/utils/trace_decorators.py,sha256=ddLIiKQfSaQrxgF1kNpwjFTQnXENzeSfcr1kuAW4jGI,3221
208
211
  tico/utils/utils.py,sha256=A5p3iAAxRGDsZJh4ybp-Qo3MX3vk5RrmSY-R3rXqVeI,12976
209
- tico/utils/validate_args_kwargs.py,sha256=CRj_SXMUUn6onsl8XLAt-zPZCFxR4C0XOCoaad_ZD4I,26689
212
+ tico/utils/validate_args_kwargs.py,sha256=yikeUbYfSg2378wagEMXDlJeSRv8HKI2oxpjWarolec,27268
210
213
  tico/utils/mx/__init__.py,sha256=IO6FP_xYbGy0dW0HL26GXD3ouxARaxCK7bz9dn4blPQ,26
211
214
  tico/utils/mx/elemwise_ops.py,sha256=V6glyAHsVR1joqpsgnNytatCD_ew92xNWZ19UFDoMTA,10281
212
215
  tico/utils/mx/formats.py,sha256=uzNWyu-1onUlwQfX5cZ6fZSUfHMRqorper7_T1k3jfk,3404
213
216
  tico/utils/mx/mx_ops.py,sha256=RcfUTYVi-wilGB2sC35OeARdwDqnixv7dG5iyZ-fQT8,8555
214
- tico-0.1.0.dev250810.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
215
- tico-0.1.0.dev250810.dist-info/METADATA,sha256=N1kcg1vk8kn6bLJgRUSZMZalY-mTn46jqLvVj9NGvR4,8450
216
- tico-0.1.0.dev250810.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
217
- tico-0.1.0.dev250810.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
218
- tico-0.1.0.dev250810.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
219
- tico-0.1.0.dev250810.dist-info/RECORD,,
217
+ tico-0.1.0.dev250811.dist-info/LICENSE,sha256=kp4JLII7bzRhPb0CPD5XTDZMh22BQ7h3k3B7t8TiSbw,12644
218
+ tico-0.1.0.dev250811.dist-info/METADATA,sha256=p8ENvpaXfA4jwrayPNaJqjCRRBKdoGu_1Vm08SBdqpU,8450
219
+ tico-0.1.0.dev250811.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
220
+ tico-0.1.0.dev250811.dist-info/entry_points.txt,sha256=kBKYSS_IYrSXmUYevmmepqIVPScq5vF8ulQRu3I_Zf0,59
221
+ tico-0.1.0.dev250811.dist-info/top_level.txt,sha256=oqs7UPoNSKZEwqsX8B-KAWdQwfAa7i60pbxW_Jk7P3w,5
222
+ tico-0.1.0.dev250811.dist-info/RECORD,,