nvfuser-cu121-torch25 0.2.25.dev20250201__cp312-cp312-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-312-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +16 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
nvfuser/__init__.py ADDED
@@ -0,0 +1,618 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+
5
+ import logging
6
+ import os
7
+ import re
8
+ import sys
9
+ from typing import Callable
10
+ import warnings
11
+
12
+ import torch
13
+
14
+ # This is needed when libnvfuser.so is patched and doesn't have the pytorch library location available.
15
+ pytorch_lib_dir = os.path.join(os.path.dirname(torch.__file__), "lib")
16
+ if pytorch_lib_dir not in sys.path:
17
+ sys.path.append(pytorch_lib_dir)
18
+
19
+ # we need to import _C here to avoid confusing error message generated from failure in this python script ended up with
20
+ # complaining on `_C` not defined for `_C._FusionDefinition`
21
+ from . import _C
22
+ from ._C import * # noqa: F401,F403
23
+
24
+ from . import contrib # noqa: F401
25
+
26
+
27
+ logger = logging.getLogger("nvfuser")
28
+
29
+
30
+ # Register automatic serialization of Nvfuser cache hierarchy and cuda kernels.
31
+ def enable_automatic_serialization():
32
+ import atexit
33
+
34
+ atexit.register(_C.serialize)
35
+
36
+ # A separate process is created for each device in a distributed setting.
37
+ # Each FusionCache becomes associated with a single device.
38
+ # Automatic serialization saves a separate cache for each device.
39
+ # Set the FusionCache id to the ddp local rank.
40
+ env_var_ddp_local_rank = os.environ.get("LOCAL_RANK", None)
41
+ if env_var_ddp_local_rank is not None:
42
+ env_var_ddp_local_rank = int(env_var_ddp_local_rank)
43
+ _C.FusionCache.get(max_fusions := 8192, env_var_ddp_local_rank)
44
+
45
+
46
+ # Unregister automatic serialization of Nvfuser cache hierarchy and cuda kernels.
47
+ def disable_automatic_serialization():
48
+ import atexit
49
+
50
+ atexit.unregister(_C.serialize)
51
+
52
+
53
+ class FusionDefinition(_C._FusionDefinition):
54
+ def __init__(self, id=None, max_length=1024):
55
+ super(FusionDefinition, self).__init__(id, max_length)
56
+ self.profiled = False
57
+
58
+ def segment(self, inputs):
59
+ """
60
+ Decompose this FusionDefinition into a sequence of segment
61
+ FusionDefinitions.
62
+
63
+ This function runs the nvfuser segmentation algorithm and translates the
64
+ segments into their corresponding FusionDefinitions.
65
+
66
+ Args:
67
+ inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion.
68
+
69
+ Returns:
70
+ List[FusionDefinition]: The FusionDefinitions corresponding to the
71
+ sub-fusion segments of this FusionDefinition.
72
+ """
73
+ num_segments = self._setup_segmentation(inputs)
74
+ if num_segments == 1:
75
+ self._finalize_segmentation()
76
+ return []
77
+
78
+ # Track all segments for this FusionDefinition
79
+ self.segments = []
80
+
81
+ # Track map_segment_fid_to_original_fid for each segment
82
+ self.segment_index_space_maps = {}
83
+
84
+ # Track the last segment a value is used as an input
85
+ self.map_value_to_last_used_segment = {}
86
+
87
+ for idx in range(num_segments):
88
+ new_fd = FusionDefinition()
89
+ map_segment_fid_to_original_fid = self._build_segment(new_fd, idx)
90
+
91
+ for segment_input in new_fd.inputs():
92
+ original_input = map_segment_fid_to_original_fid[segment_input]
93
+ self.map_value_to_last_used_segment[original_input] = idx
94
+
95
+ self.segment_index_space_maps[new_fd] = map_segment_fid_to_original_fid
96
+ self.segments.append(new_fd)
97
+ self._finalize_segmentation()
98
+ return self.segments
99
+
100
+ def __enter__(self):
101
+ return self._setup_definition()
102
+
103
+ def __exit__(self, type, value, traceback):
104
+ try:
105
+ self._finalize_definition()
106
+ except Exception as err:
107
+ logger.exception(self._repro_error_str("defining"))
108
+ raise
109
+
110
+ def definition(self):
111
+ raise NotImplementedError("definition() should be implemented by child class!")
112
+
113
+ def _execute_segments(self, input_arguments, *, device=None, profile=False):
114
+ """
115
+ Run the sequence of FusionDefinition segments to generate the results
116
+ of this FusionDefinition.
117
+
118
+ This FusionDefinition acts an argument manager. It gathers input
119
+ arguments for the segments and stores their output results. After
120
+ running a segment, any redundant intermediate values, which are
121
+ unnecessary for any other segments, are deleted to save memory.
122
+
123
+ Args:
124
+ inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion.
125
+
126
+ Kwargs:
127
+ device (Optional[Union[int, str, torch.device]]): This is a hint to run
128
+ the Fusion on the given CUDA device. This is not typically
129
+ necessary, as the device is usually inferred from the locations
130
+ of input tensors. However, for some fusion definitions, no
131
+ tensors will be input (for example when all tensors are
132
+ generated with `full` or `uniform` ops). In these cases, we
133
+ must either tell NVFuser where to run the resulting kernel, or
134
+ let it default to 0. Note that passing this option providing
135
+ and input tensors that lie on another device is an error.
136
+ profile (bool): Captures a CUPTI based profile of a fusion.
137
+
138
+
139
+ Returns:
140
+ List[Tensor]: The output results for this FusionDefinition.
141
+ """
142
+ assert len(self.segments) > 0
143
+ assert len(self.segments) == len(self.segment_index_space_maps)
144
+
145
+ input_arguments_with_extents = [*input_arguments]
146
+ for a in input_arguments:
147
+ if type(a) is torch.Tensor:
148
+ input_arguments_with_extents.extend(a.size())
149
+
150
+ # Map inputs arguments to original fid
151
+ map_original_fid_to_value = {
152
+ fd_state: argument
153
+ for fd_state, argument in zip(
154
+ self.inputs() + self.extents(), input_arguments_with_extents
155
+ )
156
+ }
157
+
158
+ # Run all segments in correct order
159
+ for idx, segment in enumerate(self.segments):
160
+ segment_to_original_map = self.segment_index_space_maps[segment]
161
+
162
+ # Gather segment input arguments
163
+ segment_arguments = [
164
+ map_original_fid_to_value[segment_to_original_map[fd_state]]
165
+ for fd_state in segment.inputs()
166
+ ]
167
+
168
+ # Run segment
169
+ segment_outputs = segment.execute(
170
+ segment_arguments, device=device, profile=profile
171
+ )
172
+
173
+ # Update original fusion definition indices to outputs
174
+ for fd_state, output in zip(segment.outputs(), segment_outputs):
175
+ map_original_fid_to_value[segment_to_original_map[fd_state]] = output
176
+
177
+ # Destroy any arguments that are not used by future segments
178
+ for segment_input in segment.inputs():
179
+ original_input = segment_to_original_map[segment_input]
180
+ if (
181
+ original_input not in self.outputs()
182
+ and self.map_value_to_last_used_segment[original_input] == idx
183
+ ):
184
+ del map_original_fid_to_value[original_input]
185
+
186
+ # Map output fid to actual results
187
+ return [map_original_fid_to_value[fd_state] for fd_state in self.outputs()]
188
+
189
+ def execute(
190
+ self,
191
+ inputs,
192
+ *,
193
+ device=None,
194
+ override_user_schedule=False,
195
+ capture_debug_output=False,
196
+ print_repro=False,
197
+ profile=False,
198
+ save_repro_inputs=False,
199
+ _enable_options: list[str] = [],
200
+ _disable_options: list[str] = [],
201
+ ) -> list[torch.Tensor | DistributedTensor]:
202
+ """
203
+ Executes an nvFuser set of kernels for a given Fusion
204
+
205
+ The FusionDefinition will be executed on a single CUDA device.
206
+ Typically, which device to run on is determined by the devices where
207
+ the input tensors reside. However, if the Fusion is defined such that
208
+ none of the inputs are tensors, we are not able to infer a device from
209
+ the inputs. For example, the following FusionDefinition will be unable
210
+ to unambiguously infer the device of its output:
211
+
212
+ with FusionDefinition() as fd:
213
+ tv1 = fd.ops.full([5])
214
+ fd.add_output(tv1)
215
+
216
+ In that case, we default to selecting the first CUDA
217
+ device, i.e. `torch.device("cuda:0")`. This method enables selecting an
218
+ alternative preferred device.
219
+
220
+ Args:
221
+ inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion.
222
+
223
+ Kwargs:
224
+ device (Optional[Union[int, str, torch.device]]): This is a hint to run
225
+ the Fusion on the given CUDA device. This is not typically
226
+ necessary, as the device is usually inferred from the locations
227
+ of input tensors. However, for some fusion definitions, no
228
+ tensors will be input (for example when all tensors are
229
+ generated with `full` or `uniform` ops). In these cases, we
230
+ must either tell NVFuser where to run the resulting kernel, or
231
+ let it default to 0. Note that passing this option providing
232
+ and input tensors that lie on another device is an error.
233
+ override_user_schedule (bool): For a user defined schedule,
234
+ override with auto-generated schedule (default: False)
235
+ capture_debug_output (bool): Whether to capture any printed
236
+ debugging information as a string. If True, the string can be
237
+ retrieved after execution using :meth:`get_debug_output`. If False,
238
+ then that method will return None when called.
239
+ print_repro (bool): Prints a reproduction script to stdout.
240
+ profile (bool): Captures a CUPTI based profile of a fusion.
241
+ save_repro_inputs (bool): Saves the inputs for last_repro_script() to
242
+ provide a provide a reproduction script.
243
+ _enable_options/_disable_options (list): NVFUSER_ENABLE/DISABLE options to use.
244
+ This is an alternative to environment variables.
245
+ Note: Currently, we do not cache/store these options in the FusionCache which makes it
246
+ plausible to reuse kernels when executing the same fusion definition with different sets of options.
247
+ Reset the FusionCache manually to avoid inadvertent kernel reuse when between different sets of options.
248
+
249
+ Returns:
250
+ List[Tensor]
251
+ """
252
+ self.profiled = profile
253
+
254
+ if device is not None:
255
+ if not isinstance(device, torch.device):
256
+ device = torch.device(device)
257
+ assert (
258
+ device.type == "cuda"
259
+ ), "If device argument is passed it must be a CUDA device"
260
+ device = device.index
261
+
262
+ # if definition is not defined by a context manager, try a child class
263
+ if self.id() is None:
264
+ self._setup_definition()
265
+ self.definition()
266
+ self._finalize_definition()
267
+
268
+ defined_multidevice_schedule = hasattr(
269
+ self, "multidevice_schedule"
270
+ ) and isinstance(self.multidevice_schedule, Callable)
271
+ defined_schedule = hasattr(self, "schedule") and isinstance(
272
+ self.schedule, Callable
273
+ )
274
+ assert not (
275
+ defined_multidevice_schedule and defined_schedule
276
+ ), "I haven't tested what if both are defined. We don't plan to support this use case although it may just work."
277
+
278
+ if defined_multidevice_schedule:
279
+ # Unlike `schedule`, `multidevice_schedule` is designed for inter-device
280
+ # scheduling, The scheduling is done before concretization and therefore
281
+ # before pre-segmentation. `schedule` however assumes the FusionDefinition
282
+ # has been concretized and pre-segmented, and therefore requires
283
+ # `_setup_schedule` and `_finalize_schedule` to be called before and after.
284
+ #
285
+ # Note: there's a plan to embed multidevice schedules into FusionDefinition
286
+ # as annotating nodes. This may eventually replace `multidevice_schedule`.
287
+ self._setup_multidevice_schedule()
288
+ self.multidevice_schedule()
289
+ self._finalize_multidevice_schedule()
290
+
291
+ # If schedule is defined by child class and schedule is not defined for
292
+ # inputs, make a schedule.
293
+ if defined_schedule:
294
+ # Schedule fusion if it does not exist yet or profiling fusion
295
+ if profile or not self._exist_schedule(inputs):
296
+ self._setup_schedule(inputs, overwrite_existing_schedule=profile)
297
+ self.schedule()
298
+ self._finalize_schedule(inputs)
299
+
300
+ if save_repro_inputs:
301
+ from torch._subclasses.fake_tensor import FakeTensorMode
302
+
303
+ fake_mode = FakeTensorMode()
304
+ self.fake_inputs = [fake_mode.from_tensor(inp) for inp in inputs]
305
+
306
+ if hasattr(self, "segments") and len(self.segments) > 0:
307
+ return self._execute_segments(inputs, device=device, profile=profile)
308
+
309
+ try:
310
+ if print_repro:
311
+ print(self.repro_script_for(inputs))
312
+ if len(_enable_options) or len(_disable_options):
313
+ warnings.warn(
314
+ "Reset the FusionCache manually to avoid reusing kernels when re-executing the fusion definition with different options."
315
+ )
316
+
317
+ out_tensors: list[DistributedTensor] = self._execute(
318
+ inputs,
319
+ device=device,
320
+ override_user_schedule=override_user_schedule,
321
+ capture_debug_output=capture_debug_output,
322
+ profile=profile,
323
+ _enable_options=_enable_options,
324
+ _disable_options=_disable_options,
325
+ )
326
+ for i, out_tensor in enumerate(out_tensors):
327
+ if out_tensor.mesh.size == 0:
328
+ out_tensors[i] = out_tensor.local
329
+ return out_tensors
330
+ except Exception as err:
331
+ logger.exception(self._repro_error_str("executing", inputs))
332
+ raise
333
+
334
+ def debug_output(self):
335
+ """
336
+ Retrieve string of captured debug information from the previous execution.
337
+
338
+ Note that `capture_debug_output=True` must be passed to `execute()` in
339
+ order to enable capturing this output. Otherwise, this method will
340
+ return `None`.
341
+
342
+ Returns:
343
+ Optional[String] : the captured debug output for the previous call
344
+ to execute(). If the `capture_debug_output` argument to that call
345
+ was False, returns None. Otherwise, returns the output as a string.
346
+ """
347
+ return self._debug_output()
348
+
349
+ def from_pytorch(self, tensor, static_sizes=False):
350
+ """
351
+ Defines an nvfuser input tensor from a pytorch tensor and defaults
352
+ to definining a symbolic tensor for dynamic shape usage.
353
+
354
+ Args:
355
+ tensor (torch.Tensor): Input tensor to nvFuser
356
+ static_sizes (bool) : Interprets sizes as static rather than
357
+ as symbolic for dynamic shape usage
358
+
359
+ Returns:
360
+ nvfuser.Tensor
361
+ """
362
+ try:
363
+ from .pytorch_utils import torch_dtype_to_nvfuser_dtype
364
+ except ImportError:
365
+ raise ImportError("Unable to import pytorch_utils!")
366
+
367
+ if not tensor.is_cuda and len(tensor.size()) != 0:
368
+ raise ValueError("CPU non-scalar tensor is not supported!")
369
+
370
+ return self.define_tensor(
371
+ sizes=tensor.size(),
372
+ strides=tensor.stride(),
373
+ dtype=torch_dtype_to_nvfuser_dtype(tensor.dtype),
374
+ static_sizes=static_sizes,
375
+ is_cpu=tensor.is_cpu,
376
+ )
377
+
378
+ def fusion_ir(self):
379
+ """
380
+ Returns the uscheduled Fusion IR for the given definition that corresponds to all scheduled inputs.
381
+
382
+ Returns:
383
+ String
384
+ """
385
+ return self._fusion_ir()
386
+
387
+ def last_cuda_code(self, intrinsic_code=False, **kwargs):
388
+ """
389
+ Returns the Cuda Code for the last executed set of inputs
390
+
391
+ Args:
392
+ intrinsic_code (Bool): Include all the additional code required to run kernel(s). (default: False)
393
+
394
+ Kwargs:
395
+ override_user_schedule (Bool): For a user defined schedule, override with auto-generated schedule (default: False)
396
+
397
+ Returns:
398
+ String
399
+ """
400
+ override_user_schedule = kwargs.pop("override_user_schedule", False)
401
+ return self._last_cuda_code(intrinsic_code, override_user_schedule)
402
+
403
+ def cuda_code_for(self, inputs, intrinsic_code=False, **kwargs):
404
+ """
405
+ Returns the Cuda Code for the given inputs
406
+
407
+ Args:
408
+ inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion.
409
+ intrinsic_code (Bool): Include all the additional code required to run kernel(s). (default: False)
410
+
411
+ Kwargs:
412
+ override_user_schedule (Bool): For a user defined schedule, override with auto-generated schedule (default: False)
413
+
414
+ Returns:
415
+ String
416
+ """
417
+ override_user_schedule = kwargs.pop("override_user_schedule", False)
418
+ return self._cuda_code_for(inputs, intrinsic_code, override_user_schedule)
419
+
420
+ def last_scheduled_fusion_ir(self, tensor_transforms=False, **kwargs):
421
+ """
422
+ Returns the Scheduled Fusion IR for the last executed set of inputs
423
+
424
+ Args:
425
+ tensor_transforms (Bool): Include tensor transforms that were applied through scheduling. (default: False)
426
+
427
+ Kwargs:
428
+ override_user_schedule (Bool): For a user defined schedule, override with auto-generated schedule (default: False)
429
+
430
+ Returns:
431
+ String
432
+ """
433
+ override_user_schedule = kwargs.pop("override_user_schedule", False)
434
+ return self._last_scheduled_fusion_ir(tensor_transforms, override_user_schedule)
435
+
436
+ def scheduled_fusion_ir_for(self, inputs, tensor_transforms=False, **kwargs):
437
+ """
438
+ Returns the Scheduled Fusion IR for the last executed set of inputs
439
+
440
+ Args:
441
+ inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion.
442
+ tensor_transforms (Bool): Include tensor transforms that were applied through scheduling. (default: False)
443
+
444
+ Kwargs:
445
+ override_user_schedule (Bool): For a user defined schedule, override with auto-generated schedule (default: False)
446
+
447
+ Returns:
448
+ String
449
+ """
450
+ override_user_schedule = kwargs.pop("override_user_schedule", False)
451
+ return self._scheduled_fusion_ir_for(
452
+ inputs, tensor_transforms, override_user_schedule
453
+ )
454
+
455
+ def profile(self):
456
+ """
457
+ Returns the FusionProfile object from the CUPTI based FusionProfiler
458
+
459
+ Returns:
460
+ FusionProfile
461
+ """
462
+ if not self.profiled:
463
+ raise ValueError(
464
+ "The execute() method was not previously called with profiling enabled!"
465
+ )
466
+
467
+ fp = self._profile()
468
+
469
+ if fp.fusion_id < 0:
470
+ raise ValueError(
471
+ "Something went wrong with Fusion Profiling as an illegal fusion_id was returned! "
472
+ + str(fp.fusion_id)
473
+ )
474
+ if fp.segments < 1:
475
+ raise ValueError(
476
+ "Something went wrong with Fusion Profiling as no kernel segments were profiled!"
477
+ + str(fp.segments)
478
+ )
479
+
480
+ return fp
481
+
482
+ def last_repro_script(self) -> str:
483
+ assert (
484
+ self.fake_inputs is not None
485
+ ), "fd.last_repro_script() cannot provide a repro because fd.execute(inputs, save_repro_state=True) was not executed!"
486
+ script = self.repro_script_for(self.fake_inputs)
487
+ return script
488
+
489
+ def repro_script_for(self, inputs: list | None = None) -> str:
490
+ msg = "# CUDA devices:\n"
491
+ for i in range(torch.cuda.device_count()):
492
+ msg += f"# {i}: {torch.cuda.get_device_name(i)}\n"
493
+ msg += (
494
+ f"# torch version: {torch.__version__}\n"
495
+ f"# cuda version: {torch.version.cuda}\n"
496
+ f"# nvfuser version: {version()}\n"
497
+ "import torch\n"
498
+ "from nvfuser import FusionDefinition, DataType\n"
499
+ f"{self}"
500
+ "with FusionDefinition() as fd:\n"
501
+ f" nvfuser_fusion_id{self.id()}(fd)\n"
502
+ )
503
+ if inputs is not None:
504
+ msg += "\ninputs = [\n"
505
+ for i in inputs:
506
+ if isinstance(i, torch.Tensor):
507
+ if i.is_contiguous():
508
+ msg += f" torch.testing.make_tensor({tuple(i.size())}, dtype={i.dtype}, device='{i.device}'),\n"
509
+ else:
510
+ # max linear index determines number of elements to generate
511
+ sz = 1
512
+ for szi, stri in zip(i.size(), i.stride()):
513
+ if szi == 0:
514
+ sz = 0
515
+ break
516
+ sz += (szi - 1) * stri
517
+ if i.dtype.is_floating_point:
518
+ msg += (
519
+ f" torch.randn({sz}, dtype={i.dtype}, device='{i.device}')"
520
+ f".as_strided({tuple(i.size())}, {tuple(i.stride())}),\n"
521
+ )
522
+ else:
523
+ upper_bound = 2 if i.dtype == torch.bool else 10
524
+ msg += (
525
+ f" torch.randint(0, {upper_bound}, ({sz},), dtype={i.dtype}, device='{i.device}')"
526
+ f".as_strided({tuple(i.size())}, {tuple(i.stride())}),\n"
527
+ )
528
+ else:
529
+ input_as_string = str(i)
530
+ # `nan` and `inf` are stringified as is, which are not
531
+ # defined in Python. So we replace them with `float("nan")`
532
+ # and `float("inf")`. `-inf` is replaced with
533
+ # `-float("inf")`, which equals `float("-inf")`.
534
+ input_as_string = re.sub(
535
+ r"\binf\b", 'float("inf")', input_as_string
536
+ )
537
+ input_as_string = re.sub(
538
+ r"\bnan\b", 'float("nan")', input_as_string
539
+ )
540
+ msg += f" {input_as_string},\n"
541
+ msg += "]"
542
+ msg += "\nfd.execute(inputs)\n"
543
+
544
+ return msg
545
+
546
+ def _repro_error_str(self, section: str, inputs: list | None = None):
547
+ msg = (
548
+ f"An error occurred while {section} nvFuser FusionDefinition {self.id()}.\n"
549
+ "If you believe this is a bug or need assistance, please file an issue at "
550
+ "https://github.com/NVIDIA/Fuser/issues/new\n"
551
+ f"Here's a script to reproduce the error:\n"
552
+ "```python\n"
553
+ )
554
+ msg += self.repro_script_for(inputs)
555
+ msg += "```\n"
556
+ return msg
557
+
558
+ def validate(
559
+ self,
560
+ inputs: list[torch.Tensor],
561
+ reference_outputs: list[torch.Tensor],
562
+ kwargs=None,
563
+ ):
564
+ """
565
+ Validates the fusion outputs against the provided reference outputs, using variable tolerances determined based on datatype and reduction size.
566
+
567
+ Inputs:
568
+ inputs: A list of inputs expected by the fusion definition
569
+ reference_outputs: A list of reference outputs to validate against
570
+ """
571
+ fusion_outputs = self.execute(inputs)
572
+ assert len(fusion_outputs) == len(
573
+ reference_outputs
574
+ ), f"Expected {len(fusion_outputs)} reference outputs for validation."
575
+
576
+ tolerance_values = self.getValTolerances(inputs)
577
+ assert len(tolerance_values) == len(
578
+ fusion_outputs
579
+ ), f"Missing tolerance values, expected {len(fusion_outputs)}, got {len(tolerance_values)}"
580
+
581
+ for inx, fusion_output in enumerate(fusion_outputs):
582
+ atol, rtol = tolerance_values[inx]
583
+ reference_output = reference_outputs[inx]
584
+
585
+ assert (
586
+ reference_output.shape == fusion_output.shape
587
+ ), "Mismatch in reference and fusion output dimensions"
588
+ if torch.is_floating_point(fusion_output) or torch.is_complex(
589
+ fusion_output
590
+ ):
591
+ assert torch.allclose(
592
+ fusion_output, reference_output, atol=atol, rtol=rtol
593
+ ), f"Max error: {torch.abs(torch.max(fusion_output - reference_output))}, \
594
+ Absolute tolerance: {atol}, Relative tolerance: {rtol}"
595
+
596
+ else:
597
+ assert torch.equal(
598
+ fusion_output, reference_output
599
+ ), "Mismatch in reference and fusion output values, datatype is not float/complex."
600
+
601
+
602
+ from .nvfuser_version import __version__
603
+
604
+
605
+ def version():
606
+ r"""returns nvfuser version in format of a string 'm.n.p+git[7d-sha]'.
607
+
608
+ We strip the git[7d-sha] and convert the string to
609
+ `nvfuser_version.Version` for comparison. e.g. you can use it as:
610
+ import nvfuser
611
+ print(nvfuser.version()) # 0.0.1+git21df524
612
+ nvfuser.version() == '0.0.1` # True
613
+ nvfuser.version() > '0.0.0` # True
614
+
615
+ from nvfuser_version import Version
616
+ nvfuser.version() < Version('1.0.0') # True
617
+ """
618
+ return __version__
nvfuser/__init__.pyi ADDED
@@ -0,0 +1,4 @@
1
+ from typing import List
2
+
3
+
4
+ def compute_contiguity(sizes, strides) -> List[bool]: ...
@@ -0,0 +1,9 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ from . import nn
5
+
6
+
7
+ __all__ = [
8
+ "nn",
9
+ ]
@@ -0,0 +1,13 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ from .normalization import InstanceNorm1dNVFuser
5
+ from .normalization import InstanceNorm2dNVFuser
6
+ from .normalization import InstanceNorm3dNVFuser
7
+
8
+
9
+ __all__ = [
10
+ "InstanceNorm1dNVFuser",
11
+ "InstanceNorm2dNVFuser",
12
+ "InstanceNorm3dNVFuser",
13
+ ]