nvfuser-cu121-torch25 0.2.25.dev20250201__cp310-cp310-manylinux_2_28_x86_64.whl
Sign up to get free protection for your applications and to get access to all the features.
- nvfuser/_C.cpython-310-x86_64-linux-gnu.so +0 -0
- nvfuser/__init__.py +618 -0
- nvfuser/__init__.pyi +4 -0
- nvfuser/contrib/__init__.py +9 -0
- nvfuser/contrib/nn/__init__.py +13 -0
- nvfuser/contrib/nn/normalization.py +725 -0
- nvfuser/include/nvfuser/alias_analysis.h +116 -0
- nvfuser/include/nvfuser/bfs.h +929 -0
- nvfuser/include/nvfuser/codegen.h +26 -0
- nvfuser/include/nvfuser/compute_at.h +28 -0
- nvfuser/include/nvfuser/compute_at_map.h +394 -0
- nvfuser/include/nvfuser/contiguity.h +351 -0
- nvfuser/include/nvfuser/cuda_utils.h +50 -0
- nvfuser/include/nvfuser/debug.h +50 -0
- nvfuser/include/nvfuser/device_lower/analysis/bank_conflict.h +53 -0
- nvfuser/include/nvfuser/device_lower/analysis/circular_buffer.h +109 -0
- nvfuser/include/nvfuser/device_lower/analysis/device_version.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/divisible_split.h +28 -0
- nvfuser/include/nvfuser/device_lower/analysis/fused_reduction.h +36 -0
- nvfuser/include/nvfuser/device_lower/analysis/index_compute.h +322 -0
- nvfuser/include/nvfuser/device_lower/analysis/predicate_elimination.h +71 -0
- nvfuser/include/nvfuser/device_lower/analysis/sync_information.h +47 -0
- nvfuser/include/nvfuser/device_lower/analysis/tensor_memory.h +65 -0
- nvfuser/include/nvfuser/device_lower/analysis/thread_predicate.h +158 -0
- nvfuser/include/nvfuser/device_lower/analysis/tma.h +93 -0
- nvfuser/include/nvfuser/device_lower/analysis/trivial_broadcast.h +75 -0
- nvfuser/include/nvfuser/device_lower/id_model_options.h +135 -0
- nvfuser/include/nvfuser/device_lower/lower2device.h +391 -0
- nvfuser/include/nvfuser/device_lower/pass/alias_memory.h +37 -0
- nvfuser/include/nvfuser/device_lower/pass/allocation.h +32 -0
- nvfuser/include/nvfuser/device_lower/pass/circular_buffer.h +191 -0
- nvfuser/include/nvfuser/device_lower/pass/expr_sort.h +17 -0
- nvfuser/include/nvfuser/device_lower/pass/fusion_simplifier.h +21 -0
- nvfuser/include/nvfuser/device_lower/pass/grid_serialization.h +26 -0
- nvfuser/include/nvfuser/device_lower/pass/index.h +200 -0
- nvfuser/include/nvfuser/device_lower/pass/inline_ptx.h +16 -0
- nvfuser/include/nvfuser/device_lower/pass/insert_syncs.h +39 -0
- nvfuser/include/nvfuser/device_lower/pass/instrument.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/loop_rotation.h +150 -0
- nvfuser/include/nvfuser/device_lower/pass/loops.h +68 -0
- nvfuser/include/nvfuser/device_lower/pass/magic_zero.h +86 -0
- nvfuser/include/nvfuser/device_lower/pass/misaligned_vectorization.h +118 -0
- nvfuser/include/nvfuser/device_lower/pass/predicate.h +23 -0
- nvfuser/include/nvfuser/device_lower/pass/replace_size.h +24 -0
- nvfuser/include/nvfuser/device_lower/pass/scalar_hoist.h +115 -0
- nvfuser/include/nvfuser/device_lower/pass/unroll.h +98 -0
- nvfuser/include/nvfuser/device_lower/pass/vectorize_welford.h +45 -0
- nvfuser/include/nvfuser/device_lower/pass/warp_reduce.h +23 -0
- nvfuser/include/nvfuser/device_lower/utils.h +382 -0
- nvfuser/include/nvfuser/device_lower/validation.h +74 -0
- nvfuser/include/nvfuser/disjoint_set.h +556 -0
- nvfuser/include/nvfuser/dispatch.h +334 -0
- nvfuser/include/nvfuser/driver_api.h +49 -0
- nvfuser/include/nvfuser/dynamic_transform.h +316 -0
- nvfuser/include/nvfuser/dynamic_type/C++20/type_traits +37 -0
- nvfuser/include/nvfuser/dynamic_type/dynamic_type.h +969 -0
- nvfuser/include/nvfuser/dynamic_type/error.h +24 -0
- nvfuser/include/nvfuser/dynamic_type/type_traits.h +703 -0
- nvfuser/include/nvfuser/evaluator_common.h +295 -0
- nvfuser/include/nvfuser/exceptions.h +283 -0
- nvfuser/include/nvfuser/expr_evaluator.h +125 -0
- nvfuser/include/nvfuser/expr_simplifier.h +218 -0
- nvfuser/include/nvfuser/flatbuffers/allocator.h +68 -0
- nvfuser/include/nvfuser/flatbuffers/array.h +253 -0
- nvfuser/include/nvfuser/flatbuffers/base.h +486 -0
- nvfuser/include/nvfuser/flatbuffers/buffer.h +154 -0
- nvfuser/include/nvfuser/flatbuffers/buffer_ref.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/code_generator.h +80 -0
- nvfuser/include/nvfuser/flatbuffers/code_generators.h +234 -0
- nvfuser/include/nvfuser/flatbuffers/default_allocator.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/detached_buffer.h +114 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffer_builder.h +1225 -0
- nvfuser/include/nvfuser/flatbuffers/flatbuffers.h +272 -0
- nvfuser/include/nvfuser/flatbuffers/flatc.h +130 -0
- nvfuser/include/nvfuser/flatbuffers/flex_flat_util.h +36 -0
- nvfuser/include/nvfuser/flatbuffers/flexbuffers.h +1889 -0
- nvfuser/include/nvfuser/flatbuffers/grpc.h +300 -0
- nvfuser/include/nvfuser/flatbuffers/hash.h +127 -0
- nvfuser/include/nvfuser/flatbuffers/idl.h +1359 -0
- nvfuser/include/nvfuser/flatbuffers/minireflect.h +420 -0
- nvfuser/include/nvfuser/flatbuffers/reflection.h +522 -0
- nvfuser/include/nvfuser/flatbuffers/reflection_generated.h +1471 -0
- nvfuser/include/nvfuser/flatbuffers/registry.h +128 -0
- nvfuser/include/nvfuser/flatbuffers/stl_emulation.h +513 -0
- nvfuser/include/nvfuser/flatbuffers/string.h +64 -0
- nvfuser/include/nvfuser/flatbuffers/struct.h +53 -0
- nvfuser/include/nvfuser/flatbuffers/table.h +168 -0
- nvfuser/include/nvfuser/flatbuffers/util.h +731 -0
- nvfuser/include/nvfuser/flatbuffers/vector.h +393 -0
- nvfuser/include/nvfuser/flatbuffers/vector_downward.h +273 -0
- nvfuser/include/nvfuser/flatbuffers/verifier.h +317 -0
- nvfuser/include/nvfuser/fusion.h +511 -0
- nvfuser/include/nvfuser/fusion_guard.h +37 -0
- nvfuser/include/nvfuser/fusion_profiler.h +311 -0
- nvfuser/include/nvfuser/fusion_segmenter.h +751 -0
- nvfuser/include/nvfuser/global_allocator.h +27 -0
- nvfuser/include/nvfuser/grouped_reduction.h +47 -0
- nvfuser/include/nvfuser/host_ir/container.h +60 -0
- nvfuser/include/nvfuser/host_ir/executor.h +152 -0
- nvfuser/include/nvfuser/host_ir/host_ir.h +320 -0
- nvfuser/include/nvfuser/host_ir/lower.h +35 -0
- nvfuser/include/nvfuser/id_model/circular_buffer_indexing.h +56 -0
- nvfuser/include/nvfuser/id_model/contiguity.h +166 -0
- nvfuser/include/nvfuser/id_model/id_model.h +359 -0
- nvfuser/include/nvfuser/id_model/id_model_index_compute.h +81 -0
- nvfuser/include/nvfuser/id_model/indexing.h +208 -0
- nvfuser/include/nvfuser/id_model/indexing_traversal.h +72 -0
- nvfuser/include/nvfuser/id_model/indexing_utils.h +62 -0
- nvfuser/include/nvfuser/id_model/loop_promotion.h +180 -0
- nvfuser/include/nvfuser/id_model/predicate_indexing.h +104 -0
- nvfuser/include/nvfuser/id_model/schedule.h +54 -0
- nvfuser/include/nvfuser/id_model/to_string.h +87 -0
- nvfuser/include/nvfuser/id_model/transform_replay.h +58 -0
- nvfuser/include/nvfuser/id_model/utils.h +176 -0
- nvfuser/include/nvfuser/id_model/validation_utils.h +55 -0
- nvfuser/include/nvfuser/index_compute.h +651 -0
- nvfuser/include/nvfuser/instrumentation.h +107 -0
- nvfuser/include/nvfuser/ir/all_nodes.h +14 -0
- nvfuser/include/nvfuser/ir/base_nodes.h +687 -0
- nvfuser/include/nvfuser/ir/builder.h +215 -0
- nvfuser/include/nvfuser/ir/builder_passkey.h +29 -0
- nvfuser/include/nvfuser/ir/cloner.h +185 -0
- nvfuser/include/nvfuser/ir/container.h +226 -0
- nvfuser/include/nvfuser/ir/graphviz.h +119 -0
- nvfuser/include/nvfuser/ir/interface_nodes.h +957 -0
- nvfuser/include/nvfuser/ir/internal_base_nodes.h +744 -0
- nvfuser/include/nvfuser/ir/internal_nodes.h +2792 -0
- nvfuser/include/nvfuser/ir/iostream.h +98 -0
- nvfuser/include/nvfuser/ir/printer.h +57 -0
- nvfuser/include/nvfuser/ir/utils.h +801 -0
- nvfuser/include/nvfuser/iter_visitor.h +661 -0
- nvfuser/include/nvfuser/kernel.h +299 -0
- nvfuser/include/nvfuser/kernel_db/kernel_db.h +109 -0
- nvfuser/include/nvfuser/kernel_db/utils.h +37 -0
- nvfuser/include/nvfuser/kernel_ir.h +1457 -0
- nvfuser/include/nvfuser/kernel_ir_dispatch.h +147 -0
- nvfuser/include/nvfuser/linked_hash_map.h +97 -0
- nvfuser/include/nvfuser/logical_domain_map.h +577 -0
- nvfuser/include/nvfuser/macros.h +23 -0
- nvfuser/include/nvfuser/mma_type.h +257 -0
- nvfuser/include/nvfuser/multidevice/c10d_mock.h +175 -0
- nvfuser/include/nvfuser/multidevice/communication.h +232 -0
- nvfuser/include/nvfuser/multidevice/communicator.h +179 -0
- nvfuser/include/nvfuser/multidevice/device_mesh.h +95 -0
- nvfuser/include/nvfuser/multidevice/executor.h +107 -0
- nvfuser/include/nvfuser/multidevice/multidevice.h +18 -0
- nvfuser/include/nvfuser/multidevice/utils.h +187 -0
- nvfuser/include/nvfuser/non_divisible_split.h +86 -0
- nvfuser/include/nvfuser/opaque_type.h +129 -0
- nvfuser/include/nvfuser/ops/alias.h +192 -0
- nvfuser/include/nvfuser/ops/all_ops.h +13 -0
- nvfuser/include/nvfuser/ops/arith.h +712 -0
- nvfuser/include/nvfuser/ops/composite.h +130 -0
- nvfuser/include/nvfuser/ops/indexing.h +55 -0
- nvfuser/include/nvfuser/ops/normalization.h +263 -0
- nvfuser/include/nvfuser/ops/utils.h +127 -0
- nvfuser/include/nvfuser/options.h +313 -0
- nvfuser/include/nvfuser/parallel_dimension_map.h +95 -0
- nvfuser/include/nvfuser/parallel_type_bitmap.h +365 -0
- nvfuser/include/nvfuser/polymorphic_value.h +432 -0
- nvfuser/include/nvfuser/predicate_compute.h +213 -0
- nvfuser/include/nvfuser/python_frontend/distributed_tensor.h +50 -0
- nvfuser/include/nvfuser/python_frontend/fusion_cache.h +298 -0
- nvfuser/include/nvfuser/python_frontend/fusion_definition.h +372 -0
- nvfuser/include/nvfuser/python_frontend/fusion_record.h +3124 -0
- nvfuser/include/nvfuser/python_frontend/fusion_state.h +143 -0
- nvfuser/include/nvfuser/python_frontend/python_bindings.h +27 -0
- nvfuser/include/nvfuser/python_frontend/segmentation.h +246 -0
- nvfuser/include/nvfuser/python_frontend/translation.h +20 -0
- nvfuser/include/nvfuser/python_frontend/translation_utils.h +308 -0
- nvfuser/include/nvfuser/scheduler/all_schedulers.h +17 -0
- nvfuser/include/nvfuser/scheduler/ampere_multi_matmul.h +206 -0
- nvfuser/include/nvfuser/scheduler/cache_policy_refiner.h +19 -0
- nvfuser/include/nvfuser/scheduler/compile_time_info.h +322 -0
- nvfuser/include/nvfuser/scheduler/debug_utils.h +68 -0
- nvfuser/include/nvfuser/scheduler/expr_eval_sched.h +45 -0
- nvfuser/include/nvfuser/scheduler/heuristic.h +113 -0
- nvfuser/include/nvfuser/scheduler/hopper_multi_matmul.h +204 -0
- nvfuser/include/nvfuser/scheduler/mark_aliases.h +19 -0
- nvfuser/include/nvfuser/scheduler/matmul.h +40 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic.h +293 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin.h +65 -0
- nvfuser/include/nvfuser/scheduler/matmul_heuristic_plugin_api.h +99 -0
- nvfuser/include/nvfuser/scheduler/matmul_utils.h +54 -0
- nvfuser/include/nvfuser/scheduler/mma_utils.h +500 -0
- nvfuser/include/nvfuser/scheduler/multi_matmul.h +74 -0
- nvfuser/include/nvfuser/scheduler/no_op.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner.h +49 -0
- nvfuser/include/nvfuser/scheduler/normalization_inner_outer.h +51 -0
- nvfuser/include/nvfuser/scheduler/normalization_outer.h +48 -0
- nvfuser/include/nvfuser/scheduler/normalization_utils.h +379 -0
- nvfuser/include/nvfuser/scheduler/pointwise.h +183 -0
- nvfuser/include/nvfuser/scheduler/pointwise_heuristic.h +118 -0
- nvfuser/include/nvfuser/scheduler/pointwise_utils.h +24 -0
- nvfuser/include/nvfuser/scheduler/reduction.h +43 -0
- nvfuser/include/nvfuser/scheduler/reduction_heuristic.h +339 -0
- nvfuser/include/nvfuser/scheduler/reduction_utils.h +159 -0
- nvfuser/include/nvfuser/scheduler/registry.h +97 -0
- nvfuser/include/nvfuser/scheduler/registry_utils.h +111 -0
- nvfuser/include/nvfuser/scheduler/resize.h +41 -0
- nvfuser/include/nvfuser/scheduler/resize_heuristic.h +67 -0
- nvfuser/include/nvfuser/scheduler/runtime_info.h +166 -0
- nvfuser/include/nvfuser/scheduler/scheduler_types.h +80 -0
- nvfuser/include/nvfuser/scheduler/transpose.h +114 -0
- nvfuser/include/nvfuser/scheduler/transpose_heuristic.h +164 -0
- nvfuser/include/nvfuser/scheduler/utils.h +771 -0
- nvfuser/include/nvfuser/scheduler/vectorize_helper.h +349 -0
- nvfuser/include/nvfuser/serde/factory.h +55 -0
- nvfuser/include/nvfuser/serde/fusion_cache_generated.h +4319 -0
- nvfuser/include/nvfuser/serde/fusion_record.h +124 -0
- nvfuser/include/nvfuser/serde/polymorphic_value.h +52 -0
- nvfuser/include/nvfuser/serde/utils.h +34 -0
- nvfuser/include/nvfuser/struct.inl +127 -0
- nvfuser/include/nvfuser/swizzle.h +54 -0
- nvfuser/include/nvfuser/sys_utils.h +40 -0
- nvfuser/include/nvfuser/tensor_metadata.h +118 -0
- nvfuser/include/nvfuser/tma.h +124 -0
- nvfuser/include/nvfuser/transform_iter.h +522 -0
- nvfuser/include/nvfuser/transform_replay.h +297 -0
- nvfuser/include/nvfuser/transform_rfactor.h +33 -0
- nvfuser/include/nvfuser/transform_view.h +136 -0
- nvfuser/include/nvfuser/type.h +1125 -0
- nvfuser/include/nvfuser/type_promotion.h +61 -0
- nvfuser/include/nvfuser/utils.h +619 -0
- nvfuser/include/nvfuser/val_graph.h +446 -0
- nvfuser/include/nvfuser/val_graph_visitor.h +259 -0
- nvfuser/include/nvfuser/validator_utils.h +92 -0
- nvfuser/include/nvfuser/vectorization_info.h +31 -0
- nvfuser/include/nvfuser/visibility.h +21 -0
- nvfuser/lib/libnvfuser_codegen.so +0 -0
- nvfuser/nvfuser_version.py +69 -0
- nvfuser/pytorch_utils.py +184 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig-release.cmake +20 -0
- nvfuser/share/cmake/nvfuser/NvfuserConfig.cmake +106 -0
- nvfuser/utils.py +18 -0
- nvfuser/version.py +1 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/LICENSE +976 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/METADATA +20 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/RECORD +242 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/WHEEL +5 -0
- nvfuser_cu121_torch25-0.2.25.dev20250201.dist-info/top_level.txt +1 -0
- nvfuser_cu121_torch25.libs/libnvToolsExt-847d78f2.so.1.0.0 +0 -0
@@ -0,0 +1,725 @@
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
|
2
|
+
# All rights reserved.
|
3
|
+
# SPDX-License-Identifier: BSD-3-Clause
|
4
|
+
import enum
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
import nvfuser
|
10
|
+
|
11
|
+
from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype
|
12
|
+
|
13
|
+
|
14
|
+
__all__ = [
|
15
|
+
"InstanceNorm1dNVFuser",
|
16
|
+
"InstanceNorm2dNVFuser",
|
17
|
+
"InstanceNorm3dNVFuser",
|
18
|
+
]
|
19
|
+
|
20
|
+
|
21
|
+
NamedAxis = enum.Enum("NamedAxis", ["BATCH", "CHANNEL"])
|
22
|
+
|
23
|
+
|
24
|
+
def partially_contig_tensor(
|
25
|
+
fd: "nvfuser.FusionDefinition",
|
26
|
+
x: torch.Tensor,
|
27
|
+
) -> "nvfuser.Tensor":
|
28
|
+
return fd.define_tensor(
|
29
|
+
shape=[1 if dim_size == 1 else -1 for dim_size in x.size()],
|
30
|
+
contiguity=nvfuser.compute_contiguity(x.size(), x.stride()),
|
31
|
+
dtype=torch_dtype_to_nvfuser_dtype(x.dtype),
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
def norm_fusion_forward(
|
36
|
+
fd: "nvfuser.FusionDefinition",
|
37
|
+
inputs: List[torch.Tensor],
|
38
|
+
x: "nvfuser.Tensor",
|
39
|
+
weight: Optional["nvfuser.Tensor"],
|
40
|
+
bias: Optional["nvfuser.Tensor"],
|
41
|
+
running_mean: Optional["nvfuser.Tensor"],
|
42
|
+
running_var: Optional["nvfuser.Tensor"],
|
43
|
+
eps: "nvfuser.Scalar",
|
44
|
+
use_input_stats: bool,
|
45
|
+
momentum: "nvfuser.Scalar",
|
46
|
+
channels_last: bool,
|
47
|
+
x_datatype: "nvfuser.DataType",
|
48
|
+
unbiased: bool = False,
|
49
|
+
*,
|
50
|
+
stat_axes: List[NamedAxis],
|
51
|
+
) -> Tuple["nvfuser.Tensor", "nvfuser.Tensor", "nvfuser.Tensor"]:
|
52
|
+
"""Modify FusionDefinition to add a generic normalization layer (forward).
|
53
|
+
|
54
|
+
This can be used to construct a BatchNorm, GroupNorm, InstanceNorm, or
|
55
|
+
LayerNorm network by indicating different sets of axes to preserve.
|
56
|
+
|
57
|
+
BatchNorm: `stat_axes = [NamedAxis.CHANNEL]`
|
58
|
+
LayerNorm: `stat_axes = [NamedAxis.BATCH]`
|
59
|
+
InstanceNorm: `stat_axes = [NamedAxis.BATCH, NamedAxis.CHANNEL]`
|
60
|
+
|
61
|
+
Args:
|
62
|
+
fd: An initialized FusionDefinition.
|
63
|
+
inputs: A list of :class:'torch.Tensor' inputs to the
|
64
|
+
`FusionDefinition` `fd`.
|
65
|
+
x: An input NVFuser tensor.
|
66
|
+
weight: If given, multiply normed output by this `Tensor`. It should be
|
67
|
+
one-dimensional if `NamedAxis.CHANNEL` is in `stat_axes`, and
|
68
|
+
zero-dimensional otherwise. It will be broadcast along all other
|
69
|
+
dimensions.
|
70
|
+
bias: If given, add this `Tensor` to normed output. It should be
|
71
|
+
one-dimensional if `NamedAxis.CHANNEL` is in `stat_axes`, and
|
72
|
+
zero-dimensional otherwise. It will be broadcast along all other
|
73
|
+
dimensions.
|
74
|
+
running_mean: If given, a running mean estimate that will be modified
|
75
|
+
in place.
|
76
|
+
running_var: If given, a running variance estimate that will be
|
77
|
+
modified in place.
|
78
|
+
eps: Amount to regularize the square root needed to convert variance to
|
79
|
+
standard deviation.
|
80
|
+
use_input_stats: Whether to compute the stats of this batch or to
|
81
|
+
_only_ use the provided running_mean and running_var.
|
82
|
+
momentum: Momentum for exponentially weighted moving average of running
|
83
|
+
stats.
|
84
|
+
channels_last: Whether channels are in position -1 (`True`) or 1
|
85
|
+
(`False`).
|
86
|
+
x_datatype: :class:'DataType' of input :class:'Tensor' `x`
|
87
|
+
unbiased: Whether to use unbiased variance for computing current batch
|
88
|
+
statistics. Note that unbiased estimates are always used for
|
89
|
+
running variance updates, regardless of this argument's value.
|
90
|
+
stat_axes: A list of `NamedAxis` objects indicating a combination of
|
91
|
+
axes with which to index the computed statistics. This can be used
|
92
|
+
to implement multiple types of normalization layers, since most of
|
93
|
+
those differ only in which axes are reduced over.
|
94
|
+
Returns:
|
95
|
+
The normalized output, as well as mean and 1/std. Note that
|
96
|
+
`fd.add_output` is _not_ called by this function.
|
97
|
+
"""
|
98
|
+
assert (running_var is None) == (
|
99
|
+
running_mean is None
|
100
|
+
), "Iff running mean or var is given, the other should be"
|
101
|
+
|
102
|
+
# dyn_shape holds Scalars describing the size of the input x
|
103
|
+
dyn_shape = fd.ops.tensor_sizes(x)
|
104
|
+
|
105
|
+
num_dims = len(dyn_shape)
|
106
|
+
|
107
|
+
batch_dim = 0
|
108
|
+
batch_size = dyn_shape[batch_dim]
|
109
|
+
|
110
|
+
channel_dim = num_dims - 1 if channels_last else 1
|
111
|
+
num_channels = dyn_shape[channel_dim]
|
112
|
+
|
113
|
+
# Running stats will be kept possibly for channel but never by instance, so
|
114
|
+
# we will reduce along batch_dim before updating running stats.
|
115
|
+
# These are used to broadcast in spatial dims
|
116
|
+
is_spatial_dim = [True] * num_dims
|
117
|
+
is_spatial_or_batch_dim = [True] * num_dims
|
118
|
+
|
119
|
+
if NamedAxis.BATCH in stat_axes:
|
120
|
+
is_spatial_dim[batch_dim] = False
|
121
|
+
if NamedAxis.CHANNEL in stat_axes:
|
122
|
+
is_spatial_dim[channel_dim] = False
|
123
|
+
is_spatial_or_batch_dim[channel_dim] = False
|
124
|
+
x_reduction_axes = [ax for ax, flag in enumerate(is_spatial_dim) if flag]
|
125
|
+
num_features = fd.define_scalar(1)
|
126
|
+
for ax in x_reduction_axes:
|
127
|
+
num_features = fd.ops.mul(num_features, dyn_shape[ax])
|
128
|
+
|
129
|
+
if use_input_stats or running_mean is None:
|
130
|
+
# In NVFuser Python we pass correction=1 to request unbiased variance calculation
|
131
|
+
x_var, x_mean = fd.ops.var_mean(x, x_reduction_axes, int(unbiased))
|
132
|
+
if running_mean is not None:
|
133
|
+
one = fd.define_scalar(1.0)
|
134
|
+
rev_momentum = fd.ops.sub(one, momentum)
|
135
|
+
|
136
|
+
# do running mean with momentum
|
137
|
+
current_mean_hat = fd.ops.mul(x_mean, momentum)
|
138
|
+
mean_hat = fd.ops.mul(running_mean, rev_momentum)
|
139
|
+
new_mean_hat = fd.ops.add(mean_hat, current_mean_hat)
|
140
|
+
|
141
|
+
# If computing stats for each instance, we don't want to keep those
|
142
|
+
# for our running mean calculation, so we sum them here
|
143
|
+
new_mean_sum = (
|
144
|
+
fd.ops.sum(new_mean_hat, [0])
|
145
|
+
if NamedAxis.BATCH in stat_axes
|
146
|
+
else new_mean_hat
|
147
|
+
)
|
148
|
+
|
149
|
+
rev_batch_size = fd.ops.reciprocal(batch_size)
|
150
|
+
new_mean_channels_only = fd.ops.mul(new_mean_sum, rev_batch_size)
|
151
|
+
if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]:
|
152
|
+
new_mean_channels_only = fd.ops.cast(new_mean_channels_only, x_datatype)
|
153
|
+
fd.add_output(new_mean_channels_only, alias_input=running_mean)
|
154
|
+
|
155
|
+
# running var calculation
|
156
|
+
x_var_unbiased = x_var
|
157
|
+
if not unbiased:
|
158
|
+
# multiply by correction to go from biased to unbiased estimate
|
159
|
+
b2ub = fd.ops.div(
|
160
|
+
num_features, fd.ops.sub(num_features, fd.define_scalar(1))
|
161
|
+
)
|
162
|
+
x_var_unbiased = fd.ops.mul(x_var, b2ub)
|
163
|
+
|
164
|
+
current_var_hat = fd.ops.mul(x_var_unbiased, momentum)
|
165
|
+
var_hat = fd.ops.mul(running_var, rev_momentum)
|
166
|
+
new_var_hat = fd.ops.add(var_hat, current_var_hat)
|
167
|
+
|
168
|
+
# See above about reducing over batch dim for running stats
|
169
|
+
new_var_sum = (
|
170
|
+
fd.ops.sum(new_var_hat, [0])
|
171
|
+
if NamedAxis.BATCH in stat_axes
|
172
|
+
else new_var_hat
|
173
|
+
)
|
174
|
+
|
175
|
+
new_var_channels_only = fd.ops.mul(new_var_sum, rev_batch_size)
|
176
|
+
if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]:
|
177
|
+
new_var_channels_only = fd.ops.cast(new_var_channels_only, x_datatype)
|
178
|
+
fd.add_output(new_var_channels_only, alias_input=running_var)
|
179
|
+
|
180
|
+
mean = x_mean
|
181
|
+
mean_bcast = fd.ops.broadcast(mean, is_spatial_dim)
|
182
|
+
x_sub_mean = fd.ops.sub(x, mean_bcast)
|
183
|
+
|
184
|
+
var_eps = fd.ops.add(x_var, eps)
|
185
|
+
invstd = fd.ops.rsqrt(var_eps)
|
186
|
+
invstd_bcast = fd.ops.broadcast(invstd, is_spatial_dim)
|
187
|
+
|
188
|
+
x_normed = fd.ops.mul(x_sub_mean, invstd_bcast)
|
189
|
+
|
190
|
+
else: # This is inference mode with running stats
|
191
|
+
assert running_mean is not None
|
192
|
+
r_mean_bcast = fd.ops.broadcast(running_mean, is_spatial_or_batch_dim)
|
193
|
+
x_sub_mean = fd.ops.sub(x, r_mean_bcast)
|
194
|
+
|
195
|
+
var_eps = fd.ops.add(running_var, eps)
|
196
|
+
invstd = fd.ops.rsqrt(var_eps)
|
197
|
+
invstd_bcast = fd.ops.broadcast(invstd, is_spatial_or_batch_dim)
|
198
|
+
|
199
|
+
mean = running_mean
|
200
|
+
x_normed = fd.ops.mul(x_sub_mean, invstd_bcast)
|
201
|
+
|
202
|
+
if weight is not None:
|
203
|
+
weight_bcast = fd.ops.broadcast(weight, is_spatial_or_batch_dim)
|
204
|
+
x_normed = fd.ops.mul(x_normed, weight_bcast)
|
205
|
+
if bias is not None:
|
206
|
+
bias_bcast = fd.ops.broadcast(bias, is_spatial_or_batch_dim)
|
207
|
+
x_normed = fd.ops.add(x_normed, bias_bcast)
|
208
|
+
|
209
|
+
return x_normed, mean, invstd
|
210
|
+
|
211
|
+
|
212
|
+
def norm_fusion_backward(
|
213
|
+
fd: "nvfuser.FusionDefinition",
|
214
|
+
inputs: List[torch.Tensor],
|
215
|
+
x: "nvfuser.Tensor",
|
216
|
+
grad_output: "nvfuser.Tensor",
|
217
|
+
mean: Optional[torch.Tensor],
|
218
|
+
invstd: torch.Tensor,
|
219
|
+
weight: Optional["nvfuser.Tensor"],
|
220
|
+
bias: Optional["nvfuser.Tensor"],
|
221
|
+
running_mean: Optional["nvfuser.Tensor"],
|
222
|
+
running_var: Optional["nvfuser.Tensor"],
|
223
|
+
use_input_stats: bool,
|
224
|
+
channels_last: bool,
|
225
|
+
x_datatype: "nvfuser.DataType",
|
226
|
+
*,
|
227
|
+
stat_axes: List[NamedAxis],
|
228
|
+
) -> Tuple["nvfuser.Tensor", "nvfuser.Tensor", "nvfuser.Tensor"]:
|
229
|
+
"""
|
230
|
+
Modify FusionDefinition to add a generic normalization layer (backward).
|
231
|
+
|
232
|
+
Args:
|
233
|
+
fd: An initialized FusionDefinition.
|
234
|
+
inputs: A list of :class:'torch.Tensor' inputs to the
|
235
|
+
`FusionDefinition` `fd`.
|
236
|
+
x: The input NVFuser tensor.
|
237
|
+
grad_output: NVFuser tensor representing gradient of loss with respect
|
238
|
+
to downstream activation (typical input to backward()).
|
239
|
+
mean: The mean used in the forward normalization.
|
240
|
+
invstd: The reciprocal of standard deviation used in the forward normalization.
|
241
|
+
weight: If given, multiply normed output by this `Tensor`. It should be
|
242
|
+
one-dimensional if `NamedAxis.CHANNEL` is in `stat_axes`, and
|
243
|
+
zero-dimensional otherwise. It will be broadcast along all other
|
244
|
+
dimensions.
|
245
|
+
bias: If given, add this `Tensor` to normed output. It should be
|
246
|
+
one-dimensional if `NamedAxis.CHANNEL` is in `stat_axes`, and
|
247
|
+
zero-dimensional otherwise. It will be broadcast along all other
|
248
|
+
dimensions.
|
249
|
+
running_mean: If given, a running mean estimate that will be modified
|
250
|
+
in place.
|
251
|
+
running_var: If given, a running variance estimate that will be
|
252
|
+
modified in place.
|
253
|
+
use_input_stats: Whether to compute the stats of this batch or to
|
254
|
+
_only_ use the provided running_mean and running_var.
|
255
|
+
channels_last: Whether channels are in position -1 (`True`) or 1
|
256
|
+
(`False`).
|
257
|
+
x_datatype: :class:'DataType' of input :class:'Tensor' `x`
|
258
|
+
stat_axes: A list of `NamedAxis` objects indicating a combination of
|
259
|
+
axes with which to index the computed statistics. This can be used
|
260
|
+
to implement multiple types of normalization layers, since most of
|
261
|
+
those differ only in which axes are reduced over.
|
262
|
+
Returns:
|
263
|
+
The normalized output, as well as mean and 1/std. Note that
|
264
|
+
`fd.add_output` is _not_ called by this function.
|
265
|
+
"""
|
266
|
+
assert not (
|
267
|
+
(running_var is None) ^ (running_mean is None)
|
268
|
+
), "Iff running mean or var is given, the other should be"
|
269
|
+
|
270
|
+
# dyn_shape holds Scalars describing the size of the input x
|
271
|
+
dyn_shape = fd.ops.tensor_sizes(x)
|
272
|
+
|
273
|
+
num_dims = len(dyn_shape)
|
274
|
+
|
275
|
+
batch_dim = 0
|
276
|
+
batch_size = dyn_shape[batch_dim]
|
277
|
+
|
278
|
+
channel_dim = num_dims - 1 if channels_last else 1
|
279
|
+
num_channels = dyn_shape[channel_dim]
|
280
|
+
|
281
|
+
# Running stats will be kept possibly for channel but never by instance, so
|
282
|
+
# we will reduce along batch_dim before updating running stats.
|
283
|
+
# These are used to broadcast in spatial dims
|
284
|
+
is_spatial_dim = [True] * num_dims
|
285
|
+
is_spatial_or_batch_dim = [True] * num_dims
|
286
|
+
|
287
|
+
if NamedAxis.BATCH in stat_axes:
|
288
|
+
is_spatial_dim[batch_dim] = False
|
289
|
+
if NamedAxis.CHANNEL in stat_axes:
|
290
|
+
is_spatial_dim[channel_dim] = False
|
291
|
+
is_spatial_or_batch_dim[channel_dim] = False
|
292
|
+
x_reduction_axes = [ax for ax, flag in enumerate(is_spatial_dim) if flag]
|
293
|
+
num_features = fd.define_scalar(1)
|
294
|
+
for ax in x_reduction_axes:
|
295
|
+
num_features = fd.ops.mul(num_features, dyn_shape[ax])
|
296
|
+
|
297
|
+
mean = fd.ops.broadcast(mean, is_spatial_dim)
|
298
|
+
|
299
|
+
norm = fd.ops.reciprocal(num_features)
|
300
|
+
grad_output_sum = fd.ops.sum(grad_output, x_reduction_axes)
|
301
|
+
dot_p = fd.ops.sum(
|
302
|
+
fd.ops.mul(
|
303
|
+
grad_output,
|
304
|
+
fd.ops.sub(x, mean),
|
305
|
+
),
|
306
|
+
x_reduction_axes,
|
307
|
+
)
|
308
|
+
grad_mean = fd.ops.broadcast(fd.ops.mul(grad_output_sum, norm), is_spatial_dim)
|
309
|
+
proj_scale = fd.ops.broadcast(
|
310
|
+
fd.ops.mul(
|
311
|
+
fd.ops.mul(dot_p, norm),
|
312
|
+
fd.ops.mul(invstd, invstd),
|
313
|
+
),
|
314
|
+
is_spatial_dim,
|
315
|
+
)
|
316
|
+
|
317
|
+
invstd_bcast = fd.ops.broadcast(invstd, is_spatial_dim)
|
318
|
+
grad_scale = (
|
319
|
+
invstd_bcast
|
320
|
+
if weight is None
|
321
|
+
else fd.ops.mul(
|
322
|
+
invstd_bcast,
|
323
|
+
fd.ops.broadcast(weight, is_spatial_or_batch_dim),
|
324
|
+
)
|
325
|
+
)
|
326
|
+
if use_input_stats:
|
327
|
+
proj = fd.ops.mul(fd.ops.sub(x, mean), proj_scale)
|
328
|
+
grad_input = fd.ops.mul(
|
329
|
+
fd.ops.sub(
|
330
|
+
fd.ops.sub(grad_output, proj),
|
331
|
+
grad_mean,
|
332
|
+
),
|
333
|
+
grad_scale,
|
334
|
+
)
|
335
|
+
else:
|
336
|
+
grad_input = fd.ops.mul(grad_output, grad_scale)
|
337
|
+
|
338
|
+
if weight is not None:
|
339
|
+
grad_weight = fd.ops.mul(dot_p, invstd)
|
340
|
+
grad_weight_reduced = fd.ops.sum(grad_weight, [0])
|
341
|
+
else:
|
342
|
+
grad_weight_reduced = None
|
343
|
+
if bias is not None:
|
344
|
+
grad_bias = grad_output_sum
|
345
|
+
grad_bias_reduced = fd.ops.sum(grad_bias, [0])
|
346
|
+
else:
|
347
|
+
grad_bias_reduced = None
|
348
|
+
|
349
|
+
return grad_input, grad_weight_reduced, grad_bias_reduced
|
350
|
+
|
351
|
+
|
352
|
+
class NormNVFuserFunction(torch.autograd.Function):
|
353
|
+
@staticmethod
|
354
|
+
def forward(
|
355
|
+
ctx: Any, # contexts are actually objects of the type we are currently defining
|
356
|
+
x: torch.Tensor,
|
357
|
+
weight: Optional[torch.Tensor],
|
358
|
+
bias: Optional[torch.Tensor],
|
359
|
+
running_mean: Optional[torch.Tensor],
|
360
|
+
running_var: Optional[torch.Tensor],
|
361
|
+
use_input_stats: bool,
|
362
|
+
momentum: float,
|
363
|
+
eps: float,
|
364
|
+
unbiased: bool,
|
365
|
+
stat_axes: List[NamedAxis],
|
366
|
+
) -> torch.Tensor:
|
367
|
+
# When x.shape[1] == 1, is_contiguous will tell us the tensor is
|
368
|
+
# channels_last, even when it is ordinary contiguous. This causes some
|
369
|
+
# issues so we only detect channels_last when channels > 1
|
370
|
+
channels_last = x.shape[1] > 1 and (
|
371
|
+
x.is_contiguous(memory_format=torch.channels_last)
|
372
|
+
or x.is_contiguous(memory_format=torch.channels_last_3d)
|
373
|
+
)
|
374
|
+
xorig = x
|
375
|
+
if channels_last:
|
376
|
+
order = [0] + list(range(2, len(x.shape))) + [1]
|
377
|
+
x = x.permute(order)
|
378
|
+
|
379
|
+
x_datatype = torch_dtype_to_nvfuser_dtype(x.dtype)
|
380
|
+
|
381
|
+
with nvfuser.FusionDefinition() as fd:
|
382
|
+
tv_x = partially_contig_tensor(fd, x)
|
383
|
+
inputs = [x]
|
384
|
+
if weight is not None:
|
385
|
+
tv_weight = partially_contig_tensor(fd, weight)
|
386
|
+
inputs.append(weight)
|
387
|
+
else:
|
388
|
+
tv_weight = None
|
389
|
+
|
390
|
+
if bias is not None:
|
391
|
+
tv_bias = partially_contig_tensor(fd, bias)
|
392
|
+
inputs.append(bias)
|
393
|
+
else:
|
394
|
+
tv_bias = None
|
395
|
+
|
396
|
+
if running_mean is None:
|
397
|
+
tv_running_mean = None
|
398
|
+
tv_running_var = None
|
399
|
+
else:
|
400
|
+
assert running_var is not None
|
401
|
+
tv_running_mean = partially_contig_tensor(fd, running_mean)
|
402
|
+
tv_running_var = partially_contig_tensor(fd, running_var)
|
403
|
+
inputs.extend([running_mean, running_var])
|
404
|
+
|
405
|
+
s_momentum = fd.define_scalar(nvfuser.DataType.Double)
|
406
|
+
s_eps = fd.define_scalar(nvfuser.DataType.Double)
|
407
|
+
inputs.extend([momentum, eps])
|
408
|
+
|
409
|
+
# cast inputs if necessary
|
410
|
+
if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]:
|
411
|
+
tv_x = fd.ops.cast(tv_x, nvfuser.DataType.Float)
|
412
|
+
if weight is not None and weight.dtype in [torch.half, torch.bfloat16]:
|
413
|
+
tv_weight = fd.ops.cast(tv_weight, nvfuser.DataType.Float)
|
414
|
+
if bias is not None and bias.dtype in [torch.half, torch.bfloat16]:
|
415
|
+
tv_bias = fd.ops.cast(tv_bias, nvfuser.DataType.Float)
|
416
|
+
|
417
|
+
out, mean, invstd = norm_fusion_forward(
|
418
|
+
fd,
|
419
|
+
inputs,
|
420
|
+
tv_x,
|
421
|
+
tv_weight,
|
422
|
+
tv_bias,
|
423
|
+
tv_running_mean,
|
424
|
+
tv_running_var,
|
425
|
+
s_eps,
|
426
|
+
use_input_stats,
|
427
|
+
s_momentum,
|
428
|
+
channels_last,
|
429
|
+
x_datatype=x_datatype,
|
430
|
+
unbiased=unbiased,
|
431
|
+
stat_axes=stat_axes,
|
432
|
+
)
|
433
|
+
|
434
|
+
if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]:
|
435
|
+
out = fd.ops.cast(out, x_datatype)
|
436
|
+
|
437
|
+
fd.add_output(out)
|
438
|
+
fd.add_output(mean)
|
439
|
+
fd.add_output(invstd)
|
440
|
+
|
441
|
+
out, mean, invstd = fd.execute(inputs)
|
442
|
+
|
443
|
+
ctx.stat_axes = stat_axes
|
444
|
+
ctx.use_input_stats = use_input_stats
|
445
|
+
ctx.channels_last = channels_last
|
446
|
+
# saving for backward in "explicit channels-last format"
|
447
|
+
ctx.save_for_backward(x, weight, bias, running_mean, running_var, mean, invstd)
|
448
|
+
if channels_last:
|
449
|
+
order = [0, len(x.shape) - 1] + list(range(1, len(x.shape) - 1))
|
450
|
+
out = out.permute(order)
|
451
|
+
if len(out.shape) == 4:
|
452
|
+
assert out.is_contiguous(memory_format=torch.channels_last)
|
453
|
+
assert xorig.is_contiguous(memory_format=torch.channels_last)
|
454
|
+
elif len(out.shape) == 5:
|
455
|
+
assert out.is_contiguous(memory_format=torch.channels_last_3d)
|
456
|
+
assert xorig.is_contiguous(memory_format=torch.channels_last_3d)
|
457
|
+
else:
|
458
|
+
raise RuntimeError(
|
459
|
+
"unhandled channels_last format variation in forward"
|
460
|
+
)
|
461
|
+
return out
|
462
|
+
|
463
|
+
@staticmethod
|
464
|
+
def backward(
|
465
|
+
ctx: Any, grad_output: torch.Tensor
|
466
|
+
) -> Tuple[
|
467
|
+
torch.Tensor,
|
468
|
+
torch.Tensor,
|
469
|
+
torch.Tensor,
|
470
|
+
None,
|
471
|
+
None,
|
472
|
+
None,
|
473
|
+
None,
|
474
|
+
None,
|
475
|
+
None,
|
476
|
+
None,
|
477
|
+
]:
|
478
|
+
"""Instance norm backward using NVFuser"""
|
479
|
+
if ctx.channels_last:
|
480
|
+
order = [0] + list(range(2, len(grad_output.shape))) + [1]
|
481
|
+
grad_output = grad_output.permute(order)
|
482
|
+
# input was saved in "explicit channels-last format"
|
483
|
+
# assert ctx.saved_tensors[0].is_contiguous()
|
484
|
+
# grad_output = grad_output.contiguous()
|
485
|
+
x, weight, bias, running_mean, running_var, mean, invstd = ctx.saved_tensors
|
486
|
+
|
487
|
+
with nvfuser.FusionDefinition() as fd:
|
488
|
+
tv_x = partially_contig_tensor(fd, x)
|
489
|
+
if x.dtype in [torch.half, torch.bfloat16]:
|
490
|
+
tv_x = fd.ops.cast(tv_x, nvfuser.DataType.Float)
|
491
|
+
inputs = [x]
|
492
|
+
if weight is not None:
|
493
|
+
tv_weight = partially_contig_tensor(fd, weight)
|
494
|
+
if weight.dtype in [torch.half, torch.bfloat16]:
|
495
|
+
tv_weight = fd.ops.cast(tv_weight, nvfuser.DataType.Float)
|
496
|
+
inputs.append(weight)
|
497
|
+
else:
|
498
|
+
tv_weight = None
|
499
|
+
if bias is not None:
|
500
|
+
tv_bias = partially_contig_tensor(fd, bias)
|
501
|
+
if bias.dtype in [torch.half, torch.bfloat16]:
|
502
|
+
tv_bias = fd.ops.cast(tv_bias, nvfuser.DataType.Float)
|
503
|
+
inputs.append(bias)
|
504
|
+
else:
|
505
|
+
tv_bias = None
|
506
|
+
if running_mean is not None:
|
507
|
+
tv_running_mean = partially_contig_tensor(fd, running_mean)
|
508
|
+
if running_mean.dtype in [torch.half, torch.bfloat16]:
|
509
|
+
tv_running_mean = fd.ops.cast(
|
510
|
+
tv_running_mean, nvfuser.DataType.Float
|
511
|
+
)
|
512
|
+
inputs.append(running_mean)
|
513
|
+
else:
|
514
|
+
tv_running_mean = None
|
515
|
+
if running_var is not None:
|
516
|
+
tv_running_var = partially_contig_tensor(fd, running_var)
|
517
|
+
if running_var.dtype in [torch.half, torch.bfloat16]:
|
518
|
+
tv_running_var = fd.ops.cast(tv_running_var, nvfuser.DataType.Float)
|
519
|
+
inputs.append(running_var)
|
520
|
+
else:
|
521
|
+
tv_running_var = None
|
522
|
+
|
523
|
+
tv_mean = partially_contig_tensor(fd, mean)
|
524
|
+
if mean.dtype in [torch.half, torch.bfloat16]:
|
525
|
+
tv_mean = fd.ops.cast(tv_mean, nvfuser.DataType.Float)
|
526
|
+
inputs.append(mean)
|
527
|
+
tv_invstd = partially_contig_tensor(fd, invstd)
|
528
|
+
if invstd.dtype in [torch.half, torch.bfloat16]:
|
529
|
+
tv_invstd = fd.ops.cast(tv_invstd, nvfuser.DataType.Float)
|
530
|
+
inputs.append(invstd)
|
531
|
+
|
532
|
+
tv_grad_output = partially_contig_tensor(fd, grad_output)
|
533
|
+
if grad_output.dtype in [torch.half, torch.bfloat16]:
|
534
|
+
tv_grad_output = fd.ops.cast(tv_grad_output, nvfuser.DataType.Float)
|
535
|
+
inputs.append(grad_output)
|
536
|
+
|
537
|
+
x_datatype = torch_dtype_to_nvfuser_dtype(x.dtype)
|
538
|
+
|
539
|
+
grad_input, grad_weight, grad_bias = norm_fusion_backward(
|
540
|
+
fd,
|
541
|
+
inputs,
|
542
|
+
tv_x,
|
543
|
+
tv_grad_output,
|
544
|
+
tv_mean,
|
545
|
+
tv_invstd,
|
546
|
+
tv_weight,
|
547
|
+
tv_bias,
|
548
|
+
tv_running_mean,
|
549
|
+
tv_running_var,
|
550
|
+
ctx.use_input_stats,
|
551
|
+
ctx.channels_last,
|
552
|
+
x_datatype=x_datatype,
|
553
|
+
stat_axes=ctx.stat_axes,
|
554
|
+
)
|
555
|
+
|
556
|
+
if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]:
|
557
|
+
grad_input = fd.ops.cast(grad_input, x_datatype)
|
558
|
+
fd.add_output(grad_input)
|
559
|
+
|
560
|
+
if weight is not None:
|
561
|
+
if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]:
|
562
|
+
grad_weight = fd.ops.cast(grad_weight, x_datatype)
|
563
|
+
fd.add_output(grad_weight)
|
564
|
+
|
565
|
+
if bias is not None:
|
566
|
+
if x_datatype in [nvfuser.DataType.Half, nvfuser.DataType.BFloat16]:
|
567
|
+
grad_bias = fd.ops.cast(grad_bias, x_datatype)
|
568
|
+
fd.add_output(grad_bias)
|
569
|
+
|
570
|
+
res = fd.execute(inputs)
|
571
|
+
grad_input = res[0]
|
572
|
+
c = 1
|
573
|
+
if weight is not None:
|
574
|
+
grad_weight = res[c]
|
575
|
+
c += 1
|
576
|
+
else:
|
577
|
+
grad_weight = None
|
578
|
+
if bias is not None:
|
579
|
+
grad_bias = res[c]
|
580
|
+
c += 1
|
581
|
+
else:
|
582
|
+
grad_bias = None
|
583
|
+
|
584
|
+
if ctx.channels_last:
|
585
|
+
order = [0, len(grad_input.shape) - 1] + list(
|
586
|
+
range(1, len(grad_input.shape) - 1)
|
587
|
+
)
|
588
|
+
grad_input = grad_input.permute(order)
|
589
|
+
if len(grad_input.shape) == 4:
|
590
|
+
assert grad_input.is_contiguous(memory_format=torch.channels_last)
|
591
|
+
elif len(grad_input.shape) == 5:
|
592
|
+
assert grad_input.is_contiguous(memory_format=torch.channels_last_3d)
|
593
|
+
else:
|
594
|
+
raise RuntimeError(
|
595
|
+
"unhandled channels_last format variation in backward"
|
596
|
+
)
|
597
|
+
return (
|
598
|
+
grad_input,
|
599
|
+
grad_weight,
|
600
|
+
grad_bias,
|
601
|
+
None,
|
602
|
+
None,
|
603
|
+
None,
|
604
|
+
None,
|
605
|
+
None,
|
606
|
+
None,
|
607
|
+
None,
|
608
|
+
)
|
609
|
+
|
610
|
+
|
611
|
+
class _NormNVFuserBase(torch.nn.modules.batchnorm._NormBase):
|
612
|
+
stat_axes: Optional[List[NamedAxis]] = None
|
613
|
+
|
614
|
+
def __init__(
|
615
|
+
self,
|
616
|
+
num_features: int,
|
617
|
+
eps: float = 1e-5,
|
618
|
+
momentum: float = 0.1,
|
619
|
+
affine: bool = False,
|
620
|
+
track_running_stats: bool = False,
|
621
|
+
device: torch.device = None,
|
622
|
+
dtype: torch.dtype = None,
|
623
|
+
) -> None:
|
624
|
+
factory_kwargs = {"device": device, "dtype": dtype}
|
625
|
+
super().__init__(
|
626
|
+
num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
|
627
|
+
)
|
628
|
+
|
629
|
+
def _check_input_dim(self, input: torch.Tensor) -> None:
|
630
|
+
raise NotImplementedError
|
631
|
+
|
632
|
+
def _load_from_state_dict(
|
633
|
+
self,
|
634
|
+
state_dict: Dict[str, Any],
|
635
|
+
prefix: str,
|
636
|
+
local_metadata: Any,
|
637
|
+
strict: bool,
|
638
|
+
missing_keys: List[str],
|
639
|
+
unexpected_keys: List[str],
|
640
|
+
error_msgs: List[str],
|
641
|
+
) -> None:
|
642
|
+
version = local_metadata.get("version", None)
|
643
|
+
# at version 1: removed running_mean and running_var when
|
644
|
+
# track_running_stats=False (default)
|
645
|
+
if version is None and not self.track_running_stats:
|
646
|
+
running_stats_keys = []
|
647
|
+
for name in ("running_mean", "running_var"):
|
648
|
+
key = prefix + name
|
649
|
+
if key in state_dict:
|
650
|
+
running_stats_keys.append(key)
|
651
|
+
if len(running_stats_keys) > 0:
|
652
|
+
error_msgs.append(
|
653
|
+
"Unexpected running stats buffer(s) {names} for {klass} "
|
654
|
+
"with track_running_stats=False. If state_dict is a "
|
655
|
+
"checkpoint saved before 0.4.0, this may be expected "
|
656
|
+
"because {klass} does not track running stats by default "
|
657
|
+
"since 0.4.0. Please remove these keys from state_dict. If "
|
658
|
+
"the running stats are actually needed, instead set "
|
659
|
+
"track_running_stats=True in {klass} to enable them. See "
|
660
|
+
"the documentation of {klass} for details.".format(
|
661
|
+
names=" and ".join(
|
662
|
+
'"{}"'.format(k) for k in running_stats_keys
|
663
|
+
),
|
664
|
+
klass=self.__class__.__name__,
|
665
|
+
)
|
666
|
+
)
|
667
|
+
for key in running_stats_keys:
|
668
|
+
state_dict.pop(key)
|
669
|
+
|
670
|
+
super()._load_from_state_dict(
|
671
|
+
state_dict,
|
672
|
+
prefix,
|
673
|
+
local_metadata,
|
674
|
+
strict,
|
675
|
+
missing_keys,
|
676
|
+
unexpected_keys,
|
677
|
+
error_msgs,
|
678
|
+
)
|
679
|
+
|
680
|
+
def forward(self, input: nvfuser.Tensor) -> nvfuser.Tensor:
|
681
|
+
assert input.is_cuda, "NVFuser InstanceNorm is CUDA only"
|
682
|
+
self._check_input_dim(input)
|
683
|
+
out = NormNVFuserFunction.apply(
|
684
|
+
input,
|
685
|
+
self.weight,
|
686
|
+
self.bias,
|
687
|
+
self.running_mean,
|
688
|
+
self.running_var,
|
689
|
+
self.training or not self.track_running_stats,
|
690
|
+
self.momentum,
|
691
|
+
self.eps,
|
692
|
+
False, # unbiased=False to match PyTorch functionality
|
693
|
+
self.stat_axes,
|
694
|
+
)
|
695
|
+
return out
|
696
|
+
|
697
|
+
|
698
|
+
class _InstanceNormNVFuser(_NormNVFuserBase):
|
699
|
+
stat_axes = [NamedAxis.BATCH, NamedAxis.CHANNEL]
|
700
|
+
|
701
|
+
|
702
|
+
class _BatchNormNVFuser(_NormNVFuserBase):
|
703
|
+
stat_axes = [NamedAxis.CHANNEL]
|
704
|
+
|
705
|
+
|
706
|
+
class _LayerNormNVFuser(_NormNVFuserBase):
|
707
|
+
stat_axes = [NamedAxis.BATCH]
|
708
|
+
|
709
|
+
|
710
|
+
class InstanceNorm1dNVFuser(_InstanceNormNVFuser):
|
711
|
+
def _check_input_dim(self, input: torch.Tensor) -> None:
|
712
|
+
if input.dim() != 3:
|
713
|
+
raise ValueError("expected 3D input (got {}D input)".format(input.dim()))
|
714
|
+
|
715
|
+
|
716
|
+
class InstanceNorm2dNVFuser(_InstanceNormNVFuser):
|
717
|
+
def _check_input_dim(self, input: torch.Tensor) -> None:
|
718
|
+
if input.dim() != 4:
|
719
|
+
raise ValueError("expected 4D input (got {}D input)".format(input.dim()))
|
720
|
+
|
721
|
+
|
722
|
+
class InstanceNorm3dNVFuser(_InstanceNormNVFuser):
|
723
|
+
def _check_input_dim(self, input: torch.Tensor) -> None:
|
724
|
+
if input.dim() != 5:
|
725
|
+
raise ValueError("expected 5D input (got {}D input)".format(input.dim()))
|