mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Collective communication interface.
|
|
17
|
+
|
|
18
|
+
Note that the APIs in the following list need to preset communication environment variables.
|
|
19
|
+
|
|
20
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
21
|
+
without any third-party or configuration file dependencies.
|
|
22
|
+
Please see the `msrun start up
|
|
23
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
24
|
+
for more details.
|
|
25
|
+
"""
|
|
26
|
+
from mindspore.mint.distributed.distributed import init_process_group, destroy_process_group, get_rank, get_world_size
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
__all__ = [
|
|
30
|
+
"init_process_group", "destroy_process_group", "get_rank", "get_world_size"
|
|
31
|
+
]
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Communication management API"""
|
|
16
|
+
from mindspore import log as logger
|
|
17
|
+
from mindspore.communication._comm_helper import _destroy_group_helper, GlobalComm, _get_rank_helper, _get_size_helper
|
|
18
|
+
from mindspore.communication import init, release, get_group_size
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def init_process_group(backend="hccl",
|
|
22
|
+
init_method=None,
|
|
23
|
+
timeout=None,
|
|
24
|
+
world_size=-1,
|
|
25
|
+
rank=-1,
|
|
26
|
+
store=None,
|
|
27
|
+
pg_options=None,
|
|
28
|
+
device_id=None):
|
|
29
|
+
"""
|
|
30
|
+
Init collective communication lib. And create a default collective communication group.
|
|
31
|
+
|
|
32
|
+
Note:
|
|
33
|
+
This method isn't supported in GPU and CPU versions of MindSpore.
|
|
34
|
+
In Ascend hardware platforms, this API should be set before the definition of any Tensor and Parameter,
|
|
35
|
+
and the instantiation and execution of any operation and net.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
backend (str, optional): The backend to ues. default is hccl and now only support hccl.
|
|
39
|
+
init_method (str, invalid): URL specifying how to init collective communication group. Provides parameters
|
|
40
|
+
consistent with pytorch, but is not currently support, setting is invalid.
|
|
41
|
+
timeout (timedelta, invalid): Timeout for API executed. Provides parameters consistent with pytorch, but is not
|
|
42
|
+
currently support, setting is invalid.
|
|
43
|
+
world_size (int, optional): Number of the processes participating in the job.
|
|
44
|
+
rank (int, invalid): Rank of the current process. Provides parameters consistent with pytorch, but is not
|
|
45
|
+
currently support, setting is invalid.
|
|
46
|
+
store (Store, invalid): Key/Value store accessible to all workers, used to exchange connection/address
|
|
47
|
+
information. Provides parameters consistent with pytorch, but is not currently support,
|
|
48
|
+
setting is invalid.
|
|
49
|
+
pg_options (ProcessGroupOptions, invalid): process group options specifying what additional options need to be
|
|
50
|
+
passed in during the construction of specific process group. Provides
|
|
51
|
+
parameters consistent with pytorch, but is not currently support,
|
|
52
|
+
setting is invalid.
|
|
53
|
+
device_id (int, invalid): the device id to exeute. Provides parameters consistent with pytorch, but is not
|
|
54
|
+
currently support, setting is invalid.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
ValueError: If `backend` is not hccl.
|
|
58
|
+
ValueError: If `world_size` is not equal to -1 or process group number.
|
|
59
|
+
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails,
|
|
60
|
+
or the environment variables RANK_ID/MINDSPORE_HCCL_CONFIG_PATH
|
|
61
|
+
have not been exported when backend is HCCL.
|
|
62
|
+
|
|
63
|
+
Supported Platforms:
|
|
64
|
+
``Ascend``
|
|
65
|
+
|
|
66
|
+
Examples:
|
|
67
|
+
.. note::
|
|
68
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
69
|
+
|
|
70
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
71
|
+
without any third-party or configuration file dependencies.
|
|
72
|
+
Please see the `msrun start up
|
|
73
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
74
|
+
for more details.
|
|
75
|
+
|
|
76
|
+
>>> import mindspore as ms
|
|
77
|
+
>>> from mindspore import set_context
|
|
78
|
+
>>> from mindspore.mint.distributed import init_process_group, destroy_process_group
|
|
79
|
+
>>> set_context(device_target="Ascend")
|
|
80
|
+
>>> init_process_group()
|
|
81
|
+
>>> destroy_process_group()
|
|
82
|
+
"""
|
|
83
|
+
if init_method is not None:
|
|
84
|
+
logger.warning("init_method is ignored, setting is invalid")
|
|
85
|
+
if timeout is not None:
|
|
86
|
+
logger.warning("timeout is ignored, setting is invalid")
|
|
87
|
+
if store is not None:
|
|
88
|
+
logger.warning("store is ignored, setting is invalid")
|
|
89
|
+
if pg_options is not None:
|
|
90
|
+
logger.warning("pg_options is ignored, setting is invalid")
|
|
91
|
+
if device_id is not None:
|
|
92
|
+
logger.warning("device_id is ignored, setting is invalid")
|
|
93
|
+
if rank != -1:
|
|
94
|
+
logger.warning("rank is ignored, setting is invalid")
|
|
95
|
+
if backend != "hccl":
|
|
96
|
+
raise ValueError("Only support hccl now, please setting backend to hccl or using default value")
|
|
97
|
+
|
|
98
|
+
#init hccl & create world group
|
|
99
|
+
init(backend)
|
|
100
|
+
|
|
101
|
+
if world_size != -1 and world_size != get_group_size():
|
|
102
|
+
raise ValueError("world_size is wrong, please using default value or setting: ", get_group_size())
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def destroy_process_group(group=None):
|
|
106
|
+
"""
|
|
107
|
+
Destroy the user collective communication group.
|
|
108
|
+
If group is None or "hccl_world_group", Destroy all group and release collective communication lib.
|
|
109
|
+
|
|
110
|
+
Note:
|
|
111
|
+
This method isn't supported in GPU and CPU versions of MindSpore.
|
|
112
|
+
This method should be used after init_process_group().
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
group (str): The communication group to destroy, the group should be created by init_process_group or new_group.
|
|
116
|
+
|
|
117
|
+
Raises:
|
|
118
|
+
TypeError: If group is not a string.
|
|
119
|
+
RuntimeError: If HCCL is not available or MindSpore is GPU/CPU version.
|
|
120
|
+
|
|
121
|
+
Supported Platforms:
|
|
122
|
+
``Ascend``
|
|
123
|
+
|
|
124
|
+
Examples:
|
|
125
|
+
.. note::
|
|
126
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
127
|
+
|
|
128
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
129
|
+
without any third-party or configuration file dependencies.
|
|
130
|
+
Please see the `msrun start up
|
|
131
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
132
|
+
for more details.
|
|
133
|
+
|
|
134
|
+
>>> import mindspore as ms
|
|
135
|
+
>>> from mindspore import set_context
|
|
136
|
+
>>> from mindspore.mint.distributed import init_process_group, destroy_process_group
|
|
137
|
+
>>> set_context(device_target="Ascend")
|
|
138
|
+
>>> init_process_group()
|
|
139
|
+
>>> destroy_process_group()
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
if group == GlobalComm.WORLD_COMM_GROUP or group is None:
|
|
143
|
+
release()
|
|
144
|
+
elif not isinstance(group, str):
|
|
145
|
+
raise TypeError("For 'destroy_group', the argument 'group' must be type of string or None, "
|
|
146
|
+
"but got 'group' type : {}.".format(type(group)))
|
|
147
|
+
else:
|
|
148
|
+
_destroy_group_helper(group)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def get_rank(group=None):
|
|
152
|
+
"""
|
|
153
|
+
Get the rank ID for the current device in the specified collective communication group.
|
|
154
|
+
|
|
155
|
+
Note:
|
|
156
|
+
This method should be used after init().
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
group (str): The communication group to work on. Normally, the group should be created by create_group,
|
|
160
|
+
otherwise, using the default group. If None, ``GlobalComm.WORLD_COMM_GROUP`` will be used.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
int, the rank ID of the calling process within the group.
|
|
164
|
+
return -1, if not part of the group
|
|
165
|
+
|
|
166
|
+
Raises:
|
|
167
|
+
TypeError: If group is not a string.
|
|
168
|
+
|
|
169
|
+
Supported Platforms:
|
|
170
|
+
``Ascend``
|
|
171
|
+
|
|
172
|
+
Examples:
|
|
173
|
+
.. note::
|
|
174
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
175
|
+
|
|
176
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
177
|
+
without any third-party or configuration file dependencies.
|
|
178
|
+
Please see the `msrun start up
|
|
179
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
180
|
+
for more details.
|
|
181
|
+
|
|
182
|
+
>>> from mindspore import set_context
|
|
183
|
+
>>> from mindspore.mint.distributed import init_process_group, get_rank
|
|
184
|
+
>>> set_context(device_target="Ascend")
|
|
185
|
+
>>> init_process_group()
|
|
186
|
+
>>> rank_id = get_rank()
|
|
187
|
+
>>> print(rank_id)
|
|
188
|
+
>>> # the result is the rank_id in world_group
|
|
189
|
+
"""
|
|
190
|
+
if group is None:
|
|
191
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
192
|
+
if not isinstance(group, str):
|
|
193
|
+
raise TypeError("For 'get_rank', the argument 'group' must be type of string, "
|
|
194
|
+
"but got 'group' type : {}.".format(type(group)))
|
|
195
|
+
try:
|
|
196
|
+
ret = _get_rank_helper(group)
|
|
197
|
+
except RuntimeError as e:
|
|
198
|
+
logger.warning(e)
|
|
199
|
+
ret = -1
|
|
200
|
+
return ret
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def get_world_size(group=None):
|
|
204
|
+
"""
|
|
205
|
+
Get the rank size of the specified collective communication group.
|
|
206
|
+
|
|
207
|
+
Note:
|
|
208
|
+
This method should be used after init().
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
group (str): The communication group to work on. Normally, the group should be created by create_group,
|
|
212
|
+
otherwise, using the default group. If None, ``GlobalComm.WORLD_COMM_GROUP`` will be used.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
int, the rank size of the group.
|
|
216
|
+
return -1, if the group is not available.
|
|
217
|
+
|
|
218
|
+
Raises:
|
|
219
|
+
TypeError: If group is not a string.
|
|
220
|
+
|
|
221
|
+
Supported Platforms:
|
|
222
|
+
``Ascend``
|
|
223
|
+
|
|
224
|
+
Examples:
|
|
225
|
+
.. note::
|
|
226
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
227
|
+
|
|
228
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
229
|
+
without any third-party or configuration file dependencies.
|
|
230
|
+
Please see the `msrun start up
|
|
231
|
+
<https://www.mindspore.cn/docs/zh-CN/master/model_train/parallel/msrun_launcher.html>`_
|
|
232
|
+
for more details.
|
|
233
|
+
|
|
234
|
+
>>> import mindspore as ms
|
|
235
|
+
>>> from mindspore import set_context
|
|
236
|
+
>>> from mindspore.mint.distributed import init_process_group, get_world_size
|
|
237
|
+
>>> set_context(device_target="Ascend")
|
|
238
|
+
>>> init_process_group()
|
|
239
|
+
>>> group_size = get_world_size()
|
|
240
|
+
>>> print("group_size is: ", group_size)
|
|
241
|
+
group_size is: 8
|
|
242
|
+
"""
|
|
243
|
+
ret = -1
|
|
244
|
+
if group is None:
|
|
245
|
+
group = GlobalComm.WORLD_COMM_GROUP
|
|
246
|
+
if not isinstance(group, str):
|
|
247
|
+
raise TypeError("For 'get_group_size', the argument 'group' must be type of string, "
|
|
248
|
+
"but got 'group' type : {}.".format(type(group)))
|
|
249
|
+
try:
|
|
250
|
+
ret = _get_size_helper(group)
|
|
251
|
+
except RuntimeError as e:
|
|
252
|
+
logger.warning(e)
|
|
253
|
+
ret = -1
|
|
254
|
+
return ret
|
mindspore/mint/nn/__init__.py
CHANGED
|
@@ -18,16 +18,17 @@ Neural Networks Cells.
|
|
|
18
18
|
Predefined building blocks or computing units to construct neural networks.
|
|
19
19
|
"""
|
|
20
20
|
from __future__ import absolute_import
|
|
21
|
+
import mindspore.ops as ops
|
|
22
|
+
from mindspore.mint.nn import functional as F
|
|
21
23
|
from mindspore.nn.cell import Cell
|
|
22
|
-
from mindspore.nn
|
|
23
|
-
|
|
24
|
-
from mindspore.nn.extend import MaxPool2d
|
|
24
|
+
from mindspore.nn import EmbeddingExt as Embedding, MaxPool2dExt as MaxPool2d, LayerNormExt as LayerNorm, Linear
|
|
25
|
+
|
|
25
26
|
# 1
|
|
26
27
|
|
|
27
28
|
# 2
|
|
28
29
|
|
|
29
30
|
# 3
|
|
30
|
-
|
|
31
|
+
from mindspore.nn.layer.basic import Identity
|
|
31
32
|
# 4
|
|
32
33
|
|
|
33
34
|
# 5
|
|
@@ -37,13 +38,13 @@ from mindspore.nn.layer.basic import UnfoldExt as Unfold
|
|
|
37
38
|
# 7
|
|
38
39
|
from mindspore.nn.layer.basic import Fold
|
|
39
40
|
# 8
|
|
40
|
-
from mindspore.nn.
|
|
41
|
-
from mindspore.nn.extend.layer.normalization import *
|
|
41
|
+
from mindspore.nn.layer.activation import SoftmaxExt as Softmax
|
|
42
42
|
# 9
|
|
43
43
|
from mindspore.nn.layer.basic import UpsampleExt as Upsample
|
|
44
44
|
# 10
|
|
45
45
|
|
|
46
46
|
# 11
|
|
47
|
+
from mindspore.nn.layer import ReLU
|
|
47
48
|
|
|
48
49
|
# 12
|
|
49
50
|
|
|
@@ -54,11 +55,11 @@ from mindspore.nn.layer.basic import DropoutExt as Dropout
|
|
|
54
55
|
# 15
|
|
55
56
|
|
|
56
57
|
# 16
|
|
57
|
-
|
|
58
|
+
from mindspore.nn.layer import LogSoftmaxExt as LogSoftmax
|
|
58
59
|
# 17
|
|
59
60
|
|
|
60
61
|
# 18
|
|
61
|
-
|
|
62
|
+
from mindspore.nn.layer import PReLUExt as PReLU
|
|
62
63
|
# 19
|
|
63
64
|
|
|
64
65
|
# 20
|
|
@@ -98,11 +99,12 @@ from mindspore.nn.layer.basic import DropoutExt as Dropout
|
|
|
98
99
|
# 37
|
|
99
100
|
|
|
100
101
|
# 38
|
|
101
|
-
|
|
102
|
+
|
|
102
103
|
# 39
|
|
103
104
|
|
|
104
105
|
# 40
|
|
105
|
-
|
|
106
|
+
from mindspore.mint.nn.layer.normalization import GroupNorm
|
|
107
|
+
from mindspore.mint.nn.layer.normalization import LayerNorm
|
|
106
108
|
# 41
|
|
107
109
|
|
|
108
110
|
# 42
|
|
@@ -114,6 +116,7 @@ from mindspore.nn.extend.basic import Linear
|
|
|
114
116
|
# 45
|
|
115
117
|
|
|
116
118
|
# 46
|
|
119
|
+
from mindspore.mint.nn.layer.activation import SiLU, LogSigmoid
|
|
117
120
|
|
|
118
121
|
# 47
|
|
119
122
|
|
|
@@ -220,9 +223,38 @@ from mindspore.nn.extend.basic import Linear
|
|
|
220
223
|
# 98
|
|
221
224
|
|
|
222
225
|
# 99
|
|
223
|
-
|
|
226
|
+
from mindspore.nn.layer import AvgPool2dExt as AvgPool2d
|
|
224
227
|
# 100
|
|
225
|
-
from mindspore.
|
|
228
|
+
from mindspore.nn.layer import SoftShrink as Softshrink
|
|
229
|
+
# 159
|
|
230
|
+
|
|
231
|
+
# 220
|
|
232
|
+
from mindspore.nn.layer import HShrink as Hardshrink
|
|
233
|
+
# 221
|
|
234
|
+
from mindspore.nn.layer import HSigmoid as Hardsigmoid
|
|
235
|
+
# 222
|
|
236
|
+
from mindspore.nn.layer import HSwish as Hardswish
|
|
237
|
+
# 238
|
|
238
|
+
from mindspore.nn.loss import L1LossExt as L1Loss
|
|
239
|
+
|
|
240
|
+
# 257
|
|
241
|
+
|
|
242
|
+
# 258
|
|
243
|
+
from mindspore.ops.function.nn_func import mse_loss_ext
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
# 674
|
|
247
|
+
from mindspore.mint.nn.layer.normalization import BatchNorm1d
|
|
248
|
+
|
|
249
|
+
# 675
|
|
250
|
+
from mindspore.mint.nn.layer.normalization import BatchNorm2d
|
|
251
|
+
|
|
252
|
+
# 676
|
|
253
|
+
from mindspore.mint.nn.layer.normalization import BatchNorm3d
|
|
254
|
+
|
|
255
|
+
from mindspore.mint.nn.layer.pooling import AdaptiveAvgPool1d
|
|
256
|
+
|
|
257
|
+
from mindspore.mint.nn.layer.pooling import AdaptiveAvgPool2d
|
|
226
258
|
|
|
227
259
|
|
|
228
260
|
class BCEWithLogitsLoss(Cell):
|
|
@@ -294,9 +326,10 @@ class BCEWithLogitsLoss(Cell):
|
|
|
294
326
|
>>> print(output)
|
|
295
327
|
0.3463612
|
|
296
328
|
"""
|
|
329
|
+
|
|
297
330
|
def __init__(self, weight=None, reduction='mean', pos_weight=None):
|
|
298
331
|
super(BCEWithLogitsLoss, self).__init__()
|
|
299
|
-
self.bce_with_logits =
|
|
332
|
+
self.bce_with_logits = ops.auto_generate.BCEWithLogitsLoss(reduction)
|
|
300
333
|
self.weight = weight
|
|
301
334
|
self.pos_weight = pos_weight
|
|
302
335
|
|
|
@@ -304,14 +337,199 @@ class BCEWithLogitsLoss(Cell):
|
|
|
304
337
|
out = self.bce_with_logits(input, target, self.weight, self.pos_weight)
|
|
305
338
|
return out
|
|
306
339
|
|
|
340
|
+
|
|
341
|
+
class SELU(Cell):
|
|
342
|
+
r"""
|
|
343
|
+
Activation function SELU (Scaled exponential Linear Unit).
|
|
344
|
+
|
|
345
|
+
Refer to :func:`mindspore.mint.nn.functional.selu` for more details.
|
|
346
|
+
|
|
347
|
+
SELU Activation Function Graph:
|
|
348
|
+
|
|
349
|
+
.. image:: ../images/SeLU.png
|
|
350
|
+
:align: center
|
|
351
|
+
|
|
352
|
+
Supported Platforms:
|
|
353
|
+
``Ascend``
|
|
354
|
+
|
|
355
|
+
Examples:
|
|
356
|
+
>>> import mindspore
|
|
357
|
+
>>> from mindspore import Tensor, mint
|
|
358
|
+
>>> import numpy as np
|
|
359
|
+
>>> input = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
|
|
360
|
+
>>> selu = mint.nn.SELU()
|
|
361
|
+
>>> output = selu(input)
|
|
362
|
+
>>> print(output)
|
|
363
|
+
[[-1.1113307 4.202804 -1.7575096]
|
|
364
|
+
[ 2.101402 -1.7462534 9.456309 ]]
|
|
365
|
+
"""
|
|
366
|
+
|
|
367
|
+
def __init__(self):
|
|
368
|
+
"""Initialize SELU"""
|
|
369
|
+
super(SELU, self).__init__()
|
|
370
|
+
|
|
371
|
+
def construct(self, input):
|
|
372
|
+
return F.selu(input)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
class GELU(Cell):
|
|
376
|
+
r"""
|
|
377
|
+
Activation function GELU (Gaussian Error Linear Unit).
|
|
378
|
+
|
|
379
|
+
Refer to :func:`mindspore.mint.nn.functional.gelu` for more details.
|
|
380
|
+
|
|
381
|
+
GELU Activation Function Graph:
|
|
382
|
+
|
|
383
|
+
.. image:: ../images/GELU.png
|
|
384
|
+
:align: center
|
|
385
|
+
|
|
386
|
+
Supported Platforms:
|
|
387
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
388
|
+
|
|
389
|
+
Examples:
|
|
390
|
+
>>> import mindspore
|
|
391
|
+
>>> from mindspore import Tensor, mint
|
|
392
|
+
>>> import numpy as np
|
|
393
|
+
>>> input = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
|
|
394
|
+
>>> gelu = mint.nn.GELU()
|
|
395
|
+
>>> output = gelu(input)
|
|
396
|
+
>>> print(output)
|
|
397
|
+
[[-1.5880802e-01 3.9999299e+00 -3.1077917e-21]
|
|
398
|
+
[ 1.9545976e+00 -2.2918017e-07 9.0000000e+00]]
|
|
399
|
+
>>> gelu = mint.nn.GELU(approximate=False)
|
|
400
|
+
>>> # CPU not support "approximate=False", using "approximate=True" instead
|
|
401
|
+
>>> output = gelu(input)
|
|
402
|
+
>>> print(output)
|
|
403
|
+
[[-1.5865526e-01 3.9998732e+00 -0.0000000e+00]
|
|
404
|
+
[ 1.9544997e+00 -1.4901161e-06 9.0000000e+00]]
|
|
405
|
+
"""
|
|
406
|
+
|
|
407
|
+
def __init__(self):
|
|
408
|
+
"""Initialize GELU"""
|
|
409
|
+
super(GELU, self).__init__()
|
|
410
|
+
|
|
411
|
+
def construct(self, input):
|
|
412
|
+
return F.gelu(input)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
class Mish(Cell):
|
|
416
|
+
r"""
|
|
417
|
+
Computes MISH (A Self Regularized Non-Monotonic Neural Activation Function)
|
|
418
|
+
of input tensors element-wise.
|
|
419
|
+
|
|
420
|
+
Refer to :func:`mindspore.mint.nn.functional.mish` for more details.
|
|
421
|
+
|
|
422
|
+
Mish Activation Function Graph:
|
|
423
|
+
|
|
424
|
+
.. image:: ../images/Mish.png
|
|
425
|
+
:align: center
|
|
426
|
+
|
|
427
|
+
Supported Platforms:
|
|
428
|
+
``Ascend``
|
|
429
|
+
|
|
430
|
+
Examples:
|
|
431
|
+
>>> import mindspore
|
|
432
|
+
>>> from mindspore import Tensor, mint
|
|
433
|
+
>>> import numpy as np
|
|
434
|
+
>>> x = Tensor(np.array([[-1.1, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
|
|
435
|
+
>>> mish = mint.nn.Mish()
|
|
436
|
+
>>> output = mish(x)
|
|
437
|
+
>>> print(output)
|
|
438
|
+
[[-3.0764845e-01 3.9974124e+00 -2.6832507e-03]
|
|
439
|
+
[ 1.9439589e+00 -3.3576239e-02 8.9999990e+00]]
|
|
440
|
+
"""
|
|
441
|
+
|
|
442
|
+
def __init__(self):
|
|
443
|
+
"""Initialize Mish."""
|
|
444
|
+
super(Mish, self).__init__()
|
|
445
|
+
|
|
446
|
+
def construct(self, input):
|
|
447
|
+
return F.mish(input)
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
class MSELoss(Cell):
|
|
451
|
+
r"""
|
|
452
|
+
Calculates the mean squared error between the predicted value and the label value.
|
|
453
|
+
|
|
454
|
+
For simplicity, let :math:`x` and :math:`y` be 1-dimensional Tensor with length :math:`N`,
|
|
455
|
+
the unreduced loss (i.e. with argument reduction set to 'none') of :math:`x` and :math:`y` is given as:
|
|
456
|
+
|
|
457
|
+
.. math::
|
|
458
|
+
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad \text{with} \quad l_n = (x_n - y_n)^2.
|
|
459
|
+
|
|
460
|
+
where :math:`N` is the batch size. If `reduction` is not ``'none'``, then:
|
|
461
|
+
|
|
462
|
+
.. math::
|
|
463
|
+
\ell(x, y) =
|
|
464
|
+
\begin{cases}
|
|
465
|
+
\operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
|
|
466
|
+
\operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
|
|
467
|
+
\end{cases}
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
|
|
471
|
+
``'sum'`` . Default: ``'mean'`` .
|
|
472
|
+
|
|
473
|
+
- ``'none'``: no reduction will be applied.
|
|
474
|
+
- ``'mean'``: compute and return the mean of elements in the output.
|
|
475
|
+
- ``'sum'``: the output elements will be summed.
|
|
476
|
+
|
|
477
|
+
Inputs:
|
|
478
|
+
- **logits** (Tensor) - The predicted value of the input. Tensor of any dimension.
|
|
479
|
+
The data type needs to be consistent with the `labels`. It should also be broadcastable with the `labels`.
|
|
480
|
+
- **labels** (Tensor) - The input label. Tensor of any dimension.
|
|
481
|
+
The data type needs to be consistent with the `logits`. It should also be broadcastable with the `logits`.
|
|
482
|
+
|
|
483
|
+
Outputs:
|
|
484
|
+
- Tensor. If `reduction` is ``'mean'`` or ``'sum'``, the shape of output is `Tensor Scalar`.
|
|
485
|
+
- If reduction is ``'none'``, the shape of output is the broadcasted shape of `logits` and `labels` .
|
|
486
|
+
|
|
487
|
+
Raises:
|
|
488
|
+
ValueError: If `reduction` is not one of ``'mean'``, ``'sum'`` or ``'none'``.
|
|
489
|
+
ValueError: If `logits` and `labels` are not broadcastable.
|
|
490
|
+
TypeError: If `logits` and `labels` are in different data type.
|
|
491
|
+
|
|
492
|
+
Supported Platforms:
|
|
493
|
+
``Ascend``
|
|
494
|
+
|
|
495
|
+
Examples:
|
|
496
|
+
>>> import mindspore
|
|
497
|
+
>>> from mindspore import Tensor, nn
|
|
498
|
+
>>> import numpy as np
|
|
499
|
+
>>> # Case 1: logits.shape = labels.shape = (3,)
|
|
500
|
+
>>> loss = nn.MSELoss()
|
|
501
|
+
>>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
|
502
|
+
>>> labels = Tensor(np.array([1, 1, 1]), mindspore.float32)
|
|
503
|
+
>>> output = loss(logits, labels)
|
|
504
|
+
>>> print(output)
|
|
505
|
+
1.6666667
|
|
506
|
+
>>> # Case 2: logits.shape = (3,), labels.shape = (2, 3)
|
|
507
|
+
>>> loss = nn.MSELoss(reduction='none')
|
|
508
|
+
>>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
|
|
509
|
+
>>> labels = Tensor(np.array([[1, 1, 1], [1, 2, 2]]), mindspore.float32)
|
|
510
|
+
>>> output = loss(logits, labels)
|
|
511
|
+
>>> print(output)
|
|
512
|
+
[[0. 1. 4.]
|
|
513
|
+
[0. 0. 1.]]
|
|
514
|
+
"""
|
|
515
|
+
|
|
516
|
+
def __init__(self, reduction='mean'):
|
|
517
|
+
super(MSELoss, self).__init__()
|
|
518
|
+
self.mse_loss = mse_loss_ext
|
|
519
|
+
self.reduction = reduction
|
|
520
|
+
|
|
521
|
+
def construct(self, input, target):
|
|
522
|
+
out = self.mse_loss(input, target, self.reduction)
|
|
523
|
+
return out
|
|
524
|
+
|
|
525
|
+
|
|
307
526
|
__all__ = [
|
|
308
|
-
'MaxPool2d',
|
|
309
527
|
# 1
|
|
310
528
|
'BCEWithLogitsLoss',
|
|
311
529
|
# 2
|
|
312
530
|
|
|
313
531
|
# 3
|
|
314
|
-
|
|
532
|
+
'Identity',
|
|
315
533
|
# 4
|
|
316
534
|
|
|
317
535
|
# 5
|
|
@@ -321,12 +539,13 @@ __all__ = [
|
|
|
321
539
|
# 7
|
|
322
540
|
'Unfold',
|
|
323
541
|
# 8
|
|
324
|
-
|
|
542
|
+
'Softmax',
|
|
325
543
|
# 9
|
|
326
544
|
'Upsample',
|
|
327
545
|
# 10
|
|
328
546
|
|
|
329
547
|
# 11
|
|
548
|
+
'ReLU',
|
|
330
549
|
|
|
331
550
|
# 12
|
|
332
551
|
|
|
@@ -337,11 +556,11 @@ __all__ = [
|
|
|
337
556
|
# 15
|
|
338
557
|
|
|
339
558
|
# 16
|
|
340
|
-
|
|
559
|
+
'LogSoftmax',
|
|
341
560
|
# 17
|
|
342
561
|
|
|
343
562
|
# 18
|
|
344
|
-
|
|
563
|
+
'PReLU',
|
|
345
564
|
# 19
|
|
346
565
|
|
|
347
566
|
# 20
|
|
@@ -385,6 +604,7 @@ __all__ = [
|
|
|
385
604
|
# 39
|
|
386
605
|
|
|
387
606
|
# 40
|
|
607
|
+
'GroupNorm',
|
|
388
608
|
|
|
389
609
|
# 41
|
|
390
610
|
|
|
@@ -397,6 +617,7 @@ __all__ = [
|
|
|
397
617
|
# 45
|
|
398
618
|
|
|
399
619
|
# 46
|
|
620
|
+
'SiLU',
|
|
400
621
|
|
|
401
622
|
# 47
|
|
402
623
|
|
|
@@ -497,16 +718,40 @@ __all__ = [
|
|
|
497
718
|
# 95
|
|
498
719
|
|
|
499
720
|
# 96
|
|
721
|
+
'AdaptiveAvgPool1d',
|
|
500
722
|
|
|
501
723
|
# 97
|
|
724
|
+
'AdaptiveAvgPool2d',
|
|
502
725
|
|
|
503
726
|
# 98
|
|
504
727
|
|
|
505
728
|
# 99
|
|
506
|
-
|
|
729
|
+
'AvgPool2d',
|
|
507
730
|
# 100
|
|
731
|
+
'SELU',
|
|
732
|
+
# 159
|
|
733
|
+
'GELU',
|
|
734
|
+
# 220
|
|
735
|
+
'Hardshrink',
|
|
736
|
+
# 221
|
|
737
|
+
'Hardsigmoid',
|
|
738
|
+
# 222
|
|
739
|
+
'Hardswish',
|
|
740
|
+
# 238
|
|
741
|
+
'L1Loss',
|
|
742
|
+
# 267
|
|
743
|
+
'Mish',
|
|
744
|
+
# 258
|
|
745
|
+
'MSELoss',
|
|
746
|
+
# 259
|
|
747
|
+
|
|
748
|
+
# 556
|
|
749
|
+
'LogSigmoid',
|
|
750
|
+
|
|
751
|
+
# 674
|
|
752
|
+
'BatchNorm1d',
|
|
753
|
+
# 675
|
|
754
|
+
'BatchNorm2d',
|
|
755
|
+
# 676
|
|
756
|
+
'BatchNorm3d',
|
|
508
757
|
]
|
|
509
|
-
|
|
510
|
-
__all__.extend(basic.__all__)
|
|
511
|
-
__all__.extend(embedding.__all__)
|
|
512
|
-
__all__.extend(normalization.__all__)
|