nvidia-cutlass-operators 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cutlass/operators/__init__.py +136 -0
- cutlass/operators/arch.py +219 -0
- cutlass/operators/arguments/__init__.py +60 -0
- cutlass/operators/arguments/base.py +134 -0
- cutlass/operators/arguments/epilogue.py +196 -0
- cutlass/operators/arguments/gemm.py +212 -0
- cutlass/operators/arguments/grouped_gemm.py +107 -0
- cutlass/operators/arguments/operand.py +396 -0
- cutlass/operators/artifact.py +55 -0
- cutlass/operators/base.py +318 -0
- cutlass/operators/config.py +210 -0
- cutlass/operators/fusion/__init__.py +126 -0
- cutlass/operators/fusion/activation.py +233 -0
- cutlass/operators/fusion/backend/__init__.py +45 -0
- cutlass/operators/fusion/backend/emitter_base.py +165 -0
- cutlass/operators/fusion/backend/sm100_emitter.py +151 -0
- cutlass/operators/fusion/backend/sm100_nodes.py +140 -0
- cutlass/operators/fusion/backend/sm80_emitter.py +46 -0
- cutlass/operators/fusion/backend/sm80_nodes.py +247 -0
- cutlass/operators/fusion/backend/sm90_emitter.py +97 -0
- cutlass/operators/fusion/backend/sm90_nodes.py +327 -0
- cutlass/operators/fusion/epilogue.py +88 -0
- cutlass/operators/fusion/evt_ops.py +116 -0
- cutlass/operators/fusion/frontend/__init__.py +41 -0
- cutlass/operators/fusion/frontend/frontend_base.py +303 -0
- cutlass/operators/fusion/frontend/python_ast.py +261 -0
- cutlass/operators/fusion/ir/__init__.py +75 -0
- cutlass/operators/fusion/ir/c_types.py +247 -0
- cutlass/operators/fusion/ir/compute_nodes.py +99 -0
- cutlass/operators/fusion/ir/dag_ir.py +258 -0
- cutlass/operators/fusion/ir/layout_algorithm.py +362 -0
- cutlass/operators/fusion/ir/layout_nodes.py +351 -0
- cutlass/operators/fusion/ir/load_nodes.py +312 -0
- cutlass/operators/fusion/ir/node.py +330 -0
- cutlass/operators/fusion/ir/store_nodes.py +276 -0
- cutlass/operators/fusion/ir/tensor.py +155 -0
- cutlass/operators/fusion/library.py +430 -0
- cutlass/operators/fusion/passes/__init__.py +59 -0
- cutlass/operators/fusion/passes/graph_drawer.py +133 -0
- cutlass/operators/fusion/passes/pass_argument_type.py +136 -0
- cutlass/operators/fusion/passes/pass_dag_2_tree.py +176 -0
- cutlass/operators/fusion/passes/pass_fix_element_d.py +86 -0
- cutlass/operators/fusion/passes/pass_get_impl.py +93 -0
- cutlass/operators/fusion/passes/pass_layout_elimination.py +230 -0
- cutlass/operators/fusion/passes/pass_manager.py +185 -0
- cutlass/operators/fusion/passes/pass_no_op_elimination.py +59 -0
- cutlass/operators/fusion/passes/pass_preprocess_red.py +96 -0
- cutlass/operators/fusion/passes/pass_shape_type_propagation.py +60 -0
- cutlass/operators/fusion/passes/smem_size_calculator.py +363 -0
- cutlass/operators/fusion/passes/util.py +46 -0
- cutlass/operators/fusion/pycute/__init__.py +36 -0
- cutlass/operators/fusion/pycute/int_tuple.py +229 -0
- cutlass/operators/fusion/pycute/layout.py +409 -0
- cutlass/operators/fusion/pycute/swizzle.py +133 -0
- cutlass/operators/fusion/pycute/typing.py +42 -0
- cutlass/operators/manifest.py +153 -0
- cutlass/operators/metadata/__init__.py +73 -0
- cutlass/operators/metadata/base.py +170 -0
- cutlass/operators/metadata/design/__init__.py +45 -0
- cutlass/operators/metadata/design/base.py +109 -0
- cutlass/operators/metadata/design/sm100.py +111 -0
- cutlass/operators/metadata/design/tile_scheduler.py +91 -0
- cutlass/operators/metadata/epilogue.py +71 -0
- cutlass/operators/metadata/operand_constraints.py +255 -0
- cutlass/operators/metadata/operands/__init__.py +37 -0
- cutlass/operators/metadata/operands/base.py +70 -0
- cutlass/operators/metadata/operands/gemm.py +90 -0
- cutlass/operators/metadata/operands/grouped_gemm.py +100 -0
- cutlass/operators/mma.py +166 -0
- cutlass/operators/providers/__init__.py +99 -0
- cutlass/operators/providers/cutedsl/__init__.py +66 -0
- cutlass/operators/providers/cutedsl/evt/common_efc.py +2189 -0
- cutlass/operators/providers/cutedsl/evt/converter.py +470 -0
- cutlass/operators/providers/cutedsl/gemm/__init__.py +43 -0
- cutlass/operators/providers/cutedsl/gemm/implementations/operator_helpers.py +73 -0
- cutlass/operators/providers/cutedsl/gemm/implementations/scheduler.py +589 -0
- cutlass/operators/providers/cutedsl/gemm/implementations/sm100_contiguous_offset_2d3d_dense_gemm_impl.py +1976 -0
- cutlass/operators/providers/cutedsl/gemm/implementations/sm100_dense_blockscaled_static_persistent_impl.py +2170 -0
- cutlass/operators/providers/cutedsl/gemm/implementations/sm100_mixed_input_impl.py +2058 -0
- cutlass/operators/providers/cutedsl/gemm/implementations/sm100_persistent_impl.py +1277 -0
- cutlass/operators/providers/cutedsl/gemm/implementations/sm100_persistent_preferred_cluster_impl.py +1176 -0
- cutlass/operators/providers/cutedsl/gemm/implementations/sm100_static_persistent_efc_impl.py +1686 -0
- cutlass/operators/providers/cutedsl/gemm/implementations/sm100_tgv_gemm_impl.py +1085 -0
- cutlass/operators/providers/cutedsl/gemm/implementations/sm80_tensorop_gemm_impl.py +814 -0
- cutlass/operators/providers/cutedsl/gemm/implementations/sm90_static_persistent_impl.py +1236 -0
- cutlass/operators/providers/cutedsl/gemm/sm100_contiguous_offset_2d3d_dense_gemm.py +451 -0
- cutlass/operators/providers/cutedsl/gemm/sm100_dense_blockscaled_static_persistent.py +606 -0
- cutlass/operators/providers/cutedsl/gemm/sm100_mixed_input.py +910 -0
- cutlass/operators/providers/cutedsl/gemm/sm100_persistent.py +494 -0
- cutlass/operators/providers/cutedsl/gemm/sm100_persistent_preferred_cluster.py +671 -0
- cutlass/operators/providers/cutedsl/gemm/sm100_static_persistent_efc.py +506 -0
- cutlass/operators/providers/cutedsl/gemm/sm100_tgv_gemm.py +420 -0
- cutlass/operators/providers/cutedsl/gemm/sm80_tensorop_gemm.py +316 -0
- cutlass/operators/providers/cutedsl/gemm/sm90_static_persistent.py +441 -0
- cutlass/operators/providers/cutedsl/integration_utils/__init__.py +40 -0
- cutlass/operators/providers/cutedsl/integration_utils/builders.py +97 -0
- cutlass/operators/providers/cutedsl/integration_utils/mma.py +49 -0
- cutlass/operators/providers/cutedsl/operator.py +90 -0
- cutlass/operators/providers/provider.py +153 -0
- cutlass/operators/status.py +68 -0
- cutlass/operators/typing.py +73 -0
- cutlass/operators/utils/__init__.py +39 -0
- cutlass/operators/utils/common.py +84 -0
- cutlass/operators/utils/device.py +247 -0
- cutlass/operators/utils/dtype.py +213 -0
- cutlass/operators/utils/gemm.py +96 -0
- cutlass/operators/utils/generate.py +100 -0
- cutlass/operators/utils/layout.py +80 -0
- cutlass/operators/utils/tensor.py +575 -0
- cutlass/operators/workspace.py +67 -0
- nvidia_cutlass_operators-0.1.0.dist-info/METADATA +198 -0
- nvidia_cutlass_operators-0.1.0.dist-info/RECORD +115 -0
- nvidia_cutlass_operators-0.1.0.dist-info/WHEEL +5 -0
- nvidia_cutlass_operators-0.1.0.dist-info/licenses/LICENSE.txt +27 -0
- nvidia_cutlass_operators-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
3
|
+
|
|
4
|
+
# Redistribution and use in source and binary forms, with or without
|
|
5
|
+
# modification, are permitted provided that the following conditions are met:
|
|
6
|
+
|
|
7
|
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
8
|
+
# list of conditions and the following disclaimer.
|
|
9
|
+
|
|
10
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
11
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
12
|
+
# and/or other materials provided with the distribution.
|
|
13
|
+
|
|
14
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
15
|
+
# contributors may be used to endorse or promote products derived from
|
|
16
|
+
# this software without specific prior written permission.
|
|
17
|
+
|
|
18
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
19
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
20
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
21
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
22
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
23
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
24
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
25
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
26
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
27
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
from __future__ import annotations
|
|
31
|
+
|
|
32
|
+
from typing import TYPE_CHECKING
|
|
33
|
+
|
|
34
|
+
from cutlass.operators import mma, workspace
|
|
35
|
+
from cutlass.operators.arch import ArchPortability, TargetSm
|
|
36
|
+
from cutlass.operators.arguments import (
|
|
37
|
+
DenseTensor,
|
|
38
|
+
EpilogueArguments,
|
|
39
|
+
GemmArguments,
|
|
40
|
+
GroupedGemmArguments,
|
|
41
|
+
Operand,
|
|
42
|
+
PerformanceControls,
|
|
43
|
+
RuntimeArguments,
|
|
44
|
+
ScaledOperand,
|
|
45
|
+
ScaleMode,
|
|
46
|
+
ScaleSwizzleMode,
|
|
47
|
+
)
|
|
48
|
+
from cutlass.operators.artifact import CompiledArtifact
|
|
49
|
+
from cutlass.operators.base import Operator
|
|
50
|
+
from cutlass.operators.config import GlobalOptions
|
|
51
|
+
from cutlass.operators.manifest import Manifest
|
|
52
|
+
from cutlass.operators.metadata import OperatorMetadata
|
|
53
|
+
from cutlass.operators.providers import (
|
|
54
|
+
CuTeDSLProvider,
|
|
55
|
+
Provider,
|
|
56
|
+
available_providers,
|
|
57
|
+
register_provider,
|
|
58
|
+
)
|
|
59
|
+
from cutlass.operators.status import Status
|
|
60
|
+
from cutlass.operators.typing import NumericLike, TensorLike
|
|
61
|
+
|
|
62
|
+
if TYPE_CHECKING:
|
|
63
|
+
from collections.abc import Callable
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
__version__ = "0.1.0"
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def get_operators(
|
|
70
|
+
args: RuntimeArguments | None = None,
|
|
71
|
+
metadata_filter: Callable[[OperatorMetadata], bool] | None = None,
|
|
72
|
+
target_sm: TargetSm | str | None = None,
|
|
73
|
+
providers: list[Provider] | None = None,
|
|
74
|
+
) -> list[Operator]:
|
|
75
|
+
"""Return Operators that match the given arguments, metadata filter, and target.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
args (RuntimeArguments | None): Runtime arguments describing the
|
|
79
|
+
operator invocation (e.g. :class:`GemmArguments`). When ``None``,
|
|
80
|
+
no argument-based filtering is applied.
|
|
81
|
+
metadata_filter (Callable[[OperatorMetadata], bool] | None): An
|
|
82
|
+
optional Callable that takes OperatorMetadata as input and returns
|
|
83
|
+
a boolean indicating if it should be considered for inclusion in
|
|
84
|
+
results. The result is an intersection of operators filtered by the
|
|
85
|
+
callable and by other parameters passed to this method.
|
|
86
|
+
target_sm (TargetSm | str | None): Compute capability to target (e.g.
|
|
87
|
+
``"100a"`` or a :class:`TargetSm` instance). Filters Operators that
|
|
88
|
+
cannot run on this target.
|
|
89
|
+
providers (list[Provider] | None): Optional list of Providers to
|
|
90
|
+
restrict discovery to (e.g. ``[ops.CuTeDSLProvider]``).
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
list[Operator]: Operators matching all filters.
|
|
94
|
+
"""
|
|
95
|
+
return Manifest.get_operators(args, metadata_filter, target_sm, providers)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
__all__ = [
|
|
99
|
+
# Central class exposing the Operator interface
|
|
100
|
+
"Operator",
|
|
101
|
+
# Runtime arguments
|
|
102
|
+
"RuntimeArguments",
|
|
103
|
+
"PerformanceControls",
|
|
104
|
+
"EpilogueArguments",
|
|
105
|
+
"GemmArguments",
|
|
106
|
+
"GroupedGemmArguments",
|
|
107
|
+
# Operands
|
|
108
|
+
"Operand",
|
|
109
|
+
"DenseTensor",
|
|
110
|
+
"ScaledOperand",
|
|
111
|
+
"ScaleMode",
|
|
112
|
+
"ScaleSwizzleMode",
|
|
113
|
+
# Operator Discovery
|
|
114
|
+
"get_operators",
|
|
115
|
+
"Manifest",
|
|
116
|
+
"OperatorMetadata",
|
|
117
|
+
# Misc. core types
|
|
118
|
+
"Status",
|
|
119
|
+
"CompiledArtifact",
|
|
120
|
+
# Arch
|
|
121
|
+
"TargetSm",
|
|
122
|
+
"ArchPortability",
|
|
123
|
+
# Typing markers
|
|
124
|
+
"TensorLike",
|
|
125
|
+
"NumericLike",
|
|
126
|
+
# Configuration
|
|
127
|
+
"GlobalOptions",
|
|
128
|
+
# Provider management
|
|
129
|
+
"Provider", # base class for all Providers
|
|
130
|
+
"CuTeDSLProvider", # CuTeDSLProvider holding registry of kernels written in CuTe DSL
|
|
131
|
+
"available_providers", # global list of available, registered Providers
|
|
132
|
+
"register_provider", # decorator to register a new Provider
|
|
133
|
+
# Submodules to access less commonly used public surfaces
|
|
134
|
+
"mma",
|
|
135
|
+
"workspace",
|
|
136
|
+
]
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
3
|
+
|
|
4
|
+
# Redistribution and use in source and binary forms, with or without
|
|
5
|
+
# modification, are permitted provided that the following conditions are met:
|
|
6
|
+
|
|
7
|
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
8
|
+
# list of conditions and the following disclaimer.
|
|
9
|
+
|
|
10
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
11
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
12
|
+
# and/or other materials provided with the distribution.
|
|
13
|
+
|
|
14
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
15
|
+
# contributors may be used to endorse or promote products derived from
|
|
16
|
+
# this software without specific prior written permission.
|
|
17
|
+
|
|
18
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
19
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
20
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
21
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
22
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
23
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
24
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
25
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
26
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
27
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import enum
|
|
32
|
+
from dataclasses import dataclass
|
|
33
|
+
from enum import auto as enum_auto
|
|
34
|
+
from typing import TYPE_CHECKING
|
|
35
|
+
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from collections.abc import Iterable
|
|
38
|
+
|
|
39
|
+
from cutlass.operators.metadata import DesignMetadata, OperandsMetadata
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(init=False)
|
|
43
|
+
class TargetSm:
|
|
44
|
+
"""Target compute capability & portability to compile for."""
|
|
45
|
+
|
|
46
|
+
cc: int
|
|
47
|
+
portability: ArchPortability
|
|
48
|
+
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
sm_str: str | None = None,
|
|
52
|
+
*,
|
|
53
|
+
cc: int | None = None,
|
|
54
|
+
portability: ArchPortability | None = None,
|
|
55
|
+
):
|
|
56
|
+
"""Create a ``TargetSm``.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
sm_str (str | None): String like "100", "100a", or "100f".
|
|
60
|
+
cc (int | None): Compute capability as integer (e.g., 100).
|
|
61
|
+
portability (ArchPortability | None): Architecture portability. If not
|
|
62
|
+
provided, defaults to ArchConditional if cc >= 90, else Portable.
|
|
63
|
+
|
|
64
|
+
Note:
|
|
65
|
+
Must use either sm_str or cc/portability, not both.
|
|
66
|
+
|
|
67
|
+
Example:
|
|
68
|
+
>>> TargetSm("100a")
|
|
69
|
+
>>> TargetSm(cc=100)
|
|
70
|
+
>>> TargetSm(cc=100, portability=ArchPortability.ArchConditional)
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
ValueError: If both sm_str and cc/portability are provided, or neither.
|
|
74
|
+
ValueError: If portability is not valid for the given cc.
|
|
75
|
+
"""
|
|
76
|
+
if sm_str is not None:
|
|
77
|
+
if cc is not None or portability is not None:
|
|
78
|
+
raise ValueError("Cannot specify sm_str with cc/portability")
|
|
79
|
+
self.cc, self.portability = self._parse_str(sm_str)
|
|
80
|
+
elif cc is not None:
|
|
81
|
+
self.cc = cc
|
|
82
|
+
self.portability = portability or (
|
|
83
|
+
ArchPortability.ArchConditional
|
|
84
|
+
if cc >= 90
|
|
85
|
+
else ArchPortability.Portable
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
raise ValueError("Must specify either sm_str or cc/portability")
|
|
89
|
+
|
|
90
|
+
if self.portability == ArchPortability.ArchConditional and self.cc < 90:
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"TargetSm {self.__str__()} is invalid. ArchConditional targets must be Hopper (sm90) or newer"
|
|
93
|
+
)
|
|
94
|
+
if self.portability == ArchPortability.FamilyPortable and self.cc < 100:
|
|
95
|
+
raise ValueError(
|
|
96
|
+
f"TargetSm {self.__str__()} is invalid. FamilyPortable targets must be Blackwell (sm100) or newer"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def ensure(cls, value: TargetSm | str | None) -> TargetSm | None:
|
|
101
|
+
"""Coerce a string to ``TargetSm``, or return as-is if already one."""
|
|
102
|
+
if value is None or isinstance(value, cls):
|
|
103
|
+
return value
|
|
104
|
+
if isinstance(value, str):
|
|
105
|
+
return cls(value)
|
|
106
|
+
raise TypeError(f"Expected TargetSm or str. Got {value} of type {type(value)}")
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def get_supported_targets(
|
|
110
|
+
design: DesignMetadata, operands: OperandsMetadata
|
|
111
|
+
) -> list[TargetSm]:
|
|
112
|
+
"""Get the supported targets for a given design and operands.
|
|
113
|
+
|
|
114
|
+
Raises:
|
|
115
|
+
ValueError: If ``design`` does not have a ``mma_instruction_type`` field.
|
|
116
|
+
"""
|
|
117
|
+
# For now, just return the supported instruction.
|
|
118
|
+
# It can be extended to any other rules that are introduced.
|
|
119
|
+
metadata_to_check = []
|
|
120
|
+
|
|
121
|
+
if not hasattr(design, "mma_instruction_type"):
|
|
122
|
+
raise ValueError(
|
|
123
|
+
"Cannot determine supported targets with missing mma_instruction_type field in "
|
|
124
|
+
f"design metadata: {design}"
|
|
125
|
+
)
|
|
126
|
+
mma_supported_targets = design.mma_instruction_type.supported_targets(
|
|
127
|
+
design, operands
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
metadata_to_check.append(mma_supported_targets)
|
|
131
|
+
|
|
132
|
+
# check against tile scheduler supported targets
|
|
133
|
+
if hasattr(design, "tile_scheduler") and design.tile_scheduler:
|
|
134
|
+
tile_scheduler_supported_targets = design.tile_scheduler.supported_targets(
|
|
135
|
+
design, operands
|
|
136
|
+
)
|
|
137
|
+
metadata_to_check.append(tile_scheduler_supported_targets)
|
|
138
|
+
return list(set.intersection(*map(set, metadata_to_check)))
|
|
139
|
+
|
|
140
|
+
def is_portable_to(self, other: TargetSm | str) -> bool:
|
|
141
|
+
"""Check if this target can compile/run on architecture described by `other` TargetSm."""
|
|
142
|
+
other = TargetSm.ensure(other)
|
|
143
|
+
match self.portability:
|
|
144
|
+
case ArchPortability.Portable:
|
|
145
|
+
return self.cc <= other.cc
|
|
146
|
+
case ArchPortability.FamilyPortable:
|
|
147
|
+
return self.major == other.major and self.cc <= other.cc
|
|
148
|
+
case ArchPortability.ArchConditional:
|
|
149
|
+
return self == other
|
|
150
|
+
case _:
|
|
151
|
+
raise NotImplementedError
|
|
152
|
+
|
|
153
|
+
def supports_operators_from(
|
|
154
|
+
self, operator_targets: Iterable[TargetSm | str]
|
|
155
|
+
) -> bool:
|
|
156
|
+
"""Check if this target can compile/run operators designed for any of the given targets."""
|
|
157
|
+
return any(
|
|
158
|
+
TargetSm.ensure(target).is_portable_to(self) for target in operator_targets
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def major(self) -> int:
|
|
163
|
+
"""Get the major version of the target SM."""
|
|
164
|
+
return self.cc // 10
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def minor(self) -> int:
|
|
168
|
+
"""Get the minor version of the target SM."""
|
|
169
|
+
return self.cc % 10
|
|
170
|
+
|
|
171
|
+
@staticmethod
|
|
172
|
+
def _parse_str(s: str) -> tuple[int, ArchPortability]:
|
|
173
|
+
"""Parse a compile target string (e.g., "100", "100a", "100f")."""
|
|
174
|
+
portability_map = {
|
|
175
|
+
"a": ArchPortability.ArchConditional,
|
|
176
|
+
"f": ArchPortability.FamilyPortable,
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
suffix = s[-1]
|
|
180
|
+
try:
|
|
181
|
+
if suffix.isdigit():
|
|
182
|
+
return (int(s), ArchPortability.Portable)
|
|
183
|
+
return (int(s[:-1]), portability_map[suffix])
|
|
184
|
+
except (ValueError, KeyError):
|
|
185
|
+
raise ValueError(
|
|
186
|
+
f"Invalid TargetSm: '{s}'. Expected form: '100', '100a', '100f'"
|
|
187
|
+
) from None
|
|
188
|
+
|
|
189
|
+
def __str__(self) -> str:
|
|
190
|
+
suffix = (
|
|
191
|
+
self.portability.name[0].lower()
|
|
192
|
+
if self.portability != ArchPortability.Portable
|
|
193
|
+
else ""
|
|
194
|
+
)
|
|
195
|
+
return f"{self.cc}{suffix}"
|
|
196
|
+
|
|
197
|
+
def __hash__(self) -> int:
|
|
198
|
+
return hash((self.cc, self.portability))
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class ArchPortability(enum.Enum):
|
|
202
|
+
"""Portability of a compiled Operator to other architectures/compute capabilities.
|
|
203
|
+
|
|
204
|
+
See:
|
|
205
|
+
https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/compute-capabilities.html#feature-set-compiler-targets
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
Portable = enum_auto()
|
|
209
|
+
"""Portable to future architectures (e.g. sm_100 without "a" or "f")."""
|
|
210
|
+
|
|
211
|
+
FamilyPortable = enum_auto()
|
|
212
|
+
"""Portable to future architectures within the same family (e.g. sm_100f).
|
|
213
|
+
Only applicable to Blackwell and newer architectures (cc >= 100).
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
ArchConditional = enum_auto()
|
|
217
|
+
"""Not portable to any other architecture (e.g. sm_100a).
|
|
218
|
+
Only applicable to Hopper and newer architectures (cc >= 90).
|
|
219
|
+
"""
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
3
|
+
|
|
4
|
+
# Redistribution and use in source and binary forms, with or without
|
|
5
|
+
# modification, are permitted provided that the following conditions are met:
|
|
6
|
+
|
|
7
|
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
8
|
+
# list of conditions and the following disclaimer.
|
|
9
|
+
|
|
10
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
11
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
12
|
+
# and/or other materials provided with the distribution.
|
|
13
|
+
|
|
14
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
15
|
+
# contributors may be used to endorse or promote products derived from
|
|
16
|
+
# this software without specific prior written permission.
|
|
17
|
+
|
|
18
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
19
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
20
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
21
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
22
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
23
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
24
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
25
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
26
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
27
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
from .base import (
|
|
31
|
+
Operand,
|
|
32
|
+
PerformanceControls,
|
|
33
|
+
RuntimeArguments,
|
|
34
|
+
)
|
|
35
|
+
from .epilogue import EpilogueArguments
|
|
36
|
+
from .gemm import GemmArguments, GemmProblemSize
|
|
37
|
+
from .grouped_gemm import GroupedGemmArguments
|
|
38
|
+
from .operand import (
|
|
39
|
+
DenseTensor,
|
|
40
|
+
ScaledOperand,
|
|
41
|
+
ScaleMode,
|
|
42
|
+
ScaleSwizzleMode,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
__all__ = [
|
|
46
|
+
# Top-level arguments & constituents
|
|
47
|
+
"RuntimeArguments",
|
|
48
|
+
"PerformanceControls",
|
|
49
|
+
"EpilogueArguments",
|
|
50
|
+
# Various RuntimeArguments subtypes
|
|
51
|
+
"GemmArguments",
|
|
52
|
+
"GemmProblemSize",
|
|
53
|
+
"GroupedGemmArguments",
|
|
54
|
+
# Operands & related classes
|
|
55
|
+
"Operand",
|
|
56
|
+
"DenseTensor",
|
|
57
|
+
"ScaledOperand",
|
|
58
|
+
"ScaleMode",
|
|
59
|
+
"ScaleSwizzleMode",
|
|
60
|
+
]
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
|
3
|
+
|
|
4
|
+
# Redistribution and use in source and binary forms, with or without
|
|
5
|
+
# modification, are permitted provided that the following conditions are met:
|
|
6
|
+
|
|
7
|
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
8
|
+
# list of conditions and the following disclaimer.
|
|
9
|
+
|
|
10
|
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
11
|
+
# this list of conditions and the following disclaimer in the documentation
|
|
12
|
+
# and/or other materials provided with the distribution.
|
|
13
|
+
|
|
14
|
+
# 3. Neither the name of the copyright holder nor the names of its
|
|
15
|
+
# contributors may be used to endorse or promote products derived from
|
|
16
|
+
# this software without specific prior written permission.
|
|
17
|
+
|
|
18
|
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
19
|
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
20
|
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
21
|
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
22
|
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
23
|
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
24
|
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
25
|
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
26
|
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
27
|
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import copy
|
|
32
|
+
from abc import ABC, abstractmethod
|
|
33
|
+
from dataclasses import dataclass, field, fields
|
|
34
|
+
from typing import Any, final, get_type_hints
|
|
35
|
+
|
|
36
|
+
from cutlass.operators.typing import NumericLike, TensorLike
|
|
37
|
+
from cutlass.operators.utils.dtype import to_cutlass_type
|
|
38
|
+
from cutlass.operators.utils.tensor import TensorWrapper
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class RuntimeArguments:
|
|
43
|
+
"""Describes the operands and all other arguments passed to an Operator at runtime.
|
|
44
|
+
|
|
45
|
+
It contains runtime operands (usually tensors) passed to the operation, as well as
|
|
46
|
+
any custom epilogue fusions, and runtime performance controls.
|
|
47
|
+
|
|
48
|
+
This is an abstract base class, whose subclass describes the operation type itself
|
|
49
|
+
(e.g. GemmArguments, GroupedGemmArguments, etc.). All operators that implement the same
|
|
50
|
+
operation type accept the same RuntimeArguments subclass.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
performance: PerformanceControls | None = field(default=None, kw_only=True)
|
|
54
|
+
"""Optional runtime performance controls passed to the Operator"""
|
|
55
|
+
|
|
56
|
+
def _validate(self):
|
|
57
|
+
"""Checks that the arguments are valid.
|
|
58
|
+
|
|
59
|
+
This is run before all fields have been converted to TensorWrapper and cutlass.Numeric.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __post_init__(self):
|
|
63
|
+
_convert_to_internal_types(self)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class PerformanceControls:
|
|
68
|
+
"""Optional runtime performance controls passed to the Operator.
|
|
69
|
+
|
|
70
|
+
Some operators may support performance options that can be controlled at runtime.
|
|
71
|
+
This class is the general container for all such controls.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class Operand(ABC):
|
|
76
|
+
"""Base class for all operands to Operators, which encapsulates one or more TensorLike objects.
|
|
77
|
+
|
|
78
|
+
In the most basic case, an Operand enacapsulates a single tensor.
|
|
79
|
+
|
|
80
|
+
In more complex cases, an Operand may encapsulate multiple tensors that encapsulate a single logical operand.
|
|
81
|
+
For instance, a :class:`~cutlass.operators.ScaledOperand` encapsulates a quantized and scale tensor, that together
|
|
82
|
+
reconstruct the logical value of the operand.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
@final
|
|
86
|
+
def copy(self) -> Operand:
|
|
87
|
+
"""Returns a copy of the operand. Does not copy the underlying tensor."""
|
|
88
|
+
return self.__copy__()
|
|
89
|
+
|
|
90
|
+
@abstractmethod
|
|
91
|
+
def __copy__(self) -> Operand:
|
|
92
|
+
"""Returns a copy of the operand. Does not copy the underlying tensor."""
|
|
93
|
+
raise NotImplementedError
|
|
94
|
+
|
|
95
|
+
def _convert_to_internal_types(self, metadata: dict[str, Any] = None):
|
|
96
|
+
_convert_to_internal_types(self, metadata=metadata)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _convert_to_internal_types(caller, metadata: dict[str, Any] = None):
|
|
100
|
+
"""Converts fields of the caller to internal types.
|
|
101
|
+
|
|
102
|
+
Current fields that are converted:
|
|
103
|
+
* ``TensorLike`` -> ``TensorWrapper``
|
|
104
|
+
* ``NumericLike`` -> ``cutlass.Numeric``
|
|
105
|
+
* Classes that implement ``_convert_to_internal_types`` -> their internal types
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
caller (Any): The caller object to convert the fields of
|
|
109
|
+
metadata (dict[str, Any] | None): Additional metadata to be used for conversion
|
|
110
|
+
"""
|
|
111
|
+
type_hints = get_type_hints(type(caller))
|
|
112
|
+
for f in fields(caller):
|
|
113
|
+
hint = type_hints.get(f.name)
|
|
114
|
+
value = getattr(caller, f.name)
|
|
115
|
+
|
|
116
|
+
global_metadata = {} if metadata is None else copy.deepcopy(metadata)
|
|
117
|
+
global_metadata.update(f.metadata)
|
|
118
|
+
|
|
119
|
+
if isinstance(value, TensorWrapper):
|
|
120
|
+
# No conversion needed
|
|
121
|
+
setattr(caller, f.name, value)
|
|
122
|
+
elif hint is TensorLike:
|
|
123
|
+
# Find all fields that are annotated as TensorLike,
|
|
124
|
+
# and wrap them in TensorWrapper
|
|
125
|
+
setattr(caller, f.name, TensorWrapper(value, **global_metadata))
|
|
126
|
+
elif hint is NumericLike:
|
|
127
|
+
# Find all fields that are annotated as NumericLike,
|
|
128
|
+
# and convert them to cutlass.Numeric
|
|
129
|
+
setattr(caller, f.name, to_cutlass_type(value))
|
|
130
|
+
elif hasattr(value, "_convert_to_internal_types"):
|
|
131
|
+
# If the field is an instance of a class that implements
|
|
132
|
+
# _convert_to_internal_types, convert it to internal types
|
|
133
|
+
value._convert_to_internal_types(metadata=global_metadata)
|
|
134
|
+
setattr(caller, f.name, value)
|