nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (242) hide show
  1. nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
  2. nvfuser/__init__.py +618 -0
  3. nvfuser/__init__.pyi +4 -0
  4. nvfuser/contrib/__init__.py +9 -0
  5. nvfuser/contrib/nn/__init__.py +13 -0
  6. nvfuser/contrib/nn/normalization.py +725 -0
  7. nvfuser/include/nvfuser/alias_analysis.h +116 -0
  8. nvfuser/include/nvfuser/bfs.h +929 -0
  9. nvfuser/include/nvfuser/codegen.h +26 -0
  10. nvfuser/include/nvfuser/compute_at.h +28 -0
  11. nvfuser/include/nvfuser/compute_at_map.h +394 -0
  12. nvfuser/include/nvfuser/contiguity.h +351 -0
  13. nvfuser/include/nvfuser/cuda_utils.h +50 -0
  14. nvfuser/include/nvfuser/debug.h +50 -0
  15. nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
  16. nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
  17. nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
  18. nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
  19. nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
  20. nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
  21. nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
  22. nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
  23. nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
  24. nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
  25. nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
  26. nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
  27. nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
  28. nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
  29. nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
  30. nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
  31. nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
  32. nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
  33. nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
  34. nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
  35. nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
  36. nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
  37. nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
  38. nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
  39. nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
  40. nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
  41. nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
  42. nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
  43. nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
  44. nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
  45. nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
  46. nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
  47. nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
  48. nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
  49. nvfuser/include/nvfuser/device_lower/utils.h +382 -0
  50. nvfuser/include/nvfuser/device_lower/validation.h +74 -0
  51. nvfuser/include/nvfuser/disjoint_set.h +556 -0
  52. nvfuser/include/nvfuser/dispatch.h +334 -0
  53. nvfuser/include/nvfuser/driver_api.h +49 -0
  54. nvfuser/include/nvfuser/dynamic_transform.h +316 -0
  55. nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
  56. nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
  57. nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
  58. nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
  59. nvfuser/include/nvfuser/evaluator_common.h +295 -0
  60. nvfuser/include/nvfuser/exceptions.h +283 -0
  61. nvfuser/include/nvfuser/expr_evaluator.h +125 -0
  62. nvfuser/include/nvfuser/expr_simplifier.h +218 -0
  63. nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
  64. nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
  65. nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
  66. nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
  67. nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
  68. nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
  69. nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
  70. nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
  71. nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
  72. nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
  73. nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
  74. nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
  75. nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
  76. nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
  77. nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
  78. nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
  79. nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
  80. nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
  81. nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
  82. nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
  83. nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
  84. nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
  85. nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
  86. nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
  87. nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
  88. nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
  89. nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
  90. nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
  91. nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
  92. nvfuser/include/nvfuser/fusion.h +511 -0
  93. nvfuser/include/nvfuser/fusion_guard.h +37 -0
  94. nvfuser/include/nvfuser/fusion_profiler.h +311 -0
  95. nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
  96. nvfuser/include/nvfuser/global_allocator.h +27 -0
  97. nvfuser/include/nvfuser/grouped_reduction.h +47 -0
  98. nvfuser/include/nvfuser/host_ir/container.h +60 -0
  99. nvfuser/include/nvfuser/host_ir/executor.h +152 -0
  100. nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
  101. nvfuser/include/nvfuser/host_ir/lower.h +35 -0
  102. nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
  103. nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
  104. nvfuser/include/nvfuser/id_model/id_model.h +359 -0
  105. nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
  106. nvfuser/include/nvfuser/id_model/indexing.h +208 -0
  107. nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
  108. nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
  109. nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
  110. nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
  111. nvfuser/include/nvfuser/id_model/schedule.h +54 -0
  112. nvfuser/include/nvfuser/id_model/to_string.h +87 -0
  113. nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
  114. nvfuser/include/nvfuser/id_model/utils.h +176 -0
  115. nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
  116. nvfuser/include/nvfuser/index_compute.h +651 -0
  117. nvfuser/include/nvfuser/instrumentation.h +107 -0
  118. nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
  119. nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
  120. nvfuser/include/nvfuser/ir/builder.h +215 -0
  121. nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
  122. nvfuser/include/nvfuser/ir/cloner.h +185 -0
  123. nvfuser/include/nvfuser/ir/container.h +226 -0
  124. nvfuser/include/nvfuser/ir/graphviz.h +119 -0
  125. nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
  126. nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
  127. nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
  128. nvfuser/include/nvfuser/ir/iostream.h +98 -0
  129. nvfuser/include/nvfuser/ir/printer.h +57 -0
  130. nvfuser/include/nvfuser/ir/utils.h +801 -0
  131. nvfuser/include/nvfuser/iter_visitor.h +661 -0
  132. nvfuser/include/nvfuser/kernel.h +299 -0
  133. nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
  134. nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
  135. nvfuser/include/nvfuser/kernel_ir.h +1457 -0
  136. nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
  137. nvfuser/include/nvfuser/linked_hash_map.h +97 -0
  138. nvfuser/include/nvfuser/logical_domain_map.h +577 -0
  139. nvfuser/include/nvfuser/macros.h +23 -0
  140. nvfuser/include/nvfuser/mma_type.h +257 -0
  141. nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
  142. nvfuser/include/nvfuser/multidevice/communication.h +232 -0
  143. nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
  144. nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
  145. nvfuser/include/nvfuser/multidevice/executor.h +107 -0
  146. nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
  147. nvfuser/include/nvfuser/multidevice/utils.h +187 -0
  148. nvfuser/include/nvfuser/non_divisible_split.h +86 -0
  149. nvfuser/include/nvfuser/opaque_type.h +129 -0
  150. nvfuser/include/nvfuser/ops/alias.h +192 -0
  151. nvfuser/include/nvfuser/ops/all_ops.h +13 -0
  152. nvfuser/include/nvfuser/ops/arith.h +712 -0
  153. nvfuser/include/nvfuser/ops/composite.h +130 -0
  154. nvfuser/include/nvfuser/ops/indexing.h +55 -0
  155. nvfuser/include/nvfuser/ops/normalization.h +263 -0
  156. nvfuser/include/nvfuser/ops/utils.h +127 -0
  157. nvfuser/include/nvfuser/options.h +313 -0
  158. nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
  159. nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
  160. nvfuser/include/nvfuser/polymorphic_value.h +432 -0
  161. nvfuser/include/nvfuser/predicate_compute.h +213 -0
  162. nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
  163. nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
  164. nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
  165. nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
  166. nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
  167. nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
  168. nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
  169. nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
  170. nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
  171. nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
  172. nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
  173. nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
  174. nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
  175. nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
  176. nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
  177. nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
  178. nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
  179. nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
  180. nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
  181. nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
  182. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
  183. nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
  184. nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
  185. nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
  186. nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
  187. nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
  188. nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
  189. nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
  190. nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
  191. nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
  192. nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
  193. nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
  194. nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
  195. nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
  196. nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
  197. nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
  198. nvfuser/include/nvfuser/scheduler/registry.h +97 -0
  199. nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
  200. nvfuser/include/nvfuser/scheduler/resize.h +41 -0
  201. nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
  202. nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
  203. nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
  204. nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
  205. nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
  206. nvfuser/include/nvfuser/scheduler/utils.h +771 -0
  207. nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
  208. nvfuser/include/nvfuser/serde/factory.h +55 -0
  209. nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
  210. nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
  211. nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
  212. nvfuser/include/nvfuser/serde/utils.h +34 -0
  213. nvfuser/include/nvfuser/struct.inl +127 -0
  214. nvfuser/include/nvfuser/swizzle.h +54 -0
  215. nvfuser/include/nvfuser/sys_utils.h +40 -0
  216. nvfuser/include/nvfuser/tensor_metadata.h +118 -0
  217. nvfuser/include/nvfuser/tma.h +124 -0
  218. nvfuser/include/nvfuser/transform_iter.h +522 -0
  219. nvfuser/include/nvfuser/transform_replay.h +297 -0
  220. nvfuser/include/nvfuser/transform_rfactor.h +33 -0
  221. nvfuser/include/nvfuser/transform_view.h +136 -0
  222. nvfuser/include/nvfuser/type.h +1125 -0
  223. nvfuser/include/nvfuser/type_promotion.h +61 -0
  224. nvfuser/include/nvfuser/utils.h +619 -0
  225. nvfuser/include/nvfuser/val_graph.h +446 -0
  226. nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
  227. nvfuser/include/nvfuser/validator_utils.h +92 -0
  228. nvfuser/include/nvfuser/vectorization_info.h +31 -0
  229. nvfuser/include/nvfuser/visibility.h +21 -0
  230. nvfuser/lib/libnvfuser_codegen.so +0 -0
  231. nvfuser/nvfuser_version.py +69 -0
  232. nvfuser/pytorch_utils.py +184 -0
  233. nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
  234. nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
  235. nvfuser/utils.py +18 -0
  236. nvfuser/version.py +1 -0
  237. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
  238. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
  239. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
  240. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
  241. nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
  242. nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,725 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
2
+ # All rights reserved.
3
+ # SPDX-License-Identifier: BSD-3-Clause
4
+ import enum
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import torch
8
+
9
+ import nvfuser
10
+
11
+ from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype
12
+
13
+
14
+ __all__ = [
15
+ "InstanceNorm1dNVFuser",
16
+ "InstanceNorm2dNVFuser",
17
+ "InstanceNorm3dNVFuser",
18
+ ]
19
+
20
+
21
+ NamedAxis = enum.Enum("NamedAxis", ["BATCH", "CHANNEL"])
22
+
23
+
24
+ def partially_contig_tensor(
25
+ fd: "nvfuser.FusionDefinition",
26
+ x: torch.Tensor,
27
+ ) -> "nvfuser.Tensor":
28
+ return fd.define_tensor(
29
+ shape=[1 if dim_size == 1 else -1 for dim_size in x.size()],
30
+ contiguity=nvfuser.compute_contiguity(x.size(), x.stride()),
31
+ dtype=torch_dtype_to_nvfuser_dtype(x.dtype),
32
+ )
33
+
34
+
35
+ def norm_fusion_forward(
36
+ fd: "nvfuser.FusionDefinition",
37
+ inputs: List[torch.Tensor],
38
+ x: "nvfuser.Tensor",
39
+ weight: Optional["nvfuser.Tensor"],
40
+ bias: Optional["nvfuser.Tensor"],
41
+ running_mean: Optional["nvfuser.Tensor"],
42
+ running_var: Optional["nvfuser.Tensor"],
43
+ eps: "nvfuser.Scalar",
44
+ use_input_stats: bool,
45
+ momentum: "nvfuser.Scalar",
46
+ channels_last: bool,
47
+ x_datatype: "nvfuser.DataType",
48
+ unbiased: bool = False,
49
+ *,
50
+ stat_axes: List[NamedAxis],
51
+ ) -> Tuple["nvfuser.Tensor", "nvfuser.Tensor", "nvfuser.Tensor"]:
52
+ """Modify FusionDefinition to add a generic normalization layer (forward).
53
+
54
+ This can be used to construct a BatchNorm, GroupNorm, InstanceNorm, or
55
+ LayerNorm network by indicating different sets of axes to preserve.
56
+
57
+ BatchNorm: `stat_axes = [NamedAxis.CHANNEL]`
58
+ LayerNorm: `stat_axes = [NamedAxis.BATCH]`
59
+ InstanceNorm: `stat_axes = [NamedAxis.BATCH, NamedAxis.CHANNEL]`
60
+
61
+ Args:
62
+ fd: An initialized FusionDefinition.
63
+ inputs: A list of :class:'torch.Tensor' inputs to the
64
+ `FusionDefinition` `fd`.
65
+ x: An input NVFuser tensor.
66
+ weight: If given, multiply normed output by this `Tensor`. It should be
67
+ one-dimensional if `NamedAxis.CHANNEL` is in `stat_axes`, and
68
+ zero-dimensional otherwise. It will be broadcast along all other
69
+ dimensions.
70
+ bias: If given, add this `Tensor` to normed output. It should be
71
+ one-dimensional if `NamedAxis.CHANNEL` is in `stat_axes`, and
72
+ zero-dimensional otherwise. It will be broadcast along all other
73
+ dimensions.
74
+ running_mean: If given, a running mean estimate that will be modified
75
+ in place.
76
+ running_var: If given, a running variance estimate that will be
77
+ modified in place.
78
+ eps: Amount to regularize the square root needed to convert variance to
79
+ standard deviation.
80
+ use_input_stats: Whether to compute the stats of this batch or to
81
+ _only_ use the provided running_mean and running_var.
82
+ momentum: Momentum for exponentially weighted moving average of running
83
+ stats.
84
+ channels_last: Whether channels are in position -1 (`True`) or 1
85
+ (`False`).
86
+ x_datatype: :class:'DataType' of input :class:'Tensor' `x`
87
+ unbiased: Whether to use unbiased variance for computing current batch
88
+ statistics. Note that unbiased estimates are always used for
89
+ running variance updates, regardless of this argument's value.
90
+ stat_axes: A list of `NamedAxis` objects indicating a combination of
91
+ axes with which to index the computed statistics. This can be used
92
+ to implement multiple types of normalization layers, since most of
93
+ those differ only in which axes are reduced over.
94
+ Returns:
95
+ The normalized output, as well as mean and 1/std. Note that
96
+ `fd.add_output` is _not_ called by this function.
97
+ """
98
+ assert (running_var is None) == (
99
+ running_mean is None
100
+ ), "Iff running mean or var is given, the other should be"
101
+
102
+ # dyn_shape holds Scalars describing the size of the input x
103
+ dyn_shape = fd.ops.tensor_sizes(x)
104
+
105
+ num_dims = len(dyn_shape)
106
+
107
+ batch_dim = 0
108
+ batch_size = dyn_shape[batch_dim]
109
+
110
+ channel_dim = num_dims - 1 if channels_last else 1
111
+ num_channels = dyn_shape[channel_dim]
112
+
113
+ # Running stats will be kept possibly for channel but never by instance, so
114
+ # we will reduce along batch_dim before updating running stats.
115
+ # These are used to broadcast in spatial dims
116
+ is_spatial_dim = [True] * num_dims
117
+ is_spatial_or_batch_dim = [True] * num_dims
118
+
119
+ if NamedAxis.BATCH in stat_axes:
120
+ is_spatial_dim[batch_dim] = False
121
+ if NamedAxis.CHANNEL in stat_axes:
122
+ is_spatial_dim[channel_dim] = False
123
+ is_spatial_or_batch_dim[channel_dim] = False
124
+ x_reduction_axes = [ax for ax, flag in enumerate(is_spatial_dim) if flag]
125
+ num_features = fd.define_scalar(1)
126
+ for ax in x_reduction_axes:
127
+ num_features = fd.ops.mul(num_features, dyn_shape[ax])
128
+
129
+ if use_input_stats or running_mean is None:
130
+ # In NVFuser Python we pass correction=1 to request unbiased variance calculation
131
+ x_var, x_mean = fd.ops.var_mean(x, x_reduction_axes, int(unbiased))
132
+ if running_mean is not None:
133
+ one = fd.define_scalar(1.0)
134
+ rev_momentum = fd.ops.sub(one, momentum)
135
+
136
+ # do running mean with momentum
137
+ current_mean_hat = fd.ops.mul(x_mean, momentum)
138
+ mean_hat = fd.ops.mul(running_mean, rev_momentum)
139
+ new_mean_hat = fd.ops.add(mean_hat, current_mean_hat)
140
+
141
+ # If computing stats for each instance, we don't want to keep those
142
+ # for our running mean calculation, so we sum them here
143
+ new_mean_sum = (
144
+ fd.ops.sum(new_mean_hat, [0])
145
+ if NamedAxis.BATCH in stat_axes
146
+ else new_mean_hat
147
+ )
148
+
149
+ rev_batch_size = fd.ops.reciprocal(batch_size)
150
+ new_mean_channels_only = fd.ops.mul(new_mean_sum, rev_batch_size)
151
+ if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]:
152
+ new_mean_channels_only = fd.ops.cast(new_mean_channels_only, x_datatype)
153
+ fd.add_output(new_mean_channels_only, alias_input=running_mean)
154
+
155
+ # running var calculation
156
+ x_var_unbiased = x_var
157
+ if not unbiased:
158
+ # multiply by correction to go from biased to unbiased estimate
159
+ b2ub = fd.ops.div(
160
+ num_features, fd.ops.sub(num_features, fd.define_scalar(1))
161
+ )
162
+ x_var_unbiased = fd.ops.mul(x_var, b2ub)
163
+
164
+ current_var_hat = fd.ops.mul(x_var_unbiased, momentum)
165
+ var_hat = fd.ops.mul(running_var, rev_momentum)
166
+ new_var_hat = fd.ops.add(var_hat, current_var_hat)
167
+
168
+ # See above about reducing over batch dim for running stats
169
+ new_var_sum = (
170
+ fd.ops.sum(new_var_hat, [0])
171
+ if NamedAxis.BATCH in stat_axes
172
+ else new_var_hat
173
+ )
174
+
175
+ new_var_channels_only = fd.ops.mul(new_var_sum, rev_batch_size)
176
+ if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]:
177
+ new_var_channels_only = fd.ops.cast(new_var_channels_only, x_datatype)
178
+ fd.add_output(new_var_channels_only, alias_input=running_var)
179
+
180
+ mean = x_mean
181
+ mean_bcast = fd.ops.broadcast(mean, is_spatial_dim)
182
+ x_sub_mean = fd.ops.sub(x, mean_bcast)
183
+
184
+ var_eps = fd.ops.add(x_var, eps)
185
+ invstd = fd.ops.rsqrt(var_eps)
186
+ invstd_bcast = fd.ops.broadcast(invstd, is_spatial_dim)
187
+
188
+ x_normed = fd.ops.mul(x_sub_mean, invstd_bcast)
189
+
190
+ else: # This is inference mode with running stats
191
+ assert running_mean is not None
192
+ r_mean_bcast = fd.ops.broadcast(running_mean, is_spatial_or_batch_dim)
193
+ x_sub_mean = fd.ops.sub(x, r_mean_bcast)
194
+
195
+ var_eps = fd.ops.add(running_var, eps)
196
+ invstd = fd.ops.rsqrt(var_eps)
197
+ invstd_bcast = fd.ops.broadcast(invstd, is_spatial_or_batch_dim)
198
+
199
+ mean = running_mean
200
+ x_normed = fd.ops.mul(x_sub_mean, invstd_bcast)
201
+
202
+ if weight is not None:
203
+ weight_bcast = fd.ops.broadcast(weight, is_spatial_or_batch_dim)
204
+ x_normed = fd.ops.mul(x_normed, weight_bcast)
205
+ if bias is not None:
206
+ bias_bcast = fd.ops.broadcast(bias, is_spatial_or_batch_dim)
207
+ x_normed = fd.ops.add(x_normed, bias_bcast)
208
+
209
+ return x_normed, mean, invstd
210
+
211
+
212
+ def norm_fusion_backward(
213
+ fd: "nvfuser.FusionDefinition",
214
+ inputs: List[torch.Tensor],
215
+ x: "nvfuser.Tensor",
216
+ grad_output: "nvfuser.Tensor",
217
+ mean: Optional[torch.Tensor],
218
+ invstd: torch.Tensor,
219
+ weight: Optional["nvfuser.Tensor"],
220
+ bias: Optional["nvfuser.Tensor"],
221
+ running_mean: Optional["nvfuser.Tensor"],
222
+ running_var: Optional["nvfuser.Tensor"],
223
+ use_input_stats: bool,
224
+ channels_last: bool,
225
+ x_datatype: "nvfuser.DataType",
226
+ *,
227
+ stat_axes: List[NamedAxis],
228
+ ) -> Tuple["nvfuser.Tensor", "nvfuser.Tensor", "nvfuser.Tensor"]:
229
+ """
230
+ Modify FusionDefinition to add a generic normalization layer (backward).
231
+
232
+ Args:
233
+ fd: An initialized FusionDefinition.
234
+ inputs: A list of :class:'torch.Tensor' inputs to the
235
+ `FusionDefinition` `fd`.
236
+ x: The input NVFuser tensor.
237
+ grad_output: NVFuser tensor representing gradient of loss with respect
238
+ to downstream activation (typical input to backward()).
239
+ mean: The mean used in the forward normalization.
240
+ invstd: The reciprocal of standard deviation used in the forward normalization.
241
+ weight: If given, multiply normed output by this `Tensor`. It should be
242
+ one-dimensional if `NamedAxis.CHANNEL` is in `stat_axes`, and
243
+ zero-dimensional otherwise. It will be broadcast along all other
244
+ dimensions.
245
+ bias: If given, add this `Tensor` to normed output. It should be
246
+ one-dimensional if `NamedAxis.CHANNEL` is in `stat_axes`, and
247
+ zero-dimensional otherwise. It will be broadcast along all other
248
+ dimensions.
249
+ running_mean: If given, a running mean estimate that will be modified
250
+ in place.
251
+ running_var: If given, a running variance estimate that will be
252
+ modified in place.
253
+ use_input_stats: Whether to compute the stats of this batch or to
254
+ _only_ use the provided running_mean and running_var.
255
+ channels_last: Whether channels are in position -1 (`True`) or 1
256
+ (`False`).
257
+ x_datatype: :class:'DataType' of input :class:'Tensor' `x`
258
+ stat_axes: A list of `NamedAxis` objects indicating a combination of
259
+ axes with which to index the computed statistics. This can be used
260
+ to implement multiple types of normalization layers, since most of
261
+ those differ only in which axes are reduced over.
262
+ Returns:
263
+ The normalized output, as well as mean and 1/std. Note that
264
+ `fd.add_output` is _not_ called by this function.
265
+ """
266
+ assert not (
267
+ (running_var is None) ^ (running_mean is None)
268
+ ), "Iff running mean or var is given, the other should be"
269
+
270
+ # dyn_shape holds Scalars describing the size of the input x
271
+ dyn_shape = fd.ops.tensor_sizes(x)
272
+
273
+ num_dims = len(dyn_shape)
274
+
275
+ batch_dim = 0
276
+ batch_size = dyn_shape[batch_dim]
277
+
278
+ channel_dim = num_dims - 1 if channels_last else 1
279
+ num_channels = dyn_shape[channel_dim]
280
+
281
+ # Running stats will be kept possibly for channel but never by instance, so
282
+ # we will reduce along batch_dim before updating running stats.
283
+ # These are used to broadcast in spatial dims
284
+ is_spatial_dim = [True] * num_dims
285
+ is_spatial_or_batch_dim = [True] * num_dims
286
+
287
+ if NamedAxis.BATCH in stat_axes:
288
+ is_spatial_dim[batch_dim] = False
289
+ if NamedAxis.CHANNEL in stat_axes:
290
+ is_spatial_dim[channel_dim] = False
291
+ is_spatial_or_batch_dim[channel_dim] = False
292
+ x_reduction_axes = [ax for ax, flag in enumerate(is_spatial_dim) if flag]
293
+ num_features = fd.define_scalar(1)
294
+ for ax in x_reduction_axes:
295
+ num_features = fd.ops.mul(num_features, dyn_shape[ax])
296
+
297
+ mean = fd.ops.broadcast(mean, is_spatial_dim)
298
+
299
+ norm = fd.ops.reciprocal(num_features)
300
+ grad_output_sum = fd.ops.sum(grad_output, x_reduction_axes)
301
+ dot_p = fd.ops.sum(
302
+ fd.ops.mul(
303
+ grad_output,
304
+ fd.ops.sub(x, mean),
305
+ ),
306
+ x_reduction_axes,
307
+ )
308
+ grad_mean = fd.ops.broadcast(fd.ops.mul(grad_output_sum, norm), is_spatial_dim)
309
+ proj_scale = fd.ops.broadcast(
310
+ fd.ops.mul(
311
+ fd.ops.mul(dot_p, norm),
312
+ fd.ops.mul(invstd, invstd),
313
+ ),
314
+ is_spatial_dim,
315
+ )
316
+
317
+ invstd_bcast = fd.ops.broadcast(invstd, is_spatial_dim)
318
+ grad_scale = (
319
+ invstd_bcast
320
+ if weight is None
321
+ else fd.ops.mul(
322
+ invstd_bcast,
323
+ fd.ops.broadcast(weight, is_spatial_or_batch_dim),
324
+ )
325
+ )
326
+ if use_input_stats:
327
+ proj = fd.ops.mul(fd.ops.sub(x, mean), proj_scale)
328
+ grad_input = fd.ops.mul(
329
+ fd.ops.sub(
330
+ fd.ops.sub(grad_output, proj),
331
+ grad_mean,
332
+ ),
333
+ grad_scale,
334
+ )
335
+ else:
336
+ grad_input = fd.ops.mul(grad_output, grad_scale)
337
+
338
+ if weight is not None:
339
+ grad_weight = fd.ops.mul(dot_p, invstd)
340
+ grad_weight_reduced = fd.ops.sum(grad_weight, [0])
341
+ else:
342
+ grad_weight_reduced = None
343
+ if bias is not None:
344
+ grad_bias = grad_output_sum
345
+ grad_bias_reduced = fd.ops.sum(grad_bias, [0])
346
+ else:
347
+ grad_bias_reduced = None
348
+
349
+ return grad_input, grad_weight_reduced, grad_bias_reduced
350
+
351
+
352
+ class NormNVFuserFunction(torch.autograd.Function):
353
+ @staticmethod
354
+ def forward(
355
+ ctx: Any, # contexts are actually objects of the type we are currently defining
356
+ x: torch.Tensor,
357
+ weight: Optional[torch.Tensor],
358
+ bias: Optional[torch.Tensor],
359
+ running_mean: Optional[torch.Tensor],
360
+ running_var: Optional[torch.Tensor],
361
+ use_input_stats: bool,
362
+ momentum: float,
363
+ eps: float,
364
+ unbiased: bool,
365
+ stat_axes: List[NamedAxis],
366
+ ) -> torch.Tensor:
367
+ # When x.shape[1] == 1, is_contiguous will tell us the tensor is
368
+ # channels_last, even when it is ordinary contiguous. This causes some
369
+ # issues so we only detect channels_last when channels > 1
370
+ channels_last = x.shape[1] > 1 and (
371
+ x.is_contiguous(memory_format=torch.channels_last)
372
+ or x.is_contiguous(memory_format=torch.channels_last_3d)
373
+ )
374
+ xorig = x
375
+ if channels_last:
376
+ order = [0] + list(range(2, len(x.shape))) + [1]
377
+ x = x.permute(order)
378
+
379
+ x_datatype = torch_dtype_to_nvfuser_dtype(x.dtype)
380
+
381
+ with nvfuser.FusionDefinition() as fd:
382
+ tv_x = partially_contig_tensor(fd, x)
383
+ inputs = [x]
384
+ if weight is not None:
385
+ tv_weight = partially_contig_tensor(fd, weight)
386
+ inputs.append(weight)
387
+ else:
388
+ tv_weight = None
389
+
390
+ if bias is not None:
391
+ tv_bias = partially_contig_tensor(fd, bias)
392
+ inputs.append(bias)
393
+ else:
394
+ tv_bias = None
395
+
396
+ if running_mean is None:
397
+ tv_running_mean = None
398
+ tv_running_var = None
399
+ else:
400
+ assert running_var is not None
401
+ tv_running_mean = partially_contig_tensor(fd, running_mean)
402
+ tv_running_var = partially_contig_tensor(fd, running_var)
403
+ inputs.extend([running_mean, running_var])
404
+
405
+ s_momentum = fd.define_scalar(nvfuser.DataType.Double)
406
+ s_eps = fd.define_scalar(nvfuser.DataType.Double)
407
+ inputs.extend([momentum, eps])
408
+
409
+ # cast inputs if necessary
410
+ if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]:
411
+ tv_x = fd.ops.cast(tv_x, nvfuser.DataType.Float)
412
+ if weight is not None and weight.dtype in [torch.half, torch.bfloat16]:
413
+ tv_weight = fd.ops.cast(tv_weight, nvfuser.DataType.Float)
414
+ if bias is not None and bias.dtype in [torch.half, torch.bfloat16]:
415
+ tv_bias = fd.ops.cast(tv_bias, nvfuser.DataType.Float)
416
+
417
+ out, mean, invstd = norm_fusion_forward(
418
+ fd,
419
+ inputs,
420
+ tv_x,
421
+ tv_weight,
422
+ tv_bias,
423
+ tv_running_mean,
424
+ tv_running_var,
425
+ s_eps,
426
+ use_input_stats,
427
+ s_momentum,
428
+ channels_last,
429
+ x_datatype=x_datatype,
430
+ unbiased=unbiased,
431
+ stat_axes=stat_axes,
432
+ )
433
+
434
+ if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]:
435
+ out = fd.ops.cast(out, x_datatype)
436
+
437
+ fd.add_output(out)
438
+ fd.add_output(mean)
439
+ fd.add_output(invstd)
440
+
441
+ out, mean, invstd = fd.execute(inputs)
442
+
443
+ ctx.stat_axes = stat_axes
444
+ ctx.use_input_stats = use_input_stats
445
+ ctx.channels_last = channels_last
446
+ # saving for backward in "explicit channels-last format"
447
+ ctx.save_for_backward(x, weight, bias, running_mean, running_var, mean, invstd)
448
+ if channels_last:
449
+ order = [0, len(x.shape) - 1] + list(range(1, len(x.shape) - 1))
450
+ out = out.permute(order)
451
+ if len(out.shape) == 4:
452
+ assert out.is_contiguous(memory_format=torch.channels_last)
453
+ assert xorig.is_contiguous(memory_format=torch.channels_last)
454
+ elif len(out.shape) == 5:
455
+ assert out.is_contiguous(memory_format=torch.channels_last_3d)
456
+ assert xorig.is_contiguous(memory_format=torch.channels_last_3d)
457
+ else:
458
+ raise RuntimeError(
459
+ "unhandled channels_last format variation in forward"
460
+ )
461
+ return out
462
+
463
+ @staticmethod
464
+ def backward(
465
+ ctx: Any, grad_output: torch.Tensor
466
+ ) -> Tuple[
467
+ torch.Tensor,
468
+ torch.Tensor,
469
+ torch.Tensor,
470
+ None,
471
+ None,
472
+ None,
473
+ None,
474
+ None,
475
+ None,
476
+ None,
477
+ ]:
478
+ """Instance norm backward using NVFuser"""
479
+ if ctx.channels_last:
480
+ order = [0] + list(range(2, len(grad_output.shape))) + [1]
481
+ grad_output = grad_output.permute(order)
482
+ # input was saved in "explicit channels-last format"
483
+ # assert ctx.saved_tensors[0].is_contiguous()
484
+ # grad_output = grad_output.contiguous()
485
+ x, weight, bias, running_mean, running_var, mean, invstd = ctx.saved_tensors
486
+
487
+ with nvfuser.FusionDefinition() as fd:
488
+ tv_x = partially_contig_tensor(fd, x)
489
+ if x.dtype in [torch.half, torch.bfloat16]:
490
+ tv_x = fd.ops.cast(tv_x, nvfuser.DataType.Float)
491
+ inputs = [x]
492
+ if weight is not None:
493
+ tv_weight = partially_contig_tensor(fd, weight)
494
+ if weight.dtype in [torch.half, torch.bfloat16]:
495
+ tv_weight = fd.ops.cast(tv_weight, nvfuser.DataType.Float)
496
+ inputs.append(weight)
497
+ else:
498
+ tv_weight = None
499
+ if bias is not None:
500
+ tv_bias = partially_contig_tensor(fd, bias)
501
+ if bias.dtype in [torch.half, torch.bfloat16]:
502
+ tv_bias = fd.ops.cast(tv_bias, nvfuser.DataType.Float)
503
+ inputs.append(bias)
504
+ else:
505
+ tv_bias = None
506
+ if running_mean is not None:
507
+ tv_running_mean = partially_contig_tensor(fd, running_mean)
508
+ if running_mean.dtype in [torch.half, torch.bfloat16]:
509
+ tv_running_mean = fd.ops.cast(
510
+ tv_running_mean, nvfuser.DataType.Float
511
+ )
512
+ inputs.append(running_mean)
513
+ else:
514
+ tv_running_mean = None
515
+ if running_var is not None:
516
+ tv_running_var = partially_contig_tensor(fd, running_var)
517
+ if running_var.dtype in [torch.half, torch.bfloat16]:
518
+ tv_running_var = fd.ops.cast(tv_running_var, nvfuser.DataType.Float)
519
+ inputs.append(running_var)
520
+ else:
521
+ tv_running_var = None
522
+
523
+ tv_mean = partially_contig_tensor(fd, mean)
524
+ if mean.dtype in [torch.half, torch.bfloat16]:
525
+ tv_mean = fd.ops.cast(tv_mean, nvfuser.DataType.Float)
526
+ inputs.append(mean)
527
+ tv_invstd = partially_contig_tensor(fd, invstd)
528
+ if invstd.dtype in [torch.half, torch.bfloat16]:
529
+ tv_invstd = fd.ops.cast(tv_invstd, nvfuser.DataType.Float)
530
+ inputs.append(invstd)
531
+
532
+ tv_grad_output = partially_contig_tensor(fd, grad_output)
533
+ if grad_output.dtype in [torch.half, torch.bfloat16]:
534
+ tv_grad_output = fd.ops.cast(tv_grad_output, nvfuser.DataType.Float)
535
+ inputs.append(grad_output)
536
+
537
+ x_datatype = torch_dtype_to_nvfuser_dtype(x.dtype)
538
+
539
+ grad_input, grad_weight, grad_bias = norm_fusion_backward(
540
+ fd,
541
+ inputs,
542
+ tv_x,
543
+ tv_grad_output,
544
+ tv_mean,
545
+ tv_invstd,
546
+ tv_weight,
547
+ tv_bias,
548
+ tv_running_mean,
549
+ tv_running_var,
550
+ ctx.use_input_stats,
551
+ ctx.channels_last,
552
+ x_datatype=x_datatype,
553
+ stat_axes=ctx.stat_axes,
554
+ )
555
+
556
+ if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]:
557
+ grad_input = fd.ops.cast(grad_input, x_datatype)
558
+ fd.add_output(grad_input)
559
+
560
+ if weight is not None:
561
+ if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]:
562
+ grad_weight = fd.ops.cast(grad_weight, x_datatype)
563
+ fd.add_output(grad_weight)
564
+
565
+ if bias is not None:
566
+ if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]:
567
+ grad_bias = fd.ops.cast(grad_bias, x_datatype)
568
+ fd.add_output(grad_bias)
569
+
570
+ res = fd.execute(inputs)
571
+ grad_input = res[0]
572
+ c = 1
573
+ if weight is not None:
574
+ grad_weight = res[c]
575
+ c += 1
576
+ else:
577
+ grad_weight = None
578
+ if bias is not None:
579
+ grad_bias = res[c]
580
+ c += 1
581
+ else:
582
+ grad_bias = None
583
+
584
+ if ctx.channels_last:
585
+ order = [0, len(grad_input.shape) - 1] + list(
586
+ range(1, len(grad_input.shape) - 1)
587
+ )
588
+ grad_input = grad_input.permute(order)
589
+ if len(grad_input.shape) == 4:
590
+ assert grad_input.is_contiguous(memory_format=torch.channels_last)
591
+ elif len(grad_input.shape) == 5:
592
+ assert grad_input.is_contiguous(memory_format=torch.channels_last_3d)
593
+ else:
594
+ raise RuntimeError(
595
+ "unhandled channels_last format variation in backward"
596
+ )
597
+ return (
598
+ grad_input,
599
+ grad_weight,
600
+ grad_bias,
601
+ None,
602
+ None,
603
+ None,
604
+ None,
605
+ None,
606
+ None,
607
+ None,
608
+ )
609
+
610
+
611
+ class _NormNVFuserBase(torch.nn.modules.batchnorm._NormBase):
612
+ stat_axes: Optional[List[NamedAxis]] = None
613
+
614
+ def __init__(
615
+ self,
616
+ num_features: int,
617
+ eps: float = 1e-5,
618
+ momentum: float = 0.1,
619
+ affine: bool = False,
620
+ track_running_stats: bool = False,
621
+ device: torch.device = None,
622
+ dtype: torch.dtype = None,
623
+ ) -> None:
624
+ factory_kwargs = {"device": device, "dtype": dtype}
625
+ super().__init__(
626
+ num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
627
+ )
628
+
629
+ def _check_input_dim(self, input: torch.Tensor) -> None:
630
+ raise NotImplementedError
631
+
632
+ def _load_from_state_dict(
633
+ self,
634
+ state_dict: Dict[str, Any],
635
+ prefix: str,
636
+ local_metadata: Any,
637
+ strict: bool,
638
+ missing_keys: List[str],
639
+ unexpected_keys: List[str],
640
+ error_msgs: List[str],
641
+ ) -> None:
642
+ version = local_metadata.get("version", None)
643
+ # at version 1: removed running_mean and running_var when
644
+ # track_running_stats=False (default)
645
+ if version is None and not self.track_running_stats:
646
+ running_stats_keys = []
647
+ for name in ("running_mean", "running_var"):
648
+ key = prefix + name
649
+ if key in state_dict:
650
+ running_stats_keys.append(key)
651
+ if len(running_stats_keys) > 0:
652
+ error_msgs.append(
653
+ "Unexpected running stats buffer(s) {names} for {klass} "
654
+ "with track_running_stats=False. If state_dict is a "
655
+ "checkpoint saved before 0.4.0, this may be expected "
656
+ "because {klass} does not track running stats by default "
657
+ "since 0.4.0. Please remove these keys from state_dict. If "
658
+ "the running stats are actually needed, instead set "
659
+ "track_running_stats=True in {klass} to enable them. See "
660
+ "the documentation of {klass} for details.".format(
661
+ names=" and ".join(
662
+ '"{}"'.format(k) for k in running_stats_keys
663
+ ),
664
+ klass=self.__class__.__name__,
665
+ )
666
+ )
667
+ for key in running_stats_keys:
668
+ state_dict.pop(key)
669
+
670
+ super()._load_from_state_dict(
671
+ state_dict,
672
+ prefix,
673
+ local_metadata,
674
+ strict,
675
+ missing_keys,
676
+ unexpected_keys,
677
+ error_msgs,
678
+ )
679
+
680
+ def forward(self, input: nvfuser.Tensor) -> nvfuser.Tensor:
681
+ assert input.is_cuda, "NVFuser InstanceNorm is CUDA only"
682
+ self._check_input_dim(input)
683
+ out = NormNVFuserFunction.apply(
684
+ input,
685
+ self.weight,
686
+ self.bias,
687
+ self.running_mean,
688
+ self.running_var,
689
+ self.training or not self.track_running_stats,
690
+ self.momentum,
691
+ self.eps,
692
+ False, # unbiased=False to match PyTorch functionality
693
+ self.stat_axes,
694
+ )
695
+ return out
696
+
697
+
698
+ class _InstanceNormNVFuser(_NormNVFuserBase):
699
+ stat_axes = [NamedAxis.BATCH, NamedAxis.CHANNEL]
700
+
701
+
702
+ class _BatchNormNVFuser(_NormNVFuserBase):
703
+ stat_axes = [NamedAxis.CHANNEL]
704
+
705
+
706
+ class _LayerNormNVFuser(_NormNVFuserBase):
707
+ stat_axes = [NamedAxis.BATCH]
708
+
709
+
710
+ class InstanceNorm1dNVFuser(_InstanceNormNVFuser):
711
+ def _check_input_dim(self, input: torch.Tensor) -> None:
712
+ if input.dim() != 3:
713
+ raise ValueError("expected 3D input (got {}D input)".format(input.dim()))
714
+
715
+
716
+ class InstanceNorm2dNVFuser(_InstanceNormNVFuser):
717
+ def _check_input_dim(self, input: torch.Tensor) -> None:
718
+ if input.dim() != 4:
719
+ raise ValueError("expected 4D input (got {}D input)".format(input.dim()))
720
+
721
+
722
+ class InstanceNorm3dNVFuser(_InstanceNormNVFuser):
723
+ def _check_input_dim(self, input: torch.Tensor) -> None:
724
+ if input.dim() != 5:
725
+ raise ValueError("expected 5D input (got {}D input)".format(input.dim()))