mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""normalization for mint"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
from __future__ import division
|
|
18
|
+
|
|
19
|
+
from typing import Optional
|
|
20
|
+
import numpy as np
|
|
21
|
+
import mindspore as ms
|
|
22
|
+
from mindspore import mint
|
|
23
|
+
from mindspore import ops
|
|
24
|
+
from mindspore import Tensor
|
|
25
|
+
from mindspore.common.parameter import Parameter
|
|
26
|
+
from mindspore.common.initializer import initializer
|
|
27
|
+
from mindspore import _checkparam as validator
|
|
28
|
+
from mindspore.common import dtype as mstype
|
|
29
|
+
from mindspore.nn.cell import Cell
|
|
30
|
+
from mindspore.nn.layer.normalization import LayerNormExt as LayerNorm
|
|
31
|
+
from mindspore.ops import group_norm
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class _NormBase(Cell):
|
|
35
|
+
"""Common base of _InstanceNorm and _BatchNorm"""
|
|
36
|
+
|
|
37
|
+
def __init__(self,
|
|
38
|
+
num_features: int,
|
|
39
|
+
eps: float = 1e-5,
|
|
40
|
+
momentum: float = 0.1,
|
|
41
|
+
affine: bool = True,
|
|
42
|
+
track_running_stats: bool = True,
|
|
43
|
+
dtype=None
|
|
44
|
+
) -> None:
|
|
45
|
+
super(_NormBase, self).__init__()
|
|
46
|
+
self.shape = ops.Shape()
|
|
47
|
+
self.num_features = num_features
|
|
48
|
+
self.eps = eps
|
|
49
|
+
self.momentum = momentum
|
|
50
|
+
self.affine = affine
|
|
51
|
+
self.track_running_stats = track_running_stats
|
|
52
|
+
self.dtype = dtype if dtype is not None else mstype.float32
|
|
53
|
+
if self.affine:
|
|
54
|
+
self.weight = Parameter(
|
|
55
|
+
Tensor(np.empty(num_features), dtype=self.dtype), name="weight")
|
|
56
|
+
self.bias = Parameter(
|
|
57
|
+
Tensor(np.empty(num_features), dtype=self.dtype), name="bias")
|
|
58
|
+
self.weight: Optional[Parameter]
|
|
59
|
+
self.bias: Optional[Parameter]
|
|
60
|
+
else:
|
|
61
|
+
self.weight = None
|
|
62
|
+
self.bias = None
|
|
63
|
+
if self.track_running_stats:
|
|
64
|
+
self.running_mean = Parameter(Tensor(np.zeros(num_features), dtype=self.dtype),
|
|
65
|
+
requires_grad=False, name="running_mean")
|
|
66
|
+
self.running_var = Parameter(Tensor(np.ones(num_features), dtype=self.dtype),
|
|
67
|
+
requires_grad=False, name="running_var")
|
|
68
|
+
self.running_mean: Optional[Tensor]
|
|
69
|
+
self.running_var: Optional[Tensor]
|
|
70
|
+
self.num_batches_tracked = Parameter(Tensor(0, dtype=ms.float32),
|
|
71
|
+
requires_grad=False, name="num_batches_tracked")
|
|
72
|
+
self.num_batches_tracked: Optional[Tensor]
|
|
73
|
+
else:
|
|
74
|
+
self.running_mean = None
|
|
75
|
+
self.running_var = None
|
|
76
|
+
self.num_batches_tracked = None
|
|
77
|
+
self.reset_parameters()
|
|
78
|
+
|
|
79
|
+
def reset_running_stats(self) -> None:
|
|
80
|
+
"""init parameters"""
|
|
81
|
+
|
|
82
|
+
if self.track_running_stats:
|
|
83
|
+
zero_running_mean = Tensor(
|
|
84
|
+
np.zeros(self.num_features), dtype=self.dtype)
|
|
85
|
+
one_running_var = Tensor(
|
|
86
|
+
np.ones(self.num_features), dtype=self.dtype)
|
|
87
|
+
zero_num_batches_tracked = Tensor(0, dtype=ms.float32)
|
|
88
|
+
|
|
89
|
+
ops.assign(self.running_mean, zero_running_mean)
|
|
90
|
+
ops.assign(self.running_var, one_running_var)
|
|
91
|
+
ops.assign(self.num_batches_tracked, zero_num_batches_tracked)
|
|
92
|
+
|
|
93
|
+
def reset_parameters(self) -> None:
|
|
94
|
+
self.reset_running_stats()
|
|
95
|
+
if self.affine:
|
|
96
|
+
one_weight = Tensor(np.ones(self.num_features), dtype=self.dtype)
|
|
97
|
+
zero_bias = Tensor(np.zeros(self.num_features), dtype=self.dtype)
|
|
98
|
+
|
|
99
|
+
ops.assign(self.weight, one_weight)
|
|
100
|
+
ops.assign(self.bias, zero_bias)
|
|
101
|
+
|
|
102
|
+
def _check_input_dim(self, input):
|
|
103
|
+
raise NotImplementedError
|
|
104
|
+
|
|
105
|
+
def extend_repr(self):
|
|
106
|
+
return 'num_features={}, eps={}, momentum={}, affine={}, track_running_stats={}'.format(
|
|
107
|
+
self.num_features, self.eps, self.momentum, self.affine, self.track_running_stats)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class _BatchNorm(_NormBase):
|
|
111
|
+
"""common base of BatchNormXxx"""
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
num_features: int,
|
|
116
|
+
eps=1e-5,
|
|
117
|
+
momentum=0.1,
|
|
118
|
+
affine=True,
|
|
119
|
+
track_running_stats=True,
|
|
120
|
+
dtype=None) -> None:
|
|
121
|
+
super(_BatchNorm, self).__init__(num_features, eps, momentum, affine, track_running_stats,
|
|
122
|
+
dtype)
|
|
123
|
+
self.training = True
|
|
124
|
+
|
|
125
|
+
def _check_input_dim(self, input):
|
|
126
|
+
raise NotImplementedError
|
|
127
|
+
|
|
128
|
+
def construct(self, input):
|
|
129
|
+
self._check_input_dim(input)
|
|
130
|
+
|
|
131
|
+
if self.momentum is None:
|
|
132
|
+
exponential_average_factor = 0.0
|
|
133
|
+
else:
|
|
134
|
+
exponential_average_factor = self.momentum
|
|
135
|
+
|
|
136
|
+
if self.training and self.track_running_stats:
|
|
137
|
+
if self.num_batches_tracked is not None:
|
|
138
|
+
num_batches_tracked_one = Tensor(1, dtype=ms.float32)
|
|
139
|
+
ops.assign_add(self.num_batches_tracked,
|
|
140
|
+
num_batches_tracked_one)
|
|
141
|
+
if self.momentum is None:
|
|
142
|
+
exponential_average_factor = float(1.0 / self.num_batches_tracked)
|
|
143
|
+
else:
|
|
144
|
+
exponential_average_factor = self.momentum
|
|
145
|
+
|
|
146
|
+
if self.training:
|
|
147
|
+
bn_training = True
|
|
148
|
+
else:
|
|
149
|
+
bn_training = (self.running_mean is None) and (
|
|
150
|
+
self.running_var is None)
|
|
151
|
+
|
|
152
|
+
return mint.functional.batch_norm(
|
|
153
|
+
input,
|
|
154
|
+
self.running_mean
|
|
155
|
+
if not self.training or self.track_running_stats
|
|
156
|
+
else None,
|
|
157
|
+
self.running_var if not self.training or self.track_running_stats
|
|
158
|
+
else None,
|
|
159
|
+
self.weight,
|
|
160
|
+
self.bias,
|
|
161
|
+
bn_training,
|
|
162
|
+
exponential_average_factor,
|
|
163
|
+
self.eps,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class BatchNorm1d(_BatchNorm):
|
|
168
|
+
r"""
|
|
169
|
+
Applies Batch Normalization over a 2D or 3D input as described in the paper
|
|
170
|
+
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
|
171
|
+
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_ .
|
|
172
|
+
|
|
173
|
+
.. math::
|
|
174
|
+
|
|
175
|
+
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
|
176
|
+
|
|
177
|
+
The mean and standard-deviation are calculated per-dimension over
|
|
178
|
+
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
179
|
+
of size `C` (where `C` is the number of features or channels of the input). By default, the
|
|
180
|
+
elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0.
|
|
181
|
+
|
|
182
|
+
.. warning::
|
|
183
|
+
This API does not support Dynamic Rank.
|
|
184
|
+
This is an experimental API that is subject to change or deletion.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
num_features (int): `C` from an expected input of shape :math:`(N, C, L)`.
|
|
188
|
+
eps (float, optional): a value added to the denominator for numerical stability.
|
|
189
|
+
Default: ``1e-5`` .
|
|
190
|
+
momentum (float, optional): the value used for the running_mean and running_var
|
|
191
|
+
computation. Can be set to ``None`` for cumulative moving average. Default: ``0.1`` .
|
|
192
|
+
affine (bool, optional): a boolean value that when set to ``True``, this cell has
|
|
193
|
+
learnable affine parameters. Default: ``True`` .
|
|
194
|
+
track_running_stats (bool, optional): a boolean value that when set to ``True``, this
|
|
195
|
+
cell tracks the running mean and variance, and when set to ``False``,
|
|
196
|
+
this cell does not track such statistics. And this cell always uses batch statistics
|
|
197
|
+
in both training and eval modes. Default: ``True`` .
|
|
198
|
+
dtype (:class:`mindspore.dtype`, optional): Dtype of Parameters. Default: ``None`` .
|
|
199
|
+
|
|
200
|
+
Inputs:
|
|
201
|
+
- **input** (Tensor) - The input with shape :math:`(N, C)` or :math:`(N, C, L)`,
|
|
202
|
+
where :math:`N` means batch, :math:`C` means the number of feature or the number of channel,
|
|
203
|
+
and :math:`L` is the length of sequence.
|
|
204
|
+
|
|
205
|
+
Outputs:
|
|
206
|
+
Tensor, has the same type and shape as `input`.
|
|
207
|
+
|
|
208
|
+
Raises:
|
|
209
|
+
TypeError: If `num_features` is not a int number.
|
|
210
|
+
TypeError: If `eps` is not a float.
|
|
211
|
+
ValueError: If `num_features` is less than 1.
|
|
212
|
+
|
|
213
|
+
Supported Platforms:
|
|
214
|
+
``Ascend``
|
|
215
|
+
|
|
216
|
+
Examples:
|
|
217
|
+
>>> import mindspore
|
|
218
|
+
>>> from mindspore import Tensor, mint
|
|
219
|
+
>>> input_x = mindspore.Tensor([[0.7, 0.5, 0.5, 0.6], [0.5, 0.4, 0.6, 0.9]])
|
|
220
|
+
>>> net = mint.nn.BatchNorm1d(4)
|
|
221
|
+
>>> output = net(input_x)
|
|
222
|
+
>>> print(output)
|
|
223
|
+
[[ 0.99950075 0.9980011 -0.9980068 -0.9997783]
|
|
224
|
+
[-0.9995012 -0.99799967 0.9980068 0.9997778]]
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
def _check_input_dim(self, input):
|
|
228
|
+
shape = self.shape(input)
|
|
229
|
+
dim = len(shape)
|
|
230
|
+
if dim != 2 and dim != 3:
|
|
231
|
+
raise ValueError(
|
|
232
|
+
"expected 2D or 3D input (got {}D input)".format(dim)
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class BatchNorm2d(_BatchNorm):
|
|
237
|
+
r"""
|
|
238
|
+
Applies Batch Normalization over a 4D input as described in the paper
|
|
239
|
+
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
|
240
|
+
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_ .
|
|
241
|
+
|
|
242
|
+
.. math::
|
|
243
|
+
|
|
244
|
+
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
|
245
|
+
|
|
246
|
+
The mean and standard-deviation are calculated per-dimension over
|
|
247
|
+
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
248
|
+
of size `C` (where `C` is the number of features or channels of the input). By default, the
|
|
249
|
+
elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0.
|
|
250
|
+
|
|
251
|
+
.. warning::
|
|
252
|
+
This API does not support Dynamic Rank.
|
|
253
|
+
This is an experimental API that is subject to change or deletion.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
num_features (int): `C` from an expected input of shape :math:`(N, C, H, W)`.
|
|
257
|
+
eps (float, optional): a value added to the denominator for numerical stability.
|
|
258
|
+
Default: ``1e-5`` .
|
|
259
|
+
momentum (float, optional): the value used for the running_mean and running_var
|
|
260
|
+
computation. Can be set to ``None`` for cumulative moving average. Default: ``0.1`` .
|
|
261
|
+
affine (bool, optional): a boolean value that when set to ``True``, this cell has
|
|
262
|
+
learnable affine parameters. Default: ``True`` .
|
|
263
|
+
track_running_stats (bool, optional): a boolean value that when set to ``True``, this
|
|
264
|
+
cell tracks the running mean and variance, and when set to ``False``,
|
|
265
|
+
this cell does not track such statistics. And this cell always uses batch statistics
|
|
266
|
+
in both training and eval modes. Default: ``True`` .
|
|
267
|
+
dtype (:class:`mindspore.dtype`, optional): Dtype of Parameters. Default: ``None`` .
|
|
268
|
+
|
|
269
|
+
Inputs:
|
|
270
|
+
- **input** (Tensor) - The input with shape :math:`(N, C, H, W)`.
|
|
271
|
+
|
|
272
|
+
Outputs:
|
|
273
|
+
Tensor, has the same type and shape as `input`.
|
|
274
|
+
|
|
275
|
+
Raises:
|
|
276
|
+
TypeError: If `num_features` is not a int number.
|
|
277
|
+
TypeError: If `eps` is not a float.
|
|
278
|
+
ValueError: If `num_features` is less than 1.
|
|
279
|
+
|
|
280
|
+
Supported Platforms:
|
|
281
|
+
``Ascend``
|
|
282
|
+
|
|
283
|
+
Examples:
|
|
284
|
+
>>> import mindspore
|
|
285
|
+
>>> from mindspore import Tensor, mint
|
|
286
|
+
>>> input_x = mindspore.Tensor([0.3, 0.4, 0.5, 0.3])
|
|
287
|
+
>>> input_x = input_x.reshape((2, 2, 1, 1))
|
|
288
|
+
>>> net = mint.nn.BatchNorm2d(2)
|
|
289
|
+
>>> output = net(input_x)
|
|
290
|
+
>>> print(output)
|
|
291
|
+
[[[[-0.99950075]]
|
|
292
|
+
[[0.9980087]]]
|
|
293
|
+
[[[0.999501]]
|
|
294
|
+
[[-0.9980097]]]]
|
|
295
|
+
"""
|
|
296
|
+
|
|
297
|
+
def _check_input_dim(self, input):
|
|
298
|
+
shape = self.shape(input)
|
|
299
|
+
dim = len(shape)
|
|
300
|
+
if dim != 4:
|
|
301
|
+
raise ValueError(
|
|
302
|
+
"expected 4D input (got {}D input)".format(dim)
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class BatchNorm3d(_BatchNorm):
|
|
307
|
+
r"""
|
|
308
|
+
Applies Batch Normalization over a 5D input as described in the paper
|
|
309
|
+
`Batch Normalization: Accelerating Deep Network Training by Reducing
|
|
310
|
+
Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_ .
|
|
311
|
+
|
|
312
|
+
.. math::
|
|
313
|
+
|
|
314
|
+
y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta
|
|
315
|
+
|
|
316
|
+
The mean and standard-deviation are calculated per-dimension over
|
|
317
|
+
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
318
|
+
of size `C` (where `C` is the number of features or channels of the input). By default, the
|
|
319
|
+
elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0.
|
|
320
|
+
|
|
321
|
+
.. warning::
|
|
322
|
+
This API does not support Dynamic Rank.
|
|
323
|
+
This is an experimental API that is subject to change or deletion.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
num_features (int): `C` from an expected input of shape :math:`(N, C, D, H, W)`.
|
|
327
|
+
eps (float, optional): a value added to the denominator for numerical stability.
|
|
328
|
+
Default: ``1e-5`` .
|
|
329
|
+
momentum (float, optional): the value used for the running_mean and running_var
|
|
330
|
+
computation. Can be set to ``None`` for cumulative moving average. Default: ``0.1`` .
|
|
331
|
+
affine (bool, optional): a boolean value that when set to ``True``, this cell has
|
|
332
|
+
learnable affine parameters. Default: ``True`` .
|
|
333
|
+
track_running_stats (bool, optional): a boolean value that when set to ``True``, this
|
|
334
|
+
cell tracks the running mean and variance, and when set to ``False``,
|
|
335
|
+
this cell does not track such statistics. And this cell always uses batch statistics
|
|
336
|
+
in both training and eval modes. Default: ``True`` .
|
|
337
|
+
dtype (:class:`mindspore.dtype`, optional): Dtype of Parameters. Default: ``None`` .
|
|
338
|
+
|
|
339
|
+
Inputs:
|
|
340
|
+
- **input** (Tensor) - The input with shape :math:`(N, C, D, H, W)`.
|
|
341
|
+
|
|
342
|
+
Outputs:
|
|
343
|
+
Tensor, has the same type and shape as `input`.
|
|
344
|
+
|
|
345
|
+
Raises:
|
|
346
|
+
TypeError: If `num_features` is not a int number.
|
|
347
|
+
TypeError: If `eps` is not a float.
|
|
348
|
+
ValueError: If `num_features` is less than 1.
|
|
349
|
+
|
|
350
|
+
Supported Platforms:
|
|
351
|
+
``Ascend``
|
|
352
|
+
|
|
353
|
+
Examples:
|
|
354
|
+
>>> import mindspore
|
|
355
|
+
>>> from mindspore import Tensor, mint
|
|
356
|
+
>>> input_x = mindspore.Tensor([0.1, 0.9, 1.2, 2.3])
|
|
357
|
+
>>> input_x = input_x.reshape((1, 2, 1, 1, 2))
|
|
358
|
+
>>> net = mint.nn.BatchNorm3d(2)
|
|
359
|
+
>>> output = net(input_x)
|
|
360
|
+
>>> print(output)
|
|
361
|
+
[[[[[-0.9999688 0.99996865]]]
|
|
362
|
+
[[[-0.9999833 06.9999831]]]]]
|
|
363
|
+
"""
|
|
364
|
+
|
|
365
|
+
def _check_input_dim(self, input):
|
|
366
|
+
shape = self.shape(input)
|
|
367
|
+
dim = len(shape)
|
|
368
|
+
if dim != 5:
|
|
369
|
+
raise ValueError(
|
|
370
|
+
"expected 5D input (got {}D input)".format(dim)
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
class GroupNorm(Cell):
|
|
375
|
+
r"""
|
|
376
|
+
Group Normalization over a mini-batch of inputs.
|
|
377
|
+
|
|
378
|
+
Group Normalization is widely used in recurrent neural networks. It applies
|
|
379
|
+
normalization on a mini-batch of inputs for each single training case as described
|
|
380
|
+
in the paper `Group Normalization <https://arxiv.org/pdf/1803.08494.pdf>`_.
|
|
381
|
+
|
|
382
|
+
Group Normalization divides the channels into groups and computes within each group
|
|
383
|
+
the mean and variance for normalization, and it performs very stable over a wide
|
|
384
|
+
range of batch size. :math:`\gamma` and :math:`\beta` are trainable scale and shift.
|
|
385
|
+
It can be described using the following formula:
|
|
386
|
+
|
|
387
|
+
.. math::
|
|
388
|
+
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
389
|
+
|
|
390
|
+
where :math:`\gamma` is `weight`, :math:`\beta` is `bias`, and :math:`\epsilon` is `eps`.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
num_groups (int): The number of groups to be divided along the channel dimension.
|
|
394
|
+
num_channels (int): The number of input channels.
|
|
395
|
+
eps (float, optional): A value added to the denominator for numerical stability. Default: ``1e-05`` .
|
|
396
|
+
affine (bool, optional): The parameters, such as :math:`\gamma` and :math:`\beta`, are learnable
|
|
397
|
+
when set to ``true`` . Default: ``True`` .
|
|
398
|
+
dtype (:class:`mindspore.dtype`, optional): Dtype of Parameters. Default: ``None`` .
|
|
399
|
+
|
|
400
|
+
Inputs:
|
|
401
|
+
- **input** (Tensor) - The input feature with shape :math:`(N, C, *)`, where :math:`*` means, any number of
|
|
402
|
+
additional dimensions.
|
|
403
|
+
|
|
404
|
+
Outputs:
|
|
405
|
+
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `x`.
|
|
406
|
+
|
|
407
|
+
Raises:
|
|
408
|
+
TypeError: If `num_groups` or `num_channels` is not an int.
|
|
409
|
+
TypeError: If `eps` is not a float.
|
|
410
|
+
TypeError: If `affine` is not a bool.
|
|
411
|
+
ValueError: If `num_groups` or `num_channels` is less than 1.
|
|
412
|
+
ValueError: If `num_channels` is not divided by `num_groups`.
|
|
413
|
+
|
|
414
|
+
Supported Platforms:
|
|
415
|
+
``Ascend``
|
|
416
|
+
|
|
417
|
+
Examples:
|
|
418
|
+
>>> import mindspore as ms
|
|
419
|
+
>>> import numpy as np
|
|
420
|
+
>>> group_norm_op = ms.mint.nn.GroupNorm(2, 2)
|
|
421
|
+
>>> x = ms.Tensor(np.ones([1, 2, 4, 4], np.float32))
|
|
422
|
+
>>> output = group_norm_op(x)
|
|
423
|
+
>>> print(output)
|
|
424
|
+
[[[[0. 0. 0. 0.]
|
|
425
|
+
[0. 0. 0. 0.]
|
|
426
|
+
[0. 0. 0. 0.]
|
|
427
|
+
[0. 0. 0. 0.]]
|
|
428
|
+
[[0. 0. 0. 0.]
|
|
429
|
+
[0. 0. 0. 0.]
|
|
430
|
+
[0. 0. 0. 0.]
|
|
431
|
+
[0. 0. 0. 0.]]]]
|
|
432
|
+
"""
|
|
433
|
+
|
|
434
|
+
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, dtype=None):
|
|
435
|
+
"""Initialize GroupNorm."""
|
|
436
|
+
super(GroupNorm, self).__init__()
|
|
437
|
+
ms_dtype = mstype.float32 if dtype is None else dtype
|
|
438
|
+
gamma_init = 'ones'
|
|
439
|
+
beta_init = 'zeros'
|
|
440
|
+
|
|
441
|
+
self.num_groups = validator.check_positive_int(
|
|
442
|
+
num_groups, "num_groups", self.cls_name)
|
|
443
|
+
self.num_channels = validator.check_positive_int(
|
|
444
|
+
num_channels, "num_channels", self.cls_name)
|
|
445
|
+
if num_channels % num_groups != 0:
|
|
446
|
+
raise ValueError(f"For '{self.cls_name}', the 'num_channels' must be divided by 'num_groups', "
|
|
447
|
+
f"but got 'num_channels': {num_channels}, 'num_groups': {num_groups}.")
|
|
448
|
+
self.eps = validator.check_value_type(
|
|
449
|
+
'eps', eps, (float,), type(self).__name__)
|
|
450
|
+
self.affine = validator.check_bool(
|
|
451
|
+
affine, arg_name="affine", prim_name=self.cls_name)
|
|
452
|
+
|
|
453
|
+
self.gamma = Parameter(initializer(
|
|
454
|
+
gamma_init, self.num_channels, dtype=ms_dtype), name="gamma", requires_grad=affine)
|
|
455
|
+
self.beta = Parameter(initializer(
|
|
456
|
+
beta_init, self.num_channels, dtype=ms_dtype), name="beta", requires_grad=affine)
|
|
457
|
+
|
|
458
|
+
def _cal_output(self, x):
|
|
459
|
+
"""calculate groupnorm output"""
|
|
460
|
+
return group_norm(x, self.num_groups, self.gamma, self.beta, self.eps)
|
|
461
|
+
|
|
462
|
+
def extend_repr(self):
|
|
463
|
+
return 'num_groups={}, num_channels={}, eps={}, affine={}'.format(
|
|
464
|
+
self.num_groups, self.num_channels, self.eps, self.affine)
|
|
465
|
+
|
|
466
|
+
def construct(self, input):
|
|
467
|
+
output = self._cal_output(input)
|
|
468
|
+
return output
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
__all__ = [
|
|
472
|
+
'GroupNorm',
|
|
473
|
+
'BatchNorm1d',
|
|
474
|
+
'BatchNorm2d',
|
|
475
|
+
'BatchNorm3d',
|
|
476
|
+
'LayerNorm',
|
|
477
|
+
]
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""normalization for mint"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
from __future__ import division
|
|
18
|
+
|
|
19
|
+
from mindspore import mint
|
|
20
|
+
from mindspore.nn.cell import Cell
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class _AdaptiveAvgPoolNd(Cell):
|
|
24
|
+
"""Common base of AdaptiveAvgPoolNd"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, output_size) -> None:
|
|
27
|
+
super(_AdaptiveAvgPoolNd, self).__init__()
|
|
28
|
+
self.output_size = output_size
|
|
29
|
+
|
|
30
|
+
def extend_repr(self):
|
|
31
|
+
return 'output_size={}'.format(self.output_size)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class AdaptiveAvgPool1d(_AdaptiveAvgPoolNd):
|
|
35
|
+
r"""
|
|
36
|
+
Applies a 1D adaptive average pooling over an input signal composed of several input planes.
|
|
37
|
+
|
|
38
|
+
The output is of size :math:`L_{out}` , for any input size.
|
|
39
|
+
The number of output features is equal to the number of input planes.
|
|
40
|
+
|
|
41
|
+
.. warning::
|
|
42
|
+
This is an experimental API that is subject to change or deletion.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
output_size (int): the target output size :math:`L_{out}` .
|
|
46
|
+
|
|
47
|
+
Inputs:
|
|
48
|
+
- **input** (Tensor) - The input with shape :math:`(N, C, L_{in})` or :math:`(C, L_{in})` .
|
|
49
|
+
|
|
50
|
+
Supported Platforms:
|
|
51
|
+
``Ascend``
|
|
52
|
+
|
|
53
|
+
Examples:
|
|
54
|
+
>>> import mindspore
|
|
55
|
+
>>> from mindspore import Tensor, mint
|
|
56
|
+
>>> import numpy as np
|
|
57
|
+
>>> input = Tensor(np.array([[[2, 1, 2], [2, 3, 5]]]), mindspore.float16)
|
|
58
|
+
>>> net = mint.nn.AdaptiveAvgPool1d(3)
|
|
59
|
+
>>> output = net(input)
|
|
60
|
+
>>> print(output)
|
|
61
|
+
[[[2. 1. 2.]
|
|
62
|
+
[2. 3. 5.]]]
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def construct(self, input):
|
|
66
|
+
return mint.nn.functional.adaptive_avg_pool1d(input, self.output_size)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):
|
|
70
|
+
r"""
|
|
71
|
+
Applies a 2D adaptive average pooling over an input signal composed of several input planes.
|
|
72
|
+
|
|
73
|
+
The output is of size :math:`H x W` , for any input size.
|
|
74
|
+
The number of output features is equal to the number of input planes.
|
|
75
|
+
|
|
76
|
+
.. warning::
|
|
77
|
+
This is an experimental API that is subject to change or deletion.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
output_size (Union(int, tuple[int])): the target output size of the image of the form :math:`H x W` .
|
|
81
|
+
Can be a tuple :math:`(H, W)` or a single :math:`H` for square image :math:`H x H` .
|
|
82
|
+
:math:`H` and :math:`W` can be either a ``int`` , or ``None`` which means the size will
|
|
83
|
+
be the same as that of the input.
|
|
84
|
+
|
|
85
|
+
Inputs:
|
|
86
|
+
- **input** (Tensor) - The input with shape :math:`(N, C, H, W)` or :math:`(C, H, W)` .
|
|
87
|
+
|
|
88
|
+
Supported Platforms:
|
|
89
|
+
``Ascend``
|
|
90
|
+
|
|
91
|
+
Examples:
|
|
92
|
+
>>> import mindspore
|
|
93
|
+
>>> from mindspore import Tensor, mint
|
|
94
|
+
>>> import numpy as np
|
|
95
|
+
>>> input = Tensor(np.array([[[2, 1, 2], [2, 3, 5]]]), mindspore.float16)
|
|
96
|
+
>>> net = mint.nn.AdaptiveAvgPool2d((2, 2))
|
|
97
|
+
>>> output = net(input)
|
|
98
|
+
>>> print(output)
|
|
99
|
+
[[[1.5 1.5]
|
|
100
|
+
[2.5 4. ]]]
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def construct(self, input):
|
|
104
|
+
return mint.nn.functional.adaptive_avg_pool2d(input, self.output_size)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
__all__ = [
|
|
108
|
+
'AdaptiveAvgPool2d',
|
|
109
|
+
'AdaptiveAvgPool1d',
|
|
110
|
+
]
|
mindspore/mint/optim/adamw.py
CHANGED
|
@@ -23,18 +23,24 @@ from mindspore.ops import auto_generate as gen
|
|
|
23
23
|
from mindspore.experimental.optim.optimizer import Optimizer
|
|
24
24
|
from mindspore import _checkparam as validator
|
|
25
25
|
|
|
26
|
-
_optim_adamw_opt = C.MultitypeFuncGraph("optim_adamw_opt")
|
|
27
26
|
hyper_map = C.HyperMap()
|
|
28
27
|
|
|
29
28
|
|
|
30
|
-
|
|
31
|
-
|
|
29
|
+
def _run_optim_adamw_amsgrad_opt(opt, beta1, beta2, lr, eps, weight_decay, step, amsgrad, maximize, parameters, grads,
|
|
30
|
+
exp_avg, exp_avg_sq, max_exp_avg_sq):
|
|
31
|
+
"""Apply adamw optimizer to the weight parameter."""
|
|
32
|
+
success = True
|
|
33
|
+
opt(parameters, exp_avg, exp_avg_sq, max_exp_avg_sq, grads, step, lr, beta1, beta2, weight_decay, eps, amsgrad,
|
|
34
|
+
maximize)
|
|
35
|
+
return success
|
|
36
|
+
|
|
37
|
+
|
|
32
38
|
def _run_optim_adamw_opt(opt, beta1, beta2, lr, eps, weight_decay, step, amsgrad, maximize, parameters, grads, exp_avg,
|
|
33
|
-
exp_avg_sq
|
|
39
|
+
exp_avg_sq):
|
|
34
40
|
"""Apply adamw optimizer to the weight parameter."""
|
|
35
41
|
success = True
|
|
36
|
-
opt(parameters, exp_avg, exp_avg_sq,
|
|
37
|
-
|
|
42
|
+
opt(parameters, exp_avg, exp_avg_sq, exp_avg_sq, grads, step, lr, beta1, beta2, weight_decay, eps, amsgrad,
|
|
43
|
+
maximize)
|
|
38
44
|
return success
|
|
39
45
|
|
|
40
46
|
|
|
@@ -156,15 +162,14 @@ class AdamW(Optimizer):
|
|
|
156
162
|
defaults = dict(lr=lr, betas=betas, eps=eps,
|
|
157
163
|
weight_decay=weight_decay, amsgrad=amsgrad,
|
|
158
164
|
maximize=maximize)
|
|
165
|
+
self.max_v_group = True
|
|
159
166
|
super(AdamW, self).__init__(params, defaults)
|
|
160
167
|
|
|
161
168
|
self.exp_avg = self.parameters.clone(prefix="exp_avg", init='zeros')
|
|
162
169
|
self.exp_avg_sq = self.parameters.clone(prefix="exp_avg_sq", init='zeros')
|
|
163
|
-
self.max_exp_avg_sq = self.parameters.clone(prefix="max_exp_avg_sq", init='zeros')
|
|
164
170
|
self.state_step = Parameter(Tensor([-1], mstype.float32), "state_step")
|
|
165
171
|
self.increase_tensor = Tensor(1, mstype.float32)
|
|
166
172
|
self.assignadd = P.AssignAdd()
|
|
167
|
-
self.op_cast = P.Cast()
|
|
168
173
|
self.adamw_opt = gen.AdamW()
|
|
169
174
|
|
|
170
175
|
def construct(self, gradients):
|
|
@@ -177,9 +182,17 @@ class AdamW(Optimizer):
|
|
|
177
182
|
lr = group.get("lr")
|
|
178
183
|
grads = tuple(gradients[start_id: end_id])
|
|
179
184
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
+
if group.get("amsgrad"):
|
|
186
|
+
self.hyper_map(F.partial(_run_optim_adamw_amsgrad_opt, self.adamw_opt, beta1, beta2, float(lr),
|
|
187
|
+
group.get("eps"), group.get("weight_decay"), self.state_step,
|
|
188
|
+
group.get("amsgrad"), maximize),
|
|
189
|
+
self.parameters[start_id: end_id], grads, self.exp_avg[start_id: end_id],
|
|
190
|
+
self.exp_avg_sq[start_id: end_id], group.get("max_exp_avg_sq"))
|
|
191
|
+
else:
|
|
192
|
+
self.hyper_map(F.partial(_run_optim_adamw_opt, self.adamw_opt, beta1, beta2, float(lr),
|
|
193
|
+
group.get("eps"), group.get("weight_decay"), self.state_step,
|
|
194
|
+
group.get("amsgrad"), maximize),
|
|
195
|
+
self.parameters[start_id: end_id], grads, self.exp_avg[start_id: end_id],
|
|
196
|
+
self.exp_avg_sq[start_id: end_id])
|
|
197
|
+
|
|
185
198
|
return True
|