nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,92 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <array>
11
+ #include <unordered_map>
12
+ #include <utility>
13
+
14
+ #include <ATen/ArrayRef.h>
15
+
16
+ #include <expr_evaluator.h>
17
+ #include <fusion.h>
18
+ #include <ir/interface_nodes.h>
19
+ #include <iter_visitor.h>
20
+ #include <runtime/executor_params.h>
21
+ #include <type.h>
22
+
23
+ namespace nvfuser {
24
+
25
+ struct ValidationConstants {
26
+ // Tolerances generated from randn + add + sum fusion
27
+ // compared against double precision
28
+ std::array<std::array<double, 2>, 20> sum_tolerances_float = {
29
+ {{4, 1.68222e-06}, {8, 2.23704e-06}, {16, 2.95788e-06},
30
+ {32, 4.4778e-06}, {64, 6.75395e-06}, {128, 8.57934e-06},
31
+ {256, 1.30594e-05}, {512, 2.19122e-05}, {1024, 3.3451e-05},
32
+ {2048, 5.78476e-05}, {4096, 0.000108292}, {8192, 0.00012207},
33
+ {16384, 0.000136882}, {32768, 0.000248561}, {65536, 0.000407594},
34
+ {131072, 0.000500901}, {262144, 0.000923019}, {524288, 0.00156909},
35
+ {1048576, 0.00223107}, {2097152, 0.00343043}}};
36
+
37
+ // Tolerances generated from randn + add + sum fusion
38
+ // compared against double precision
39
+ std::array<std::array<double, 2>, 20> sum_tolerances_half = {
40
+ {{4, 0.00390625}, {8, 0.0078125}, {16, 0.0078125},
41
+ {32, 0.0155334}, {64, 0.0156269}, {128, 0.0312042},
42
+ {256, 0.0312548}, {512, 0.0619979}, {1024, 0.0625103},
43
+ {2048, 0.124686}, {4096, 0.12501}, {8192, 0.24945},
44
+ {16384, 0.250049}, {32768, 0.498946}, {65536, 0.500071},
45
+ {131072, 0.985087}, {262144, 1.00006}, {524288, 1.99234},
46
+ {1048576, 2.00032}, {2097152, 3.99073}}};
47
+
48
+ double base_half_abs_tol = -1;
49
+ double base_half_rel_tol = -1;
50
+ double base_float_abs_tol = -1;
51
+ double base_float_rel_tol = -1;
52
+ };
53
+
54
+ // Returns abs and relative values to use for validation.
55
+ std::pair<double, double> getTolerance(
56
+ DataType dtype,
57
+ int64_t reduction_size,
58
+ const ValidationConstants& tolerances);
59
+
60
+ class ReductionSizeMapper : private IterVisitor {
61
+ public:
62
+ //! Runs through the fusion and determines how many reductions were performed
63
+ //! to compute each tensorview.
64
+ static std::unordered_map<TensorView*, int64_t> computeReductionSizes(
65
+ Fusion* fusion,
66
+ ExpressionEvaluator& expr_eval);
67
+
68
+ private:
69
+ ReductionSizeMapper(Fusion* fusion, ExpressionEvaluator& expr_eval);
70
+
71
+ int64_t getReductionSize(const TensorView* tv);
72
+
73
+ void dispatch(Expr* expr) override;
74
+
75
+ using IterVisitor::handle;
76
+
77
+ std::unordered_map<TensorView*, int64_t> reduction_map;
78
+ ExpressionEvaluator& expr_eval_;
79
+ };
80
+
81
+ ExpressionEvaluator bindInputsAndLaunchParams(
82
+ Fusion* fusion,
83
+ const at::ArrayRef<c10::IValue>& aten_inputs,
84
+ const LaunchParams& launch_constraints);
85
+
86
+ std::vector<std::pair<double, double>> get_val_constants(
87
+ Fusion* fusion,
88
+ const at::ArrayRef<c10::IValue>& aten_inputs,
89
+ const LaunchParams& lparams = LaunchParams(),
90
+ const ValidationConstants& tolerances = ValidationConstants());
91
+
92
+ } // namespace nvfuser
@@ -0,0 +1,31 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ #include <ir/all_nodes.h>
11
+
12
+ namespace nvfuser {
13
+
14
+ struct VectorizedSetInfo {
15
+ //! Producer of a vectorized set
16
+ TensorView* producer_tv = nullptr;
17
+ //! Consumer of a vectorized set
18
+ TensorView* consumer_tv = nullptr;
19
+ //! Number of elements to vectorize
20
+ int64_t word_size = -1;
21
+ //! Vectorized domain
22
+ IterDomain* vectorized_loop_id = nullptr;
23
+ //! Right-most allocation dependent domain of the loop domain for consumer
24
+ IterDomain* vectorized_consumer_alloc_id = nullptr;
25
+ //! Right-most allocation dependent domain of the loop domain for producer
26
+ IterDomain* vectorized_producer_alloc_id = nullptr;
27
+ //! All of the dependent allocation domains that are contiguously merged
28
+ std::unordered_set<IterDomain*> contig_alloc_ids;
29
+ };
30
+
31
+ } // namespace nvfuser
@@ -0,0 +1,21 @@
1
+ // clang-format off
2
+ /*
3
+ * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ */
7
+ // clang-format on
8
+ #pragma once
9
+
10
+ /// Defines the NVF_API macro, which should be added on methods or classes
11
+ /// that are used outside of nvFuser. See doc/dev/visibility.md for details.
12
+
13
+ #if defined _WIN32 || defined __CYGWIN__
14
+ #ifdef BUILDING_DLL
15
+ #define NVF_API __declspec(dllexport)
16
+ #else
17
+ #define NVF_API __declspec(dllimport)
18
+ #endif
19
+ #else
20
+ #define NVF_API __attribute__((visibility("default")))
21
+ #endif
Binary file
@@ -0,0 +1,69 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ from typing import Any
5
+ from .version import _version_str
6
+
7
+ __all__ = ["NvfuserVersion", "Version"]
8
+
9
+
10
+ class _LazyImport:
11
+ """Wraps around classes lazy imported from packaging.version
12
+ Output of the function v in following snippets are identical:
13
+ from packaging.version import Version
14
+ def v():
15
+ return Version('1.2.3')
16
+ and
17
+ Version = _LazyImport('Version')
18
+ def v():
19
+ return Version('1.2.3')
20
+ The difference here is that in later example imports
21
+ do not happen until v is called
22
+ """
23
+
24
+ def __init__(self, cls_name: str) -> None:
25
+ self._cls_name = cls_name
26
+
27
+ def get_cls(self):
28
+ try:
29
+ import packaging.version # type: ignore[import]
30
+ except ImportError:
31
+ # If packaging isn't installed, try and use the vendored copy
32
+ # in pkg_resources
33
+ from pkg_resources import packaging # type: ignore[attr-defined, no-redef]
34
+ return getattr(packaging.version, self._cls_name)
35
+
36
+ def __call__(self, *args, **kwargs):
37
+ return self.get_cls()(*args, **kwargs)
38
+
39
+ def __instancecheck__(self, obj):
40
+ return isinstance(obj, self.get_cls())
41
+
42
+
43
+ Version = _LazyImport("Version")
44
+
45
+
46
+ class NvfuserVersion(str):
47
+ @classmethod
48
+ def _convert_to_version(cls, ver: Any) -> Version:
49
+ if isinstance(ver, str):
50
+ return Version(ver.split("+")[0])
51
+ elif isinstance(ver, Version.get_cls()):
52
+ return ver
53
+ else:
54
+ raise ValueError("can't convert {} to Version".format(ver))
55
+
56
+ def _cmp_version(self, other: Any, method: str) -> Version:
57
+ return getattr(NvfuserVersion._convert_to_version(self), method)(
58
+ NvfuserVersion._convert_to_version(other)
59
+ )
60
+
61
+
62
+ for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]:
63
+ setattr(
64
+ NvfuserVersion,
65
+ cmp_method,
66
+ lambda x, y, method=cmp_method: x._cmp_version(y, method),
67
+ )
68
+
69
+ __version__ = NvfuserVersion(_version_str)
@@ -0,0 +1,184 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ import torch
5
+
6
+ from ._C import DataType
7
+
8
+ import ctypes
9
+ import functools
10
+ import gc
11
+ from typing import Type, Union, Tuple
12
+
13
+ NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]]
14
+
15
+ _torch_dtype_to_nvfuser_dtype_map = {
16
+ torch.cdouble: DataType.ComplexDouble,
17
+ torch.cfloat: DataType.ComplexFloat,
18
+ torch.double: DataType.Double,
19
+ torch.float: DataType.Float,
20
+ torch.half: DataType.Half,
21
+ torch.bfloat16: DataType.BFloat16,
22
+ torch.float8_e4m3fn: DataType.Float8_e4m3fn,
23
+ torch.float8_e5m2: DataType.Float8_e5m2,
24
+ torch.long: DataType.Int,
25
+ torch.int: DataType.Int32,
26
+ torch.bool: DataType.Bool,
27
+ # Python scalars
28
+ complex: DataType.ComplexDouble,
29
+ float: DataType.Double,
30
+ int: DataType.Int,
31
+ bool: DataType.Bool,
32
+ }
33
+
34
+
35
+ def python_scalar_to_nvfuser_dtype(a: Union[int, float, complex, bool]):
36
+ return _torch_dtype_to_nvfuser_dtype_map[type(a)]
37
+
38
+
39
+ def torch_dtype_to_nvfuser_dtype(dtype: Union[torch.dtype, NumberTypeType]):
40
+ """
41
+ Translates from torch.dtype to nvFuser's DataType enum
42
+ """
43
+ return _torch_dtype_to_nvfuser_dtype_map[dtype]
44
+
45
+
46
+ def get_device_properties() -> Tuple[int, float]:
47
+ """
48
+ Computes device properties using ctypes and cuda.
49
+ Note: Consider using CUDA-Python when CUDA support >= 12.0.
50
+ """
51
+ libnames = ("libcuda.so", "libcuda.dylib", "nvcuda.dll", "cuda.dll")
52
+ for libname in libnames:
53
+ try:
54
+ cuda = ctypes.CDLL(libname)
55
+ except OSError:
56
+ continue
57
+ else:
58
+ break
59
+ else:
60
+ raise OSError("could not load any of: " + " ".join(libnames))
61
+
62
+ # Device attribute enums (taken from cuda.h)
63
+ # https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1ge12b8a782bebe21b1ac0091bf9f4e2a3
64
+
65
+ CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK = 1
66
+ CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK = 8
67
+ CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK = 12
68
+ CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13
69
+ CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36
70
+ CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH = 37
71
+ CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE = 38
72
+ CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR = 39
73
+
74
+ device_properties = {}
75
+ device = torch.cuda.current_device()
76
+ cuda_properties = torch.cuda.get_device_properties(device)
77
+
78
+ device_properties["gpu_name"] = cuda_properties.name
79
+ device_properties["gpu_compute_capability_major"] = cuda_properties.major
80
+ device_properties["gpu_compute_capability_minor"] = cuda_properties.minor
81
+ device_properties["gpu_gmem_bytes"] = cuda_properties.total_memory
82
+ device_properties["gpu_sm_count"] = cuda_properties.multi_processor_count
83
+
84
+ max_threads_per_block = ctypes.c_int()
85
+ cuda.cuDeviceGetAttribute(
86
+ ctypes.byref(max_threads_per_block),
87
+ CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK,
88
+ device,
89
+ )
90
+ device_properties["gpu_max_threads_per_block"] = max_threads_per_block.value
91
+
92
+ smem_per_block = ctypes.c_int()
93
+ cuda.cuDeviceGetAttribute(
94
+ ctypes.byref(smem_per_block),
95
+ CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK,
96
+ device,
97
+ )
98
+ device_properties["gpu_smem_bytes_per_block"] = smem_per_block.value
99
+
100
+ max_reg_per_block = ctypes.c_int()
101
+ cuda.cuDeviceGetAttribute(
102
+ ctypes.byref(max_reg_per_block),
103
+ CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK,
104
+ device,
105
+ )
106
+ device_properties["gpu_regs_per_block"] = max_reg_per_block.value
107
+
108
+ max_clock_khz = ctypes.c_int()
109
+ cuda.cuDeviceGetAttribute(
110
+ ctypes.byref(max_clock_khz),
111
+ CU_DEVICE_ATTRIBUTE_CLOCK_RATE,
112
+ device,
113
+ )
114
+ device_properties["gpu_clock_rate_khz"] = max_clock_khz.value
115
+
116
+ l2_cache_size = ctypes.c_int()
117
+ cuda.cuDeviceGetAttribute(
118
+ ctypes.byref(l2_cache_size), CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE, device
119
+ )
120
+ device_properties["gpu_l2_bytes"] = l2_cache_size.value
121
+
122
+ memory_clock_rate = ctypes.c_int()
123
+ cuda.cuDeviceGetAttribute(
124
+ ctypes.byref(memory_clock_rate), CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device
125
+ )
126
+ device_properties["gpu_mem_clock_khz"] = memory_clock_rate.value
127
+
128
+ memory_bus_width = ctypes.c_int()
129
+ cuda.cuDeviceGetAttribute(
130
+ ctypes.byref(memory_bus_width),
131
+ CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH,
132
+ device,
133
+ )
134
+ device_properties["gpu_mem_bus_width"] = memory_bus_width.value
135
+
136
+ max_threads_per_sm = ctypes.c_int()
137
+ cuda.cuDeviceGetAttribute(
138
+ ctypes.byref(max_threads_per_sm),
139
+ CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR,
140
+ device,
141
+ )
142
+ device_properties["gpu_max_threads_per_sm"] = max_threads_per_sm.value
143
+
144
+ # Compute peak bandwidth in GBps
145
+ peak_bandwidth = (2 * memory_bus_width.value * memory_clock_rate.value) / (1e6 * 8)
146
+ device_properties["gpu_peak_bandwidth_gbps"] = peak_bandwidth
147
+
148
+ return device_properties
149
+
150
+
151
+ DEVICE_PROPERTIES = None
152
+ if torch.cuda.is_available():
153
+ # Loading libraries will raise errors on non-CUDA machines.
154
+ DEVICE_PROPERTIES = get_device_properties()
155
+
156
+
157
+ def retry_on_oom_or_skip_test(func):
158
+ """Decorator: upon torch.OutOfMemoryError clear the cache and retry test"""
159
+
160
+ @functools.wraps(func)
161
+ def retried_func(*args, **kwargs):
162
+ try:
163
+ output = func(*args, **kwargs)
164
+ except torch.OutOfMemoryError:
165
+ pass
166
+ else:
167
+ return output
168
+
169
+ # We have hit an OOM error, so clear the cache and retry
170
+ gc.collect()
171
+ torch.cuda.empty_cache()
172
+
173
+ try:
174
+ output = func(*args, **kwargs)
175
+ except torch.OutOfMemoryError as e:
176
+ # If we hit an OOM this time, then skip the test
177
+ import pytest
178
+
179
+ pytest.skip(f"Test failed due to OutOfMemoryError: {e}")
180
+ return
181
+
182
+ return output
183
+
184
+ return retried_func
@@ -0,0 +1,20 @@
1
+ #----------------------------------------------------------------
2
+ # Generated CMake target import file for configuration "Release".
3
+ #----------------------------------------------------------------
4
+
5
+ # Commands may need to know the format version.
6
+ set(CMAKE_IMPORT_FILE_VERSION 1)
7
+
8
+ # Import target "nvfuser_codegen" for configuration "Release"
9
+ set_property(TARGET nvfuser_codegen APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
10
+ set_target_properties(nvfuser_codegen PROPERTIES
11
+ IMPORTED_LINK_DEPENDENT_LIBRARIES_RELEASE "torch"
12
+ IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libnvfuser_codegen.so"
13
+ IMPORTED_SONAME_RELEASE "libnvfuser_codegen.so"
14
+ )
15
+
16
+ list(APPEND _cmake_import_check_targets nvfuser_codegen )
17
+ list(APPEND _cmake_import_check_files_for_nvfuser_codegen "${_IMPORT_PREFIX}/lib/libnvfuser_codegen.so" )
18
+
19
+ # Commands beyond this point should not need to know the version.
20
+ set(CMAKE_IMPORT_FILE_VERSION)
@@ -0,0 +1,106 @@
1
+ # Generated by CMake
2
+
3
+ if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8)
4
+ message(FATAL_ERROR "CMake >= 2.8.3 required")
5
+ endif()
6
+ if(CMAKE_VERSION VERSION_LESS "2.8.3")
7
+ message(FATAL_ERROR "CMake >= 2.8.3 required")
8
+ endif()
9
+ cmake_policy(PUSH)
10
+ cmake_policy(VERSION 2.8.3...3.29)
11
+ #----------------------------------------------------------------
12
+ # Generated CMake target import file.
13
+ #----------------------------------------------------------------
14
+
15
+ # Commands may need to know the format version.
16
+ set(CMAKE_IMPORT_FILE_VERSION 1)
17
+
18
+ # Protect against multiple inclusion, which would fail when already imported targets are added once more.
19
+ set(_cmake_targets_defined "")
20
+ set(_cmake_targets_not_defined "")
21
+ set(_cmake_expected_targets "")
22
+ foreach(_cmake_expected_target IN ITEMS nvfuser_codegen)
23
+ list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
24
+ if(TARGET "${_cmake_expected_target}")
25
+ list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
26
+ else()
27
+ list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
28
+ endif()
29
+ endforeach()
30
+ unset(_cmake_expected_target)
31
+ if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
32
+ unset(_cmake_targets_defined)
33
+ unset(_cmake_targets_not_defined)
34
+ unset(_cmake_expected_targets)
35
+ unset(CMAKE_IMPORT_FILE_VERSION)
36
+ cmake_policy(POP)
37
+ return()
38
+ endif()
39
+ if(NOT _cmake_targets_defined STREQUAL "")
40
+ string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
41
+ string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
42
+ message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
43
+ endif()
44
+ unset(_cmake_targets_defined)
45
+ unset(_cmake_targets_not_defined)
46
+ unset(_cmake_expected_targets)
47
+
48
+
49
+ # Compute the installation prefix relative to this file.
50
+ get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
51
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
52
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
53
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
54
+ if(_IMPORT_PREFIX STREQUAL "/")
55
+ set(_IMPORT_PREFIX "")
56
+ endif()
57
+
58
+ # Create imported target nvfuser_codegen
59
+ add_library(nvfuser_codegen SHARED IMPORTED)
60
+
61
+ set_target_properties(nvfuser_codegen PROPERTIES
62
+ INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include/nvfuser"
63
+ )
64
+
65
+ # Load information for each installed configuration.
66
+ file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/NvfuserConfig-*.cmake")
67
+ foreach(_cmake_config_file IN LISTS _cmake_config_files)
68
+ include("${_cmake_config_file}")
69
+ endforeach()
70
+ unset(_cmake_config_file)
71
+ unset(_cmake_config_files)
72
+
73
+ # Cleanup temporary variables.
74
+ set(_IMPORT_PREFIX)
75
+
76
+ # Loop over all imported files and verify that they actually exist
77
+ foreach(_cmake_target IN LISTS _cmake_import_check_targets)
78
+ if(CMAKE_VERSION VERSION_LESS "3.28"
79
+ OR NOT DEFINED _cmake_import_check_xcframework_for_${_cmake_target}
80
+ OR NOT IS_DIRECTORY "${_cmake_import_check_xcframework_for_${_cmake_target}}")
81
+ foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}")
82
+ if(NOT EXISTS "${_cmake_file}")
83
+ message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file
84
+ \"${_cmake_file}\"
85
+ but this file does not exist. Possible reasons include:
86
+ * The file was deleted, renamed, or moved to another location.
87
+ * An install or uninstall procedure did not complete successfully.
88
+ * The installation package was faulty and contained
89
+ \"${CMAKE_CURRENT_LIST_FILE}\"
90
+ but not all the files it references.
91
+ ")
92
+ endif()
93
+ endforeach()
94
+ endif()
95
+ unset(_cmake_file)
96
+ unset("_cmake_import_check_files_for_${_cmake_target}")
97
+ endforeach()
98
+ unset(_cmake_target)
99
+ unset(_cmake_import_check_targets)
100
+
101
+ # This file does not depend on other imported targets which have
102
+ # been exported from the same project but in a separate export set.
103
+
104
+ # Commands beyond this point should not need to know the version.
105
+ set(CMAKE_IMPORT_FILE_VERSION)
106
+ cmake_policy(POP)
nvfuser/utils.py ADDED
@@ -0,0 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ import os
5
+
6
+
7
+ __all__ = [
8
+ "cmake_prefix_path",
9
+ ]
10
+
11
+
12
+ cmake_prefix_path = os.path.join(
13
+ os.path.dirname(os.path.dirname(__file__)),
14
+ "nvfuser",
15
+ "share",
16
+ "cmake",
17
+ "nvfuser",
18
+ )
nvfuser/version.py ADDED
@@ -0,0 +1 @@
1
+ _version_str = '0.2.25+git93b68e0'