mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +8 -1
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +275 -64
- mindspore/common/_utils.py +0 -44
- mindspore/common/api.py +285 -35
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
- mindspore/common/hook_handle.py +60 -0
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/parameter.py +13 -107
- mindspore/common/recompute.py +4 -11
- mindspore/common/tensor.py +16 -169
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +5 -85
- mindspore/dataset/engine/datasets.py +8 -4
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dnnl.dll +0 -0
- mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +1 -1
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/distributed.py +182 -62
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +4 -4
- mindspore/mint/nn/layer/normalization.py +8 -13
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +16 -66
- mindspore/nn/layer/basic.py +49 -1
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +31 -124
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +0 -1
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +0 -1
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/wrap/grad_reducer.py +4 -74
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/fft.py +9 -9
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +0 -5
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
- mindspore/ops/auto_generate/gen_extend_func.py +2 -7
- mindspore/ops/auto_generate/gen_ops_def.py +98 -141
- mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +15 -1
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +1 -0
- mindspore/ops/function/array_func.py +14 -12
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +3 -4
- mindspore/ops/function/math_func.py +45 -54
- mindspore/ops/function/nn_func.py +75 -294
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +2 -0
- mindspore/ops/functional_overload.py +354 -18
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +1 -38
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +1 -0
- mindspore/ops/operations/comm_ops.py +94 -2
- mindspore/ops/operations/custom_ops.py +228 -19
- mindspore/ops/operations/debug_ops.py +27 -29
- mindspore/ops/operations/manually_defined/ops_def.py +27 -306
- mindspore/ops/operations/nn_ops.py +2 -2
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +1 -17
- mindspore/ops/tensor_method.py +72 -3
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +1 -4
- mindspore/parallel/_utils.py +29 -6
- mindspore/parallel/checkpoint_transform.py +18 -2
- mindspore/parallel/cluster/process_entity/_api.py +24 -32
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +117 -16
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +38 -2
- mindspore/profiler/common/path_manager.py +56 -24
- mindspore/profiler/common/profiler_context.py +2 -12
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/experimental_config.py +2 -1
- mindspore/profiler/platform/npu_profiler.py +33 -6
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +3 -2
- mindspore/runtime/executor.py +11 -3
- mindspore/runtime/memory.py +112 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +5 -18
- mindspore/train/amp.py +6 -4
- mindspore/train/callback/_checkpoint.py +0 -9
- mindspore/train/callback/_train_fault_tolerance.py +69 -18
- mindspore/train/data_sink.py +1 -5
- mindspore/train/model.py +38 -211
- mindspore/train/serialization.py +126 -387
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +144 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/train/memory_profiling_pb2.py +0 -298
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -1,251 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""
|
|
16
|
-
Parallel Loss for the Parallel Training.
|
|
17
|
-
These are experimental APIs that are subject to change or deletion.
|
|
18
|
-
"""
|
|
19
|
-
from __future__ import absolute_import
|
|
20
|
-
|
|
21
|
-
from mindspore.parallel import set_algo_parameters
|
|
22
|
-
from mindspore.common.tensor import Tensor
|
|
23
|
-
import mindspore.common.dtype as mstype
|
|
24
|
-
from mindspore.ops import operations as P
|
|
25
|
-
from mindspore.ops import functional as F
|
|
26
|
-
from mindspore.nn import Cell
|
|
27
|
-
from mindspore.nn.loss.loss import _check_is_tensor
|
|
28
|
-
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
|
|
29
|
-
from mindspore.context import ParallelMode
|
|
30
|
-
from mindspore.parallel._utils import _get_device_num, _get_pipeline_stages
|
|
31
|
-
from mindspore.log import _LogActionOnce
|
|
32
|
-
from mindspore import log as logger
|
|
33
|
-
from mindspore.parallel._transformer.layers import _check_input_dtype
|
|
34
|
-
from mindspore.parallel._transformer.op_parallel_config import default_dpmp_config, OpParallelConfig
|
|
35
|
-
|
|
36
|
-
__all__ = ["CrossEntropyLoss"]
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
class _Softmax(Cell):
|
|
40
|
-
"""
|
|
41
|
-
Calculate the softmax results with given logits.
|
|
42
|
-
|
|
43
|
-
Note:
|
|
44
|
-
The bprop of the cell is rewritten, just returns the accepted dout as returns. This cell should be used
|
|
45
|
-
together with _NLLoss, to optimize the bprop of the cross entroy loss.
|
|
46
|
-
|
|
47
|
-
Args:
|
|
48
|
-
parallel_config (OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
|
|
49
|
-
an instance of `OpParallelConfig` with default args.
|
|
50
|
-
|
|
51
|
-
Inputs:
|
|
52
|
-
- **logits** (Tensor) - Tensor of shape (N, C). Data type must be float16 or float32. The output logits of
|
|
53
|
-
the backbone.
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
Outputs:
|
|
57
|
-
Tensor. The corresponding softmax results.
|
|
58
|
-
"""
|
|
59
|
-
def __init__(self, parallel_config=default_dpmp_config):
|
|
60
|
-
super(_Softmax, self).__init__()
|
|
61
|
-
if not isinstance(parallel_config, OpParallelConfig):
|
|
62
|
-
raise TypeError("For 'CrossEntropyLoss', the class variable 'parallel_config' must be OpParallelConfig"
|
|
63
|
-
", but got the type: {}.".format(type(parallel_config)))
|
|
64
|
-
dp = parallel_config.data_parallel
|
|
65
|
-
mp = parallel_config.model_parallel
|
|
66
|
-
# on/off value for onehot, for smooth labeling, modify the off_value
|
|
67
|
-
self.on_value = Tensor(1.0, mstype.float32)
|
|
68
|
-
self.off_value = Tensor(0.0, mstype.float32)
|
|
69
|
-
|
|
70
|
-
self.sum = P.ReduceSum().shard(((dp, mp),))
|
|
71
|
-
self.max = P.ArgMaxWithValue(axis=-1, keep_dims=True).shard(
|
|
72
|
-
((dp, mp),))
|
|
73
|
-
self.sub = P.Sub().shard(((dp, mp), (dp, 1)))
|
|
74
|
-
self.exp = P.Exp().shard(((dp, mp),))
|
|
75
|
-
self.div = P.RealDiv().shard(((dp, mp), (dp, 1)))
|
|
76
|
-
self.onehot = P.OneHot().shard(((dp, mp), (), ()))
|
|
77
|
-
|
|
78
|
-
def construct(self, logits, label):
|
|
79
|
-
# LogSoftmax for logits over last dimension
|
|
80
|
-
logits = F.cast(logits, mstype.float32)
|
|
81
|
-
_, logit_max = self.max(logits)
|
|
82
|
-
logit_sub = self.sub(logits, logit_max)
|
|
83
|
-
logit_exp = self.exp(logit_sub)
|
|
84
|
-
exp_sum = self.sum(logit_exp, -1)
|
|
85
|
-
exp_sum = P.Reshape()(exp_sum, (F.shape(exp_sum)[0], 1))
|
|
86
|
-
softmax_result = self.div(logit_exp, exp_sum)
|
|
87
|
-
|
|
88
|
-
one_hot_label = self.onehot(label, F.shape(logits)[-1], self.on_value, self.off_value)
|
|
89
|
-
return softmax_result, one_hot_label
|
|
90
|
-
|
|
91
|
-
def bprop(self, logits, label, out, dout):
|
|
92
|
-
"""just return the loss of the dout. Note this should be used together with _NLLLoss"""
|
|
93
|
-
d_logits = F.cast(dout[0], F.dtype(logits))
|
|
94
|
-
return d_logits, F.zeros_like(label)
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
class _NLLLoss(Cell):
|
|
98
|
-
"""
|
|
99
|
-
Calculate the NLLLoss results with given softmax results and the label.
|
|
100
|
-
|
|
101
|
-
Note:
|
|
102
|
-
The bprop of the cell is rewritten. This cell should be used
|
|
103
|
-
together with _Softmax, to optimize the bprop of the cross entroy loss.
|
|
104
|
-
|
|
105
|
-
Args:
|
|
106
|
-
parallel_config (OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
|
|
107
|
-
an instance of `OpParallelConfig` with default args.
|
|
108
|
-
|
|
109
|
-
Inputs:
|
|
110
|
-
- **loss** (Tensor) - Tensor of shape (N, C). Data type is float32.
|
|
111
|
-
|
|
112
|
-
Outputs:
|
|
113
|
-
Tensor. The corresponding loss results.
|
|
114
|
-
"""
|
|
115
|
-
def __init__(self, parallel_config=default_dpmp_config):
|
|
116
|
-
super(_NLLLoss, self).__init__()
|
|
117
|
-
if not isinstance(parallel_config, OpParallelConfig):
|
|
118
|
-
raise TypeError("For 'CrossEntropyLoss', the class variable 'parallel_config' must be OpParallelConfig"
|
|
119
|
-
", but got the type: {}.".format(type(parallel_config)))
|
|
120
|
-
dp = parallel_config.data_parallel
|
|
121
|
-
mp = parallel_config.model_parallel
|
|
122
|
-
self.repeat_loss = 1
|
|
123
|
-
self.eps_const = Tensor(1e-24, mstype.float32)
|
|
124
|
-
# In auto parallel, there will be a virtual div in the back propagation begins. As we use custom bprop function
|
|
125
|
-
# we need to eliminate this virtual div by adding a factor "mp".
|
|
126
|
-
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL, ParallelMode.SEMI_AUTO_PARALLEL):
|
|
127
|
-
self.repeat_loss = mp
|
|
128
|
-
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
|
|
129
|
-
self.sum = P.ReduceSum()
|
|
130
|
-
self.mul = P.Mul()
|
|
131
|
-
self.neg = P.Neg()
|
|
132
|
-
self.log = P.Log()
|
|
133
|
-
self.add = P.Add().shard(((dp, mp), ()))
|
|
134
|
-
else:
|
|
135
|
-
self.sum = P.ReduceSum().shard(((dp, mp),))
|
|
136
|
-
self.mul = P.Mul().shard(((dp, mp), (dp, mp)))
|
|
137
|
-
self.neg = P.Neg().shard(((dp, mp),))
|
|
138
|
-
self.log = P.Log().shard(((dp, mp),))
|
|
139
|
-
self.add = P.Add().shard(((dp, mp), ()))
|
|
140
|
-
|
|
141
|
-
def construct(self, softmax_result, one_hot_label):
|
|
142
|
-
"""The forward of _NLLLoss"""
|
|
143
|
-
log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
|
|
144
|
-
loss = self.mul(log_softmax_result, one_hot_label)
|
|
145
|
-
loss_unsum = self.neg(loss)
|
|
146
|
-
loss_reduce = self.sum(loss_unsum, -1)
|
|
147
|
-
return loss_reduce
|
|
148
|
-
|
|
149
|
-
def bprop(self, softmax_result, one_hot_label, out, dout):
|
|
150
|
-
"""A simplified function. Note this should be used together with _Softmax"""
|
|
151
|
-
logits = softmax_result - one_hot_label
|
|
152
|
-
logits = logits * P.ExpandDims()(dout, -1) * self.repeat_loss
|
|
153
|
-
|
|
154
|
-
return logits, F.zeros_like(one_hot_label)
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
class CrossEntropyLoss(Cell):
|
|
158
|
-
"""
|
|
159
|
-
Calculate the cross entropy loss.
|
|
160
|
-
|
|
161
|
-
Args:
|
|
162
|
-
parallel_config (OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
|
|
163
|
-
an instance of `OpParallelConfig` with default args.
|
|
164
|
-
|
|
165
|
-
Inputs:
|
|
166
|
-
- **logits** (Tensor) - Tensor of shape (N, C). Data type must be float16 or float32. The output logits of
|
|
167
|
-
the backbone.
|
|
168
|
-
|
|
169
|
-
- **labels** (Tensor) - Tensor of shape (N, ). The ground truth label of the sample.
|
|
170
|
-
|
|
171
|
-
- **input_mask** (Tensor) - Tensor of shape (N, ). input_mask indicates whether there are padded inputs and for
|
|
172
|
-
padded inputs it will not be counted into loss.
|
|
173
|
-
|
|
174
|
-
Outputs:
|
|
175
|
-
Tensor. The corresponding cross entropy loss.
|
|
176
|
-
|
|
177
|
-
Examples:
|
|
178
|
-
>>> import numpy as np
|
|
179
|
-
>>> from mindspore import dtype as mstype
|
|
180
|
-
>>> from mindspore.nn.transformer import CrossEntropyLoss
|
|
181
|
-
>>> from mindspore import Tensor
|
|
182
|
-
>>> loss = CrossEntropyLoss()
|
|
183
|
-
>>> logits = Tensor(np.array([[3, 5, 6, 9, 12, 33, 42, 12, 32, 72]]), mstype.float32)
|
|
184
|
-
>>> labels_np = np.array([1]).astype(np.int32)
|
|
185
|
-
>>> input_mask = Tensor(np.ones(1).astype(np.float32))
|
|
186
|
-
>>> labels = Tensor(labels_np)
|
|
187
|
-
>>> output = loss(logits, labels, input_mask)
|
|
188
|
-
>>> print(output.shape)
|
|
189
|
-
(1,)
|
|
190
|
-
"""
|
|
191
|
-
@_LogActionOnce(logger=logger, key='CrossEntropyLoss',
|
|
192
|
-
no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
|
|
193
|
-
def __init__(self, parallel_config=default_dpmp_config):
|
|
194
|
-
super(CrossEntropyLoss, self).__init__()
|
|
195
|
-
if not isinstance(parallel_config, OpParallelConfig):
|
|
196
|
-
raise TypeError("For 'CrossEntropyLoss', the class variable 'parallel_config' must be OpParallelConfig"
|
|
197
|
-
", but got the type: {}.".format(type(parallel_config)))
|
|
198
|
-
dp = parallel_config.data_parallel
|
|
199
|
-
mp = parallel_config.model_parallel
|
|
200
|
-
self.enable_force_redistribute = False
|
|
201
|
-
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL, ParallelMode.SEMI_AUTO_PARALLEL):
|
|
202
|
-
self.enable_force_redistribute = True
|
|
203
|
-
self.add = P.Add().shard(((dp, mp), ())).add_prim_attr("keep_alive", True)
|
|
204
|
-
self.add_label = P.Add().shard(((dp,), ())).add_prim_attr("keep_alive", True)
|
|
205
|
-
self._check_and_modify_sharding_context(dp)
|
|
206
|
-
self.sum2 = P.ReduceSum().shard(((1,),))
|
|
207
|
-
self.mul2 = P.Mul().shard(((1,), (1,)))
|
|
208
|
-
self.add2 = P.Add()
|
|
209
|
-
self.div2 = P.RealDiv()
|
|
210
|
-
self.relu = P.ReLU().shard(((1,),))
|
|
211
|
-
|
|
212
|
-
self._softmax = _Softmax(parallel_config)
|
|
213
|
-
self._nllloss = _NLLLoss(parallel_config)
|
|
214
|
-
|
|
215
|
-
@staticmethod
|
|
216
|
-
def _check_and_modify_sharding_context(dp):
|
|
217
|
-
device_num = _get_device_num()
|
|
218
|
-
stages = _get_pipeline_stages()
|
|
219
|
-
if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and dp * stages != device_num:
|
|
220
|
-
set_algo_parameters(fully_use_devices=False)
|
|
221
|
-
|
|
222
|
-
def construct(self, logits, label, input_mask):
|
|
223
|
-
self._check_input(logits, label, input_mask)
|
|
224
|
-
|
|
225
|
-
# The add is used for forcing the redistribution before stepping in sub graphs, when semi/auto parallel enabled.
|
|
226
|
-
if self.enable_force_redistribute:
|
|
227
|
-
logits = self.add(logits, 0)
|
|
228
|
-
label = self.add_label(label, 0)
|
|
229
|
-
softmax, one_hot_label = self._softmax(logits, label)
|
|
230
|
-
loss_reduce = self._nllloss(softmax, one_hot_label)
|
|
231
|
-
|
|
232
|
-
# Using input_mask to mask the loss
|
|
233
|
-
input_mask = P.Reshape()(input_mask, (-1,))
|
|
234
|
-
numerator = self.sum2(self.mul2(loss_reduce, input_mask))
|
|
235
|
-
|
|
236
|
-
denominator = self.add2(
|
|
237
|
-
self.sum2(input_mask),
|
|
238
|
-
P.Cast()(F.tuple_to_array((1e-5,)), mstype.float32))
|
|
239
|
-
loss = self.div2(numerator, denominator)
|
|
240
|
-
|
|
241
|
-
return loss
|
|
242
|
-
|
|
243
|
-
def _check_input(self, logits, label, input_mask):
|
|
244
|
-
r"""Check the input tensor shape and type"""
|
|
245
|
-
_check_is_tensor('logits', logits, self.cls_name)
|
|
246
|
-
_check_is_tensor('label', label, self.cls_name)
|
|
247
|
-
_check_is_tensor('input_mask', input_mask, self.cls_name)
|
|
248
|
-
_check_input_dtype(F.dtype(logits), "logits", [mstype.float32, mstype.float16], self.cls_name)
|
|
249
|
-
_check_input_dtype(F.dtype(label), "label", [mstype.int32], self.cls_name)
|
|
250
|
-
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32], self.cls_name)
|
|
251
|
-
return True
|