nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,61 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <exceptions.h>
11
+ #include <ir/interface_nodes.h>
12
+ #include <type.h>
13
+
14
+ namespace nvfuser {
15
+
16
+ struct TypePromotionConfig {
17
+ bool promote_integer_inputs_to_float = false;
18
+ // Checks the promoted type is either single or double.
19
+ bool require_full_precision_promoted = false;
20
+ };
21
+
22
+ namespace TypePromotion {
23
+
24
+ static const TypePromotionConfig comparison_op_config;
25
+ static const TypePromotionConfig default_op_config;
26
+ static const TypePromotionConfig float_op_config{
27
+ /* promote_integer_inputs_to_float */ true,
28
+ /* require_full_precision_promoted */ false};
29
+ static const TypePromotionConfig float_only_op_config{
30
+ /* promote_integer_inputs_to_float */ false,
31
+ /* require_full_precision_promoted */ true};
32
+
33
+ } // namespace TypePromotion
34
+
35
+ // Implements the the behavior of the following flags:
36
+ // - promote_inputs_to_common_dtype
37
+ // - promote_integer_inputs_to_float
38
+ DataType computeTypes(
39
+ const TypePromotionConfig& config,
40
+ const std::vector<Val*>& operands,
41
+ const bool cast_half_to_float = true);
42
+
43
+ // Computes the common dtype for the given operands
44
+ // Casts operands to common dtype if necessary
45
+ // Automatically cast FP16/BF16 dtype to Float
46
+ std::vector<Val*> promoteValues(
47
+ const TypePromotionConfig& config,
48
+ const std::vector<Val*>& operands);
49
+
50
+ std::vector<Val*> promoteValues(
51
+ const std::vector<Val*>& operands,
52
+ DataType common_type);
53
+
54
+ // Casts value to common dtype if necessary
55
+ // Avoid cast if value's dtype matches its dtype class
56
+ Val* optionalCast(DataType dtype, Val* v);
57
+
58
+ // Casts value to common dtype if necessary
59
+ Val* optionalCastStrict(DataType dtype, Val* v);
60
+
61
+ } // namespace nvfuser
@@ -0,0 +1,619 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <ATen/ATen.h>
11
+ #include <exceptions.h>
12
+ #include <torch/csrc/jit/ir/ir.h>
13
+ #include <torch/torch.h>
14
+ #include <visibility.h>
15
+
16
+ #include <debug.h>
17
+ #include <mma_type.h>
18
+ #include <tma.h>
19
+ #include <type.h>
20
+
21
+ #include <c10/core/thread_pool.h>
22
+ #include <deque>
23
+ #include <memory>
24
+ #include <optional>
25
+ #include <regex>
26
+ #include <sstream>
27
+ #include <string>
28
+ #include <type_traits>
29
+ #include <typeinfo>
30
+ #include <unordered_map>
31
+ #include <vector>
32
+
33
+ #define NVF_TORCH_VERSION_GREATER(major, minor, patch) \
34
+ TORCH_VERSION_MAJOR > major || \
35
+ (TORCH_VERSION_MAJOR == major && TORCH_VERSION_MINOR > minor || \
36
+ (TORCH_VERSION_MINOR == minor && TORCH_VERSION_PATCH > patch))
37
+
38
+ #define NVF_TORCH_VERSION_NO_LESS(major, minor, patch) \
39
+ TORCH_VERSION_MAJOR > major || \
40
+ (TORCH_VERSION_MAJOR == major && TORCH_VERSION_MINOR > minor || \
41
+ (TORCH_VERSION_MINOR == minor && TORCH_VERSION_PATCH >= patch))
42
+
43
+ //! IR header hierarchy
44
+ //! 1. ** utils.h ** - PolymorphicBase and NonCopyable
45
+ //! 2. ir/base_nodes.h - Statement, Expr, and Val
46
+ //! 3. ir/internal_base_nodes.h - IterDomain and TensorDomain
47
+ //! 4. ir/interface_nodes.h - TensorView and Scalar
48
+ //! 5. ir/internal_nodes.h ** - Any internal-only IR nodes
49
+
50
+ namespace nvfuser {
51
+
52
+ int getNumThreads();
53
+ c10::ThreadPool* getThreadPool();
54
+
55
+ std::string debug_str(const c10::IValue& val);
56
+ std::string debug_str(const at::Tensor& tensor);
57
+
58
+ bool is_cpu_scalar(const at::Tensor& tensor);
59
+
60
+ bool is_meta_scalar(const at::Tensor& tensor);
61
+
62
+ //! Find common device among tensor inputs. If no tensor inputs are found and
63
+ //! the selected_device argument is omitted, a default value of 0 is returned.
64
+ //! If no tensor inputs are found and selected_device is provided,
65
+ //! selected_device will be returned. If tensor inputs are found their devices
66
+ //! must match one another, and if selected_device is given they must match it
67
+ //! as well, otherwise -1 is returned.
68
+ int8_t getCommonDeviceCUDA(
69
+ const at::ArrayRef<c10::IValue>& inputs,
70
+ std::optional<int8_t> selected_device = std::nullopt);
71
+
72
+ int64_t getRegPerThreadGivenThreadsPerSM(int64_t threads_per_sm);
73
+
74
+ int64_t getThreadsPerSMGivenRegPerThread(int64_t reg_per_thread);
75
+
76
+ // Check if fallback path should be used which will dispatch to eager mode if
77
+ // any errors are encountered. Helpful for debugging.
78
+ bool useFallback();
79
+
80
+ //! Ceil integer division
81
+ constexpr int64_t ceilDiv(int64_t dividend, int64_t divisor) {
82
+ return (dividend + divisor - 1) / divisor;
83
+ }
84
+
85
+ constexpr int64_t roundUpToMultiple(int64_t dividend, int64_t divisor) {
86
+ return ceilDiv(dividend, divisor) * divisor;
87
+ }
88
+
89
+ //! Simple mixin for suppressing copy & move operations, ex:
90
+ //!
91
+ //! class Foo : public NonCopyable {
92
+ //! ...
93
+ //! };
94
+ //!
95
+ class NonCopyable {
96
+ public:
97
+ NonCopyable() = default;
98
+
99
+ // No copy/move semantics
100
+ NonCopyable(const NonCopyable&) = delete;
101
+ NonCopyable& operator=(const NonCopyable&) = delete;
102
+ };
103
+
104
+ //! A generic root for a hierarchy of polymorphic classes:
105
+ //! - It ensures virtual destructors
106
+ //! - Provides the base->as<Derived>() and node->isA<T>() notation
107
+ class PolymorphicBase {
108
+ public:
109
+ virtual ~PolymorphicBase() = default;
110
+
111
+ // Replacement for static_cast<T*>(ptr): ptr->as<T>()
112
+ // (checked in DEBUG builds)
113
+ template <class T>
114
+ T* as() {
115
+ #if defined(NDEBUG) && !defined(NVFUSER_EXPLICIT_ERROR_CHECK)
116
+ auto downcast_ptr = static_cast<T*>(this);
117
+ #else
118
+ auto downcast_ptr = dynamic_cast<T*>(this);
119
+ NVF_ERROR(downcast_ptr != nullptr);
120
+ #endif // defined(NDEBUG) && !defined(NVFUSER_EXPLICIT_ERROR_CHECK)
121
+ return downcast_ptr;
122
+ }
123
+
124
+ template <class T>
125
+ const T* as() const {
126
+ #if defined(NDEBUG) && !defined(NVFUSER_EXPLICIT_ERROR_CHECK)
127
+ auto downcast_ptr = static_cast<const T*>(this);
128
+ #else
129
+ auto downcast_ptr = dynamic_cast<const T*>(this);
130
+ NVF_ERROR(downcast_ptr != nullptr);
131
+ #endif // defined(NDEBUG) && !defined(NVFUSER_EXPLICIT_ERROR_CHECK)
132
+ return downcast_ptr;
133
+ }
134
+
135
+ //! Check if the runtime type is T (or derived from T)
136
+ //!
137
+ //! \note Don't use this for conditional casts. Instead, use:
138
+ //!
139
+ //! if (auto t = dynamic_cast<T>(p)) { ... }
140
+ //!
141
+ //! instead of:
142
+ //!
143
+ //! if (p->isA<T>()) { auto t = p->as<T>(); ... }
144
+ //!
145
+ template <class T>
146
+ bool isA() const {
147
+ return dynamic_cast<const T*>(this) != nullptr;
148
+ }
149
+
150
+ //! Check if the runtime type is strictly T. Returns false for classes
151
+ //! derived from T
152
+ template <class T>
153
+ bool isStrictlyA() const {
154
+ return typeid(*this) == typeid(T);
155
+ }
156
+
157
+ private:
158
+ template <int> // unused template argument
159
+ bool isOneOf() const {
160
+ return false;
161
+ }
162
+ template <int, class T1, class... T>
163
+ bool isOneOf() const {
164
+ return isA<T1>() || isOneOf<0, T...>();
165
+ }
166
+ template <int> // unused template argument
167
+ bool isStrictlyOneOf() const {
168
+ return false;
169
+ }
170
+ template <int, class T1, class... T>
171
+ bool isStrictlyOneOf() const {
172
+ return isStrictlyA<T1>() || isStrictlyOneOf<0, T...>();
173
+ }
174
+
175
+ public:
176
+ //! Check if the runtime type is one of the given types (or derived from
177
+ //! one of the given types)
178
+ template <class... T>
179
+ bool isOneOf() const {
180
+ return isOneOf<0, T...>();
181
+ }
182
+
183
+ //! Check if the runtime type is strictly one of the given types. Derived
184
+ //! types not in the given list does not count.
185
+ template <class... T>
186
+ bool isStrictlyOneOf() const {
187
+ return isStrictlyOneOf<0, T...>();
188
+ }
189
+ };
190
+
191
+ template <class T, std::enable_if_t<std::is_enum<T>::value, bool> = true>
192
+ constexpr unsigned int switch_pair(T t1, T t2) {
193
+ constexpr unsigned int _WORD_SHIFT = 16;
194
+ return ((unsigned int)t1 << _WORD_SHIFT) + (unsigned int)t2;
195
+ }
196
+
197
+ std::vector<int64_t> getTensorSizes(at::TensorTypePtr const& tensor_type);
198
+
199
+ //! Return a sorted list of keys of an unordered map so that it can be
200
+ //! iterated deterministically
201
+ template <typename KeyType, typename ValueType, typename Cmp>
202
+ std::vector<KeyType> getSortedKeys(
203
+ const std::unordered_map<KeyType, ValueType>& map,
204
+ Cmp cmp) {
205
+ std::vector<KeyType> keys(map.size());
206
+ auto keys_it = keys.begin();
207
+ for (const auto& kv : map) {
208
+ *keys_it = kv.first;
209
+ ++keys_it;
210
+ }
211
+ std::sort(keys.begin(), keys.end(), cmp);
212
+ return keys;
213
+ }
214
+
215
+ // Based on https://stackoverflow.com/a/9154394
216
+ template <typename T>
217
+ static auto hasToStringHelper(int)
218
+ -> decltype(std::declval<typename std::remove_pointer<T>::type>().toString(), std::true_type{});
219
+
220
+ template <typename>
221
+ static auto hasToStringHelper(long) -> std::false_type;
222
+
223
+ template <class T>
224
+ struct hasToString : decltype(hasToStringHelper<T>(0)) {};
225
+
226
+ // If T::toString() is defined, use the toString() to get its
227
+ // string. If std::stringstream << is defined for T, then use <<.
228
+ // otherwise, just returns a "<attr>"
229
+
230
+ template <typename T>
231
+ struct Printer {
232
+ static std::string toString(const T& value) {
233
+ if constexpr (hasToString<T>()) {
234
+ if constexpr (std::is_pointer<T>::value) {
235
+ return value->toString();
236
+ } else {
237
+ return value.toString();
238
+ }
239
+ } else {
240
+ return "<attr>";
241
+ }
242
+ }
243
+ };
244
+
245
+ #if 0
246
+
247
+ // Waiting for C++20....
248
+
249
+ #include <concepts>
250
+
251
+ template<typename T>
252
+ concept Printable = requires(T a)
253
+ {
254
+ { std::stringstream{} << a } -> std::convertible_to<std::stringstream>;
255
+ };
256
+
257
+ template <Printable T>
258
+ struct Printer<T> {
259
+ static std::string toString(const T& value) {
260
+ std::stringstream ss;
261
+ ss << value;
262
+ return ss.str();
263
+ }
264
+ };
265
+
266
+ #else
267
+
268
+ #define SPECIALIZE_PRINTER(T) \
269
+ template <> \
270
+ struct Printer<T> { \
271
+ static std::string toString(const T& value) { \
272
+ std::stringstream ss; \
273
+ ss << value; \
274
+ return ss.str(); \
275
+ } \
276
+ }
277
+
278
+ SPECIALIZE_PRINTER(bool);
279
+ SPECIALIZE_PRINTER(int);
280
+ SPECIALIZE_PRINTER(std::string);
281
+ using ConstCharStar = const char*;
282
+ SPECIALIZE_PRINTER(ConstCharStar);
283
+ using VoidStar = void*;
284
+ SPECIALIZE_PRINTER(VoidStar);
285
+ SPECIALIZE_PRINTER(uint32_t);
286
+ SPECIALIZE_PRINTER(int64_t);
287
+ SPECIALIZE_PRINTER(uint64_t);
288
+ SPECIALIZE_PRINTER(DataType);
289
+ SPECIALIZE_PRINTER(MemoryType);
290
+ SPECIALIZE_PRINTER(UnaryOpType);
291
+ SPECIALIZE_PRINTER(BinaryOpType);
292
+ SPECIALIZE_PRINTER(TernaryOpType);
293
+ SPECIALIZE_PRINTER(LoadStoreOpType);
294
+ SPECIALIZE_PRINTER(CircularBufferLoopStage);
295
+ SPECIALIZE_PRINTER(tma::TensorMapInterleave);
296
+ SPECIALIZE_PRINTER(tma::TensorMapL2Promotion);
297
+ SPECIALIZE_PRINTER(tma::TensorMapFloatOOBFill);
298
+ SPECIALIZE_PRINTER(MmaInputSmemSwizzle);
299
+ SPECIALIZE_PRINTER(SwizzleType);
300
+ SPECIALIZE_PRINTER(Swizzle2DType);
301
+ SPECIALIZE_PRINTER(SwizzleMode);
302
+ SPECIALIZE_PRINTER(std::vector<int>);
303
+ SPECIALIZE_PRINTER(std::vector<uint32_t>);
304
+ SPECIALIZE_PRINTER(std::vector<int64_t>);
305
+ SPECIALIZE_PRINTER(std::vector<uint64_t>);
306
+ SPECIALIZE_PRINTER(std::optional<bool>);
307
+
308
+ #undef SPECIALIZE_PRINTER
309
+
310
+ #endif // if 0
311
+
312
+ // Stringification with delimiter
313
+ template <typename Iterator>
314
+ std::string toDelimitedString(
315
+ Iterator first,
316
+ Iterator last,
317
+ std::string delim = ", ") {
318
+ std::stringstream ss;
319
+ bool first_val = true;
320
+ for (auto it = first; it != last; ++it) {
321
+ if (!first_val) {
322
+ ss << delim;
323
+ }
324
+ ss << Printer<typename Iterator::value_type>::toString(*it);
325
+ first_val = false;
326
+ }
327
+ return ss.str();
328
+ }
329
+
330
+ template <typename Printable>
331
+ std::string toDelimitedString(
332
+ const std::vector<Printable>& vec,
333
+ std::string delim = ", ") {
334
+ return toDelimitedString(vec.begin(), vec.end(), delim);
335
+ }
336
+
337
+ template <typename Printable>
338
+ std::string toDelimitedString(
339
+ std::initializer_list<Printable> list,
340
+ std::string delim = ", ") {
341
+ // toDelimitedString(list.begin(), list.end(), delim) doesn't work out of the
342
+ // box, because list.begin() returns a Printable* not an iterator.
343
+ return toDelimitedString(std::vector<Printable>(list), delim);
344
+ }
345
+
346
+ template <typename Printable>
347
+ std::string toDelimitedString(
348
+ const std::deque<Printable>& dq,
349
+ std::string delim = ", ") {
350
+ return toDelimitedString(dq.begin(), dq.end(), delim);
351
+ }
352
+
353
+ template <typename Printable>
354
+ std::string toDelimitedString(
355
+ const std::unordered_set<Printable>& set,
356
+ std::string delim = ", ") {
357
+ return toDelimitedString(set.begin(), set.end(), delim);
358
+ }
359
+
360
+ template <int64_t index, int64_t stop, int64_t step, typename func_t>
361
+ void unrolled_for(func_t fun) {
362
+ if constexpr (index < stop) {
363
+ fun(std::integral_constant<int64_t, index>());
364
+ unrolled_for<index + step, stop>(fun);
365
+ }
366
+ }
367
+
368
+ template <int64_t index, int64_t stop, typename func_t>
369
+ void unrolled_for(func_t fun) {
370
+ unrolled_for<index, stop, 1>(fun);
371
+ }
372
+
373
+ template <int64_t stop, typename func_t>
374
+ void unrolled_for(func_t fun) {
375
+ unrolled_for<0, stop>(fun);
376
+ }
377
+
378
+ template <typename... Args>
379
+ std::string toDelimitedString(
380
+ const std::tuple<Args...>& args,
381
+ std::string delim = ", ") {
382
+ std::stringstream ss;
383
+ bool first_val = true;
384
+ unrolled_for<sizeof...(Args)>([&](auto i) {
385
+ if (!first_val) {
386
+ ss << delim;
387
+ }
388
+ auto item = std::get<decltype(i)::value>(args);
389
+ ss << Printer<decltype(item)>::toString(item);
390
+ first_val = false;
391
+ });
392
+ return ss.str();
393
+ }
394
+
395
+ template <typename ContainerOfStatement>
396
+ std::string toDelimitedInlineString(
397
+ const ContainerOfStatement& container,
398
+ std::string delim = ", ") {
399
+ std::stringstream ss;
400
+ bool first_val = true;
401
+ for (const auto& item : container) {
402
+ if (!first_val) {
403
+ ss << delim;
404
+ }
405
+ ss << item->toInlineString();
406
+ first_val = false;
407
+ }
408
+ return ss.str();
409
+ }
410
+
411
+ class DebugPrintScope {
412
+ public:
413
+ template <typename... Args>
414
+ DebugPrintScope(std::string name, Args... args) : name_(std::move(name)) {
415
+ debug() << "Entering " << name_ << "("
416
+ << toDelimitedString(std::forward_as_tuple(args...)) << ")"
417
+ << std::endl;
418
+ }
419
+
420
+ ~DebugPrintScope() {
421
+ debug() << "Leaving " << name_;
422
+ if (!return_.empty()) {
423
+ debug() << " returning " << return_;
424
+ }
425
+ if (!file_.empty()) {
426
+ debug() << " at " << file_;
427
+ }
428
+ if (line_ >= 0) {
429
+ debug() << ":" << line_;
430
+ }
431
+ debug() << std::endl;
432
+ }
433
+
434
+ template <typename T>
435
+ void setReturn(const T& ret, std::string file = "", int64_t line = -1) {
436
+ return_ = Printer<std::decay_t<T>>::toString(ret);
437
+ file_ = std::move(file);
438
+ line_ = line;
439
+ }
440
+
441
+ private:
442
+ // The name of the scope, as specified as the first argument of
443
+ // DEBUG_PRINT_SCOPE_NAME. If using DEBUG_PRINT_SCOPE, then this is __func__.
444
+ std::string name_;
445
+
446
+ // Return value and location of the return statement.
447
+ // Note that the recording of the return value is not automatic. The function
448
+ // needs to be manually instrumented to replace `return XXX;` with
449
+ // `RECORD_AND_RETURN(XXX)` to record the return value.
450
+ std::string return_;
451
+ std::string file_;
452
+ int64_t line_ = -1;
453
+ };
454
+
455
+ // Debug printing the entering and leaving of a function. The given arguments
456
+ // will be printed when entering the function.
457
+ //
458
+ // Note: ##__VA_ARGS__ is not C++ stardard, but it should work on gcc and clang.
459
+ // Compared to __VA_ARGS__, ##__VA_ARGS__ automatically remove the preceding
460
+ // comma when empty, allowing empty variadic parameters. If using other
461
+ // compiler, please use DebugPrintScope directly without this macro.
462
+ #define DEBUG_PRINT_SCOPE_NAME(name, ...) \
463
+ std::unique_ptr<DebugPrintScope> _debug_print_scope; \
464
+ if (isDebugDumpEnabled(DebugDumpOption::FunctionTrace)) { \
465
+ auto enabled = getDebugDumpArguments(DebugDumpOption::FunctionTrace); \
466
+ for (auto pattern : enabled) { \
467
+ std::regex re(pattern); \
468
+ if (std::regex_match(name, re)) { \
469
+ _debug_print_scope = \
470
+ std::make_unique<DebugPrintScope>(name, ##__VA_ARGS__); \
471
+ break; \
472
+ } \
473
+ } \
474
+ }
475
+
476
+ #define DEBUG_PRINT_SCOPE(...) DEBUG_PRINT_SCOPE_NAME(__func__, ##__VA_ARGS__)
477
+
478
+ #define DEBUG_LOG(...) \
479
+ if (_debug_print_scope) { \
480
+ debug() << "[" << __FILE__ << ":" << __LINE__ << "] " \
481
+ << to_str("", ##__VA_ARGS__) << std::endl; \
482
+ }
483
+
484
+ // Record the return value and return it.
485
+ #define RECORD_AND_RETURN(ret) \
486
+ if (_debug_print_scope) { \
487
+ _debug_print_scope->setReturn(ret, __FILE__, __LINE__); \
488
+ } \
489
+ return ret
490
+
491
+ // Computes the index type required.
492
+ // Made into a class w/ state to allow reuse with
493
+ // different tensors and without needing to pass an allocated
494
+ // vector of size+stride
495
+ class KernelIndexTypeCompute {
496
+ // Save 1 more bit besides the sign bit to be conservative
497
+ static constexpr int64_t most_positive_int32_index =
498
+ std::numeric_limits<int>::max() / 2;
499
+
500
+ public:
501
+ // Updates counters and returns current reqd mode
502
+ inline PrimDataType addDim(int64_t size, int64_t stride) {
503
+ if (size > 1) {
504
+ NVF_ERROR(stride >= 0, "Negative stride is not supported: ", stride);
505
+ if (stride > 0) {
506
+ // Accumulate positive stride
507
+ tensor_most_positive_index_ += (size - 1) * stride;
508
+ }
509
+ }
510
+ return getType();
511
+ }
512
+
513
+ inline PrimDataType getType() const {
514
+ if (tensor_most_positive_index_ > most_positive_int32_index) {
515
+ return PrimDataType::Int;
516
+ } else {
517
+ return PrimDataType::Int32;
518
+ }
519
+ }
520
+
521
+ private:
522
+ int64_t tensor_most_positive_index_ = 0;
523
+ };
524
+
525
+ template <typename>
526
+ struct is_std_vector : std::false_type {};
527
+
528
+ template <typename T, typename A>
529
+ struct is_std_vector<std::vector<T, A>> : std::true_type {};
530
+
531
+ template <typename T>
532
+ constexpr auto is_std_vector_v = is_std_vector<T>::value;
533
+
534
+ //! Alter an existing hash in order to combine it with a new hash in a way that
535
+ //! is order-dependent and spreads bits over the entire range of a size_t.
536
+ //! Inspired by boost::hash_combine. See https://stackoverflow.com/q/35985960
537
+ inline void hashCombine(size_t& hash, size_t new_hash) {
538
+ hash ^= new_hash + 0x9e3779b9 + (hash << 6) + (hash >> 2);
539
+ }
540
+
541
+ //! A wrapper to std::getenv. env_name is prepended with NVFUSER_.
542
+ NVF_API char* getNvFuserEnv(const char* env_name);
543
+
544
+ // Returns the mapped value or the default.
545
+ template <typename K, typename V>
546
+ const V& getOrDefault(
547
+ const std::unordered_map<K, V>& map,
548
+ const K& key,
549
+ const V& default_value = V()) {
550
+ const auto i = map.find(key);
551
+ return i == map.end() ? default_value : i->second;
552
+ }
553
+
554
+ size_t deviceAvailableSharedMemoryBytes();
555
+
556
+ inline int64_t wrapDim(int64_t dim, int64_t ndim) {
557
+ if (dim < 0) {
558
+ dim += ndim;
559
+ }
560
+ NVF_CHECK(
561
+ dim >= 0 && dim < ndim,
562
+ "Tried to access out of boundary index ",
563
+ dim,
564
+ ". total index: ",
565
+ ndim);
566
+ return dim;
567
+ }
568
+
569
+ // This is the same as the pow utility included in runtime/helpers.cu. It is
570
+ // included here to facilitate matching host-side computation.
571
+ template <typename T>
572
+ T pow(T a, T b) {
573
+ if (b < 0) {
574
+ if (a == 1) {
575
+ return 1;
576
+ } else if (a == -1) {
577
+ auto negative = (-b) % static_cast<T>(2);
578
+ return negative ? -1 : 1;
579
+ } else {
580
+ return 0;
581
+ }
582
+ } else {
583
+ T result = 1;
584
+ while (b) {
585
+ if (b & 1) {
586
+ result *= a;
587
+ }
588
+ b /= 2;
589
+ a *= a;
590
+ }
591
+ return result;
592
+ }
593
+ }
594
+
595
+ // Returns true if given number is power of 2
596
+ constexpr bool isPowOf2(int64_t x) {
597
+ return x > 1 && (x & (x - 1)) == 0;
598
+ }
599
+
600
+ template <typename T>
601
+ using MaybeUniqueOwningPtr = dynamic_type::
602
+ DynamicType<dynamic_type::NoContainers, T*, std::unique_ptr<T>>;
603
+
604
+ template <typename T>
605
+ void checkAllEqual(std::initializer_list<T> elements) {
606
+ for (const auto& element : elements) {
607
+ NVF_CHECK(
608
+ element == *elements.begin(),
609
+ "Expected all elements to be equal, but found ",
610
+ element,
611
+ " and ",
612
+ *elements.begin(),
613
+ " in [",
614
+ toDelimitedString(elements),
615
+ "]");
616
+ }
617
+ }
618
+
619
+ } // namespace nvfuser