nvidia-cutlass-operators 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. cutlass/operators/__init__.py +136 -0
  2. cutlass/operators/arch.py +219 -0
  3. cutlass/operators/arguments/__init__.py +60 -0
  4. cutlass/operators/arguments/base.py +134 -0
  5. cutlass/operators/arguments/epilogue.py +196 -0
  6. cutlass/operators/arguments/gemm.py +212 -0
  7. cutlass/operators/arguments/grouped_gemm.py +107 -0
  8. cutlass/operators/arguments/operand.py +396 -0
  9. cutlass/operators/artifact.py +55 -0
  10. cutlass/operators/base.py +318 -0
  11. cutlass/operators/config.py +210 -0
  12. cutlass/operators/fusion/__init__.py +126 -0
  13. cutlass/operators/fusion/activation.py +233 -0
  14. cutlass/operators/fusion/backend/__init__.py +45 -0
  15. cutlass/operators/fusion/backend/emitter_base.py +165 -0
  16. cutlass/operators/fusion/backend/sm100_emitter.py +151 -0
  17. cutlass/operators/fusion/backend/sm100_nodes.py +140 -0
  18. cutlass/operators/fusion/backend/sm80_emitter.py +46 -0
  19. cutlass/operators/fusion/backend/sm80_nodes.py +247 -0
  20. cutlass/operators/fusion/backend/sm90_emitter.py +97 -0
  21. cutlass/operators/fusion/backend/sm90_nodes.py +327 -0
  22. cutlass/operators/fusion/epilogue.py +88 -0
  23. cutlass/operators/fusion/evt_ops.py +116 -0
  24. cutlass/operators/fusion/frontend/__init__.py +41 -0
  25. cutlass/operators/fusion/frontend/frontend_base.py +303 -0
  26. cutlass/operators/fusion/frontend/python_ast.py +261 -0
  27. cutlass/operators/fusion/ir/__init__.py +75 -0
  28. cutlass/operators/fusion/ir/c_types.py +247 -0
  29. cutlass/operators/fusion/ir/compute_nodes.py +99 -0
  30. cutlass/operators/fusion/ir/dag_ir.py +258 -0
  31. cutlass/operators/fusion/ir/layout_algorithm.py +362 -0
  32. cutlass/operators/fusion/ir/layout_nodes.py +351 -0
  33. cutlass/operators/fusion/ir/load_nodes.py +312 -0
  34. cutlass/operators/fusion/ir/node.py +330 -0
  35. cutlass/operators/fusion/ir/store_nodes.py +276 -0
  36. cutlass/operators/fusion/ir/tensor.py +155 -0
  37. cutlass/operators/fusion/library.py +430 -0
  38. cutlass/operators/fusion/passes/__init__.py +59 -0
  39. cutlass/operators/fusion/passes/graph_drawer.py +133 -0
  40. cutlass/operators/fusion/passes/pass_argument_type.py +136 -0
  41. cutlass/operators/fusion/passes/pass_dag_2_tree.py +176 -0
  42. cutlass/operators/fusion/passes/pass_fix_element_d.py +86 -0
  43. cutlass/operators/fusion/passes/pass_get_impl.py +93 -0
  44. cutlass/operators/fusion/passes/pass_layout_elimination.py +230 -0
  45. cutlass/operators/fusion/passes/pass_manager.py +185 -0
  46. cutlass/operators/fusion/passes/pass_no_op_elimination.py +59 -0
  47. cutlass/operators/fusion/passes/pass_preprocess_red.py +96 -0
  48. cutlass/operators/fusion/passes/pass_shape_type_propagation.py +60 -0
  49. cutlass/operators/fusion/passes/smem_size_calculator.py +363 -0
  50. cutlass/operators/fusion/passes/util.py +46 -0
  51. cutlass/operators/fusion/pycute/__init__.py +36 -0
  52. cutlass/operators/fusion/pycute/int_tuple.py +229 -0
  53. cutlass/operators/fusion/pycute/layout.py +409 -0
  54. cutlass/operators/fusion/pycute/swizzle.py +133 -0
  55. cutlass/operators/fusion/pycute/typing.py +42 -0
  56. cutlass/operators/manifest.py +153 -0
  57. cutlass/operators/metadata/__init__.py +73 -0
  58. cutlass/operators/metadata/base.py +170 -0
  59. cutlass/operators/metadata/design/__init__.py +45 -0
  60. cutlass/operators/metadata/design/base.py +109 -0
  61. cutlass/operators/metadata/design/sm100.py +111 -0
  62. cutlass/operators/metadata/design/tile_scheduler.py +91 -0
  63. cutlass/operators/metadata/epilogue.py +71 -0
  64. cutlass/operators/metadata/operand_constraints.py +255 -0
  65. cutlass/operators/metadata/operands/__init__.py +37 -0
  66. cutlass/operators/metadata/operands/base.py +70 -0
  67. cutlass/operators/metadata/operands/gemm.py +90 -0
  68. cutlass/operators/metadata/operands/grouped_gemm.py +100 -0
  69. cutlass/operators/mma.py +166 -0
  70. cutlass/operators/providers/__init__.py +99 -0
  71. cutlass/operators/providers/cutedsl/__init__.py +66 -0
  72. cutlass/operators/providers/cutedsl/evt/common_efc.py +2189 -0
  73. cutlass/operators/providers/cutedsl/evt/converter.py +470 -0
  74. cutlass/operators/providers/cutedsl/gemm/__init__.py +43 -0
  75. cutlass/operators/providers/cutedsl/gemm/implementations/operator_helpers.py +73 -0
  76. cutlass/operators/providers/cutedsl/gemm/implementations/scheduler.py +589 -0
  77. cutlass/operators/providers/cutedsl/gemm/implementations/sm100_contiguous_offset_2d3d_dense_gemm_impl.py +1976 -0
  78. cutlass/operators/providers/cutedsl/gemm/implementations/sm100_dense_blockscaled_static_persistent_impl.py +2170 -0
  79. cutlass/operators/providers/cutedsl/gemm/implementations/sm100_mixed_input_impl.py +2058 -0
  80. cutlass/operators/providers/cutedsl/gemm/implementations/sm100_persistent_impl.py +1277 -0
  81. cutlass/operators/providers/cutedsl/gemm/implementations/sm100_persistent_preferred_cluster_impl.py +1176 -0
  82. cutlass/operators/providers/cutedsl/gemm/implementations/sm100_static_persistent_efc_impl.py +1686 -0
  83. cutlass/operators/providers/cutedsl/gemm/implementations/sm100_tgv_gemm_impl.py +1085 -0
  84. cutlass/operators/providers/cutedsl/gemm/implementations/sm80_tensorop_gemm_impl.py +814 -0
  85. cutlass/operators/providers/cutedsl/gemm/implementations/sm90_static_persistent_impl.py +1236 -0
  86. cutlass/operators/providers/cutedsl/gemm/sm100_contiguous_offset_2d3d_dense_gemm.py +451 -0
  87. cutlass/operators/providers/cutedsl/gemm/sm100_dense_blockscaled_static_persistent.py +606 -0
  88. cutlass/operators/providers/cutedsl/gemm/sm100_mixed_input.py +910 -0
  89. cutlass/operators/providers/cutedsl/gemm/sm100_persistent.py +494 -0
  90. cutlass/operators/providers/cutedsl/gemm/sm100_persistent_preferred_cluster.py +671 -0
  91. cutlass/operators/providers/cutedsl/gemm/sm100_static_persistent_efc.py +506 -0
  92. cutlass/operators/providers/cutedsl/gemm/sm100_tgv_gemm.py +420 -0
  93. cutlass/operators/providers/cutedsl/gemm/sm80_tensorop_gemm.py +316 -0
  94. cutlass/operators/providers/cutedsl/gemm/sm90_static_persistent.py +441 -0
  95. cutlass/operators/providers/cutedsl/integration_utils/__init__.py +40 -0
  96. cutlass/operators/providers/cutedsl/integration_utils/builders.py +97 -0
  97. cutlass/operators/providers/cutedsl/integration_utils/mma.py +49 -0
  98. cutlass/operators/providers/cutedsl/operator.py +90 -0
  99. cutlass/operators/providers/provider.py +153 -0
  100. cutlass/operators/status.py +68 -0
  101. cutlass/operators/typing.py +73 -0
  102. cutlass/operators/utils/__init__.py +39 -0
  103. cutlass/operators/utils/common.py +84 -0
  104. cutlass/operators/utils/device.py +247 -0
  105. cutlass/operators/utils/dtype.py +213 -0
  106. cutlass/operators/utils/gemm.py +96 -0
  107. cutlass/operators/utils/generate.py +100 -0
  108. cutlass/operators/utils/layout.py +80 -0
  109. cutlass/operators/utils/tensor.py +575 -0
  110. cutlass/operators/workspace.py +67 -0
  111. nvidia_cutlass_operators-0.1.0.dist-info/METADATA +198 -0
  112. nvidia_cutlass_operators-0.1.0.dist-info/RECORD +115 -0
  113. nvidia_cutlass_operators-0.1.0.dist-info/WHEEL +5 -0
  114. nvidia_cutlass_operators-0.1.0.dist-info/licenses/LICENSE.txt +27 -0
  115. nvidia_cutlass_operators-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,136 @@
1
+ # Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-3-Clause
3
+
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+
7
+ # 1. Redistributions of source code must retain the above copyright notice, this
8
+ # list of conditions and the following disclaimer.
9
+
10
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
11
+ # this list of conditions and the following disclaimer in the documentation
12
+ # and/or other materials provided with the distribution.
13
+
14
+ # 3. Neither the name of the copyright holder nor the names of its
15
+ # contributors may be used to endorse or promote products derived from
16
+ # this software without specific prior written permission.
17
+
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
+
29
+
30
+ from __future__ import annotations
31
+
32
+ from typing import TYPE_CHECKING
33
+
34
+ from cutlass.operators import mma, workspace
35
+ from cutlass.operators.arch import ArchPortability, TargetSm
36
+ from cutlass.operators.arguments import (
37
+ DenseTensor,
38
+ EpilogueArguments,
39
+ GemmArguments,
40
+ GroupedGemmArguments,
41
+ Operand,
42
+ PerformanceControls,
43
+ RuntimeArguments,
44
+ ScaledOperand,
45
+ ScaleMode,
46
+ ScaleSwizzleMode,
47
+ )
48
+ from cutlass.operators.artifact import CompiledArtifact
49
+ from cutlass.operators.base import Operator
50
+ from cutlass.operators.config import GlobalOptions
51
+ from cutlass.operators.manifest import Manifest
52
+ from cutlass.operators.metadata import OperatorMetadata
53
+ from cutlass.operators.providers import (
54
+ CuTeDSLProvider,
55
+ Provider,
56
+ available_providers,
57
+ register_provider,
58
+ )
59
+ from cutlass.operators.status import Status
60
+ from cutlass.operators.typing import NumericLike, TensorLike
61
+
62
+ if TYPE_CHECKING:
63
+ from collections.abc import Callable
64
+
65
+
66
+ __version__ = "0.1.0"
67
+
68
+
69
+ def get_operators(
70
+ args: RuntimeArguments | None = None,
71
+ metadata_filter: Callable[[OperatorMetadata], bool] | None = None,
72
+ target_sm: TargetSm | str | None = None,
73
+ providers: list[Provider] | None = None,
74
+ ) -> list[Operator]:
75
+ """Return Operators that match the given arguments, metadata filter, and target.
76
+
77
+ Args:
78
+ args (RuntimeArguments | None): Runtime arguments describing the
79
+ operator invocation (e.g. :class:`GemmArguments`). When ``None``,
80
+ no argument-based filtering is applied.
81
+ metadata_filter (Callable[[OperatorMetadata], bool] | None): An
82
+ optional Callable that takes OperatorMetadata as input and returns
83
+ a boolean indicating if it should be considered for inclusion in
84
+ results. The result is an intersection of operators filtered by the
85
+ callable and by other parameters passed to this method.
86
+ target_sm (TargetSm | str | None): Compute capability to target (e.g.
87
+ ``"100a"`` or a :class:`TargetSm` instance). Filters Operators that
88
+ cannot run on this target.
89
+ providers (list[Provider] | None): Optional list of Providers to
90
+ restrict discovery to (e.g. ``[ops.CuTeDSLProvider]``).
91
+
92
+ Returns:
93
+ list[Operator]: Operators matching all filters.
94
+ """
95
+ return Manifest.get_operators(args, metadata_filter, target_sm, providers)
96
+
97
+
98
+ __all__ = [
99
+ # Central class exposing the Operator interface
100
+ "Operator",
101
+ # Runtime arguments
102
+ "RuntimeArguments",
103
+ "PerformanceControls",
104
+ "EpilogueArguments",
105
+ "GemmArguments",
106
+ "GroupedGemmArguments",
107
+ # Operands
108
+ "Operand",
109
+ "DenseTensor",
110
+ "ScaledOperand",
111
+ "ScaleMode",
112
+ "ScaleSwizzleMode",
113
+ # Operator Discovery
114
+ "get_operators",
115
+ "Manifest",
116
+ "OperatorMetadata",
117
+ # Misc. core types
118
+ "Status",
119
+ "CompiledArtifact",
120
+ # Arch
121
+ "TargetSm",
122
+ "ArchPortability",
123
+ # Typing markers
124
+ "TensorLike",
125
+ "NumericLike",
126
+ # Configuration
127
+ "GlobalOptions",
128
+ # Provider management
129
+ "Provider", # base class for all Providers
130
+ "CuTeDSLProvider", # CuTeDSLProvider holding registry of kernels written in CuTe DSL
131
+ "available_providers", # global list of available, registered Providers
132
+ "register_provider", # decorator to register a new Provider
133
+ # Submodules to access less commonly used public surfaces
134
+ "mma",
135
+ "workspace",
136
+ ]
@@ -0,0 +1,219 @@
1
+ # Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-3-Clause
3
+
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+
7
+ # 1. Redistributions of source code must retain the above copyright notice, this
8
+ # list of conditions and the following disclaimer.
9
+
10
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
11
+ # this list of conditions and the following disclaimer in the documentation
12
+ # and/or other materials provided with the distribution.
13
+
14
+ # 3. Neither the name of the copyright holder nor the names of its
15
+ # contributors may be used to endorse or promote products derived from
16
+ # this software without specific prior written permission.
17
+
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
+
29
+ from __future__ import annotations
30
+
31
+ import enum
32
+ from dataclasses import dataclass
33
+ from enum import auto as enum_auto
34
+ from typing import TYPE_CHECKING
35
+
36
+ if TYPE_CHECKING:
37
+ from collections.abc import Iterable
38
+
39
+ from cutlass.operators.metadata import DesignMetadata, OperandsMetadata
40
+
41
+
42
+ @dataclass(init=False)
43
+ class TargetSm:
44
+ """Target compute capability & portability to compile for."""
45
+
46
+ cc: int
47
+ portability: ArchPortability
48
+
49
+ def __init__(
50
+ self,
51
+ sm_str: str | None = None,
52
+ *,
53
+ cc: int | None = None,
54
+ portability: ArchPortability | None = None,
55
+ ):
56
+ """Create a ``TargetSm``.
57
+
58
+ Args:
59
+ sm_str (str | None): String like "100", "100a", or "100f".
60
+ cc (int | None): Compute capability as integer (e.g., 100).
61
+ portability (ArchPortability | None): Architecture portability. If not
62
+ provided, defaults to ArchConditional if cc >= 90, else Portable.
63
+
64
+ Note:
65
+ Must use either sm_str or cc/portability, not both.
66
+
67
+ Example:
68
+ >>> TargetSm("100a")
69
+ >>> TargetSm(cc=100)
70
+ >>> TargetSm(cc=100, portability=ArchPortability.ArchConditional)
71
+
72
+ Raises:
73
+ ValueError: If both sm_str and cc/portability are provided, or neither.
74
+ ValueError: If portability is not valid for the given cc.
75
+ """
76
+ if sm_str is not None:
77
+ if cc is not None or portability is not None:
78
+ raise ValueError("Cannot specify sm_str with cc/portability")
79
+ self.cc, self.portability = self._parse_str(sm_str)
80
+ elif cc is not None:
81
+ self.cc = cc
82
+ self.portability = portability or (
83
+ ArchPortability.ArchConditional
84
+ if cc >= 90
85
+ else ArchPortability.Portable
86
+ )
87
+ else:
88
+ raise ValueError("Must specify either sm_str or cc/portability")
89
+
90
+ if self.portability == ArchPortability.ArchConditional and self.cc < 90:
91
+ raise ValueError(
92
+ f"TargetSm {self.__str__()} is invalid. ArchConditional targets must be Hopper (sm90) or newer"
93
+ )
94
+ if self.portability == ArchPortability.FamilyPortable and self.cc < 100:
95
+ raise ValueError(
96
+ f"TargetSm {self.__str__()} is invalid. FamilyPortable targets must be Blackwell (sm100) or newer"
97
+ )
98
+
99
+ @classmethod
100
+ def ensure(cls, value: TargetSm | str | None) -> TargetSm | None:
101
+ """Coerce a string to ``TargetSm``, or return as-is if already one."""
102
+ if value is None or isinstance(value, cls):
103
+ return value
104
+ if isinstance(value, str):
105
+ return cls(value)
106
+ raise TypeError(f"Expected TargetSm or str. Got {value} of type {type(value)}")
107
+
108
+ @staticmethod
109
+ def get_supported_targets(
110
+ design: DesignMetadata, operands: OperandsMetadata
111
+ ) -> list[TargetSm]:
112
+ """Get the supported targets for a given design and operands.
113
+
114
+ Raises:
115
+ ValueError: If ``design`` does not have a ``mma_instruction_type`` field.
116
+ """
117
+ # For now, just return the supported instruction.
118
+ # It can be extended to any other rules that are introduced.
119
+ metadata_to_check = []
120
+
121
+ if not hasattr(design, "mma_instruction_type"):
122
+ raise ValueError(
123
+ "Cannot determine supported targets with missing mma_instruction_type field in "
124
+ f"design metadata: {design}"
125
+ )
126
+ mma_supported_targets = design.mma_instruction_type.supported_targets(
127
+ design, operands
128
+ )
129
+
130
+ metadata_to_check.append(mma_supported_targets)
131
+
132
+ # check against tile scheduler supported targets
133
+ if hasattr(design, "tile_scheduler") and design.tile_scheduler:
134
+ tile_scheduler_supported_targets = design.tile_scheduler.supported_targets(
135
+ design, operands
136
+ )
137
+ metadata_to_check.append(tile_scheduler_supported_targets)
138
+ return list(set.intersection(*map(set, metadata_to_check)))
139
+
140
+ def is_portable_to(self, other: TargetSm | str) -> bool:
141
+ """Check if this target can compile/run on architecture described by `other` TargetSm."""
142
+ other = TargetSm.ensure(other)
143
+ match self.portability:
144
+ case ArchPortability.Portable:
145
+ return self.cc <= other.cc
146
+ case ArchPortability.FamilyPortable:
147
+ return self.major == other.major and self.cc <= other.cc
148
+ case ArchPortability.ArchConditional:
149
+ return self == other
150
+ case _:
151
+ raise NotImplementedError
152
+
153
+ def supports_operators_from(
154
+ self, operator_targets: Iterable[TargetSm | str]
155
+ ) -> bool:
156
+ """Check if this target can compile/run operators designed for any of the given targets."""
157
+ return any(
158
+ TargetSm.ensure(target).is_portable_to(self) for target in operator_targets
159
+ )
160
+
161
+ @property
162
+ def major(self) -> int:
163
+ """Get the major version of the target SM."""
164
+ return self.cc // 10
165
+
166
+ @property
167
+ def minor(self) -> int:
168
+ """Get the minor version of the target SM."""
169
+ return self.cc % 10
170
+
171
+ @staticmethod
172
+ def _parse_str(s: str) -> tuple[int, ArchPortability]:
173
+ """Parse a compile target string (e.g., "100", "100a", "100f")."""
174
+ portability_map = {
175
+ "a": ArchPortability.ArchConditional,
176
+ "f": ArchPortability.FamilyPortable,
177
+ }
178
+
179
+ suffix = s[-1]
180
+ try:
181
+ if suffix.isdigit():
182
+ return (int(s), ArchPortability.Portable)
183
+ return (int(s[:-1]), portability_map[suffix])
184
+ except (ValueError, KeyError):
185
+ raise ValueError(
186
+ f"Invalid TargetSm: '{s}'. Expected form: '100', '100a', '100f'"
187
+ ) from None
188
+
189
+ def __str__(self) -> str:
190
+ suffix = (
191
+ self.portability.name[0].lower()
192
+ if self.portability != ArchPortability.Portable
193
+ else ""
194
+ )
195
+ return f"{self.cc}{suffix}"
196
+
197
+ def __hash__(self) -> int:
198
+ return hash((self.cc, self.portability))
199
+
200
+
201
+ class ArchPortability(enum.Enum):
202
+ """Portability of a compiled Operator to other architectures/compute capabilities.
203
+
204
+ See:
205
+ https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/compute-capabilities.html#feature-set-compiler-targets
206
+ """
207
+
208
+ Portable = enum_auto()
209
+ """Portable to future architectures (e.g. sm_100 without "a" or "f")."""
210
+
211
+ FamilyPortable = enum_auto()
212
+ """Portable to future architectures within the same family (e.g. sm_100f).
213
+ Only applicable to Blackwell and newer architectures (cc >= 100).
214
+ """
215
+
216
+ ArchConditional = enum_auto()
217
+ """Not portable to any other architecture (e.g. sm_100a).
218
+ Only applicable to Hopper and newer architectures (cc >= 90).
219
+ """
@@ -0,0 +1,60 @@
1
+ # Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-3-Clause
3
+
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+
7
+ # 1. Redistributions of source code must retain the above copyright notice, this
8
+ # list of conditions and the following disclaimer.
9
+
10
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
11
+ # this list of conditions and the following disclaimer in the documentation
12
+ # and/or other materials provided with the distribution.
13
+
14
+ # 3. Neither the name of the copyright holder nor the names of its
15
+ # contributors may be used to endorse or promote products derived from
16
+ # this software without specific prior written permission.
17
+
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
+
29
+
30
+ from .base import (
31
+ Operand,
32
+ PerformanceControls,
33
+ RuntimeArguments,
34
+ )
35
+ from .epilogue import EpilogueArguments
36
+ from .gemm import GemmArguments, GemmProblemSize
37
+ from .grouped_gemm import GroupedGemmArguments
38
+ from .operand import (
39
+ DenseTensor,
40
+ ScaledOperand,
41
+ ScaleMode,
42
+ ScaleSwizzleMode,
43
+ )
44
+
45
+ __all__ = [
46
+ # Top-level arguments & constituents
47
+ "RuntimeArguments",
48
+ "PerformanceControls",
49
+ "EpilogueArguments",
50
+ # Various RuntimeArguments subtypes
51
+ "GemmArguments",
52
+ "GemmProblemSize",
53
+ "GroupedGemmArguments",
54
+ # Operands & related classes
55
+ "Operand",
56
+ "DenseTensor",
57
+ "ScaledOperand",
58
+ "ScaleMode",
59
+ "ScaleSwizzleMode",
60
+ ]
@@ -0,0 +1,134 @@
1
+ # Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-3-Clause
3
+
4
+ # Redistribution and use in source and binary forms, with or without
5
+ # modification, are permitted provided that the following conditions are met:
6
+
7
+ # 1. Redistributions of source code must retain the above copyright notice, this
8
+ # list of conditions and the following disclaimer.
9
+
10
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
11
+ # this list of conditions and the following disclaimer in the documentation
12
+ # and/or other materials provided with the distribution.
13
+
14
+ # 3. Neither the name of the copyright holder nor the names of its
15
+ # contributors may be used to endorse or promote products derived from
16
+ # this software without specific prior written permission.
17
+
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
+
29
+ from __future__ import annotations
30
+
31
+ import copy
32
+ from abc import ABC, abstractmethod
33
+ from dataclasses import dataclass, field, fields
34
+ from typing import Any, final, get_type_hints
35
+
36
+ from cutlass.operators.typing import NumericLike, TensorLike
37
+ from cutlass.operators.utils.dtype import to_cutlass_type
38
+ from cutlass.operators.utils.tensor import TensorWrapper
39
+
40
+
41
+ @dataclass
42
+ class RuntimeArguments:
43
+ """Describes the operands and all other arguments passed to an Operator at runtime.
44
+
45
+ It contains runtime operands (usually tensors) passed to the operation, as well as
46
+ any custom epilogue fusions, and runtime performance controls.
47
+
48
+ This is an abstract base class, whose subclass describes the operation type itself
49
+ (e.g. GemmArguments, GroupedGemmArguments, etc.). All operators that implement the same
50
+ operation type accept the same RuntimeArguments subclass.
51
+ """
52
+
53
+ performance: PerformanceControls | None = field(default=None, kw_only=True)
54
+ """Optional runtime performance controls passed to the Operator"""
55
+
56
+ def _validate(self):
57
+ """Checks that the arguments are valid.
58
+
59
+ This is run before all fields have been converted to TensorWrapper and cutlass.Numeric.
60
+ """
61
+
62
+ def __post_init__(self):
63
+ _convert_to_internal_types(self)
64
+
65
+
66
+ @dataclass
67
+ class PerformanceControls:
68
+ """Optional runtime performance controls passed to the Operator.
69
+
70
+ Some operators may support performance options that can be controlled at runtime.
71
+ This class is the general container for all such controls.
72
+ """
73
+
74
+
75
+ class Operand(ABC):
76
+ """Base class for all operands to Operators, which encapsulates one or more TensorLike objects.
77
+
78
+ In the most basic case, an Operand enacapsulates a single tensor.
79
+
80
+ In more complex cases, an Operand may encapsulate multiple tensors that encapsulate a single logical operand.
81
+ For instance, a :class:`~cutlass.operators.ScaledOperand` encapsulates a quantized and scale tensor, that together
82
+ reconstruct the logical value of the operand.
83
+ """
84
+
85
+ @final
86
+ def copy(self) -> Operand:
87
+ """Returns a copy of the operand. Does not copy the underlying tensor."""
88
+ return self.__copy__()
89
+
90
+ @abstractmethod
91
+ def __copy__(self) -> Operand:
92
+ """Returns a copy of the operand. Does not copy the underlying tensor."""
93
+ raise NotImplementedError
94
+
95
+ def _convert_to_internal_types(self, metadata: dict[str, Any] = None):
96
+ _convert_to_internal_types(self, metadata=metadata)
97
+
98
+
99
+ def _convert_to_internal_types(caller, metadata: dict[str, Any] = None):
100
+ """Converts fields of the caller to internal types.
101
+
102
+ Current fields that are converted:
103
+ * ``TensorLike`` -> ``TensorWrapper``
104
+ * ``NumericLike`` -> ``cutlass.Numeric``
105
+ * Classes that implement ``_convert_to_internal_types`` -> their internal types
106
+
107
+ Args:
108
+ caller (Any): The caller object to convert the fields of
109
+ metadata (dict[str, Any] | None): Additional metadata to be used for conversion
110
+ """
111
+ type_hints = get_type_hints(type(caller))
112
+ for f in fields(caller):
113
+ hint = type_hints.get(f.name)
114
+ value = getattr(caller, f.name)
115
+
116
+ global_metadata = {} if metadata is None else copy.deepcopy(metadata)
117
+ global_metadata.update(f.metadata)
118
+
119
+ if isinstance(value, TensorWrapper):
120
+ # No conversion needed
121
+ setattr(caller, f.name, value)
122
+ elif hint is TensorLike:
123
+ # Find all fields that are annotated as TensorLike,
124
+ # and wrap them in TensorWrapper
125
+ setattr(caller, f.name, TensorWrapper(value, **global_metadata))
126
+ elif hint is NumericLike:
127
+ # Find all fields that are annotated as NumericLike,
128
+ # and convert them to cutlass.Numeric
129
+ setattr(caller, f.name, to_cutlass_type(value))
130
+ elif hasattr(value, "_convert_to_internal_types"):
131
+ # If the field is an instance of a class that implements
132
+ # _convert_to_internal_types, convert it to internal types
133
+ value._convert_to_internal_types(metadata=global_metadata)
134
+ setattr(caller, f.name, value)