mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (290) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
  6. mindspore/_extends/parse/compile_config.py +24 -1
  7. mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
  8. mindspore/_extends/parse/resources.py +1 -1
  9. mindspore/_extends/parse/standard_method.py +8 -1
  10. mindspore/_extends/parse/trope.py +2 -1
  11. mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/boost/base.py +29 -2
  18. mindspore/common/_decorator.py +3 -2
  19. mindspore/common/_grad_function.py +3 -1
  20. mindspore/common/_tensor_cpp_method.py +1 -1
  21. mindspore/common/_tensor_docs.py +275 -64
  22. mindspore/common/_utils.py +0 -44
  23. mindspore/common/api.py +285 -35
  24. mindspore/common/dump.py +7 -108
  25. mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
  26. mindspore/common/hook_handle.py +60 -0
  27. mindspore/common/jit_config.py +5 -1
  28. mindspore/common/jit_trace.py +27 -12
  29. mindspore/common/lazy_inline.py +5 -3
  30. mindspore/common/parameter.py +13 -107
  31. mindspore/common/recompute.py +4 -11
  32. mindspore/common/tensor.py +16 -169
  33. mindspore/communication/_comm_helper.py +11 -1
  34. mindspore/communication/comm_func.py +138 -4
  35. mindspore/communication/management.py +85 -1
  36. mindspore/config/op_info.config +0 -15
  37. mindspore/context.py +5 -85
  38. mindspore/dataset/engine/datasets.py +8 -4
  39. mindspore/dataset/engine/datasets_vision.py +1 -1
  40. mindspore/dataset/engine/validators.py +1 -15
  41. mindspore/dnnl.dll +0 -0
  42. mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
  43. mindspore/graph/custom_pass.py +55 -0
  44. mindspore/include/dataset/execute.h +2 -2
  45. mindspore/jpeg62.dll +0 -0
  46. mindspore/mindrecord/__init__.py +3 -3
  47. mindspore/mindrecord/common/exceptions.py +1 -0
  48. mindspore/mindrecord/config.py +1 -1
  49. mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
  50. mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
  51. mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
  52. mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
  53. mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
  54. mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
  55. mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
  56. mindspore/mindrecord/filereader.py +4 -4
  57. mindspore/mindrecord/filewriter.py +5 -5
  58. mindspore/mindrecord/mindpage.py +2 -2
  59. mindspore/mindrecord/tools/cifar10.py +1 -1
  60. mindspore/mindrecord/tools/cifar100.py +1 -1
  61. mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
  62. mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
  63. mindspore/mindrecord/tools/csv_to_mr.py +1 -1
  64. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  65. mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
  66. mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
  67. mindspore/mindspore_backend_common.dll +0 -0
  68. mindspore/mindspore_backend_manager.dll +0 -0
  69. mindspore/mindspore_cluster.dll +0 -0
  70. mindspore/mindspore_common.dll +0 -0
  71. mindspore/mindspore_core.dll +0 -0
  72. mindspore/mindspore_cpu.dll +0 -0
  73. mindspore/mindspore_dump.dll +0 -0
  74. mindspore/mindspore_frontend.dll +0 -0
  75. mindspore/mindspore_glog.dll +0 -0
  76. mindspore/mindspore_hardware_abstract.dll +0 -0
  77. mindspore/mindspore_memory_pool.dll +0 -0
  78. mindspore/mindspore_ms_backend.dll +0 -0
  79. mindspore/mindspore_ops.dll +0 -0
  80. mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
  81. mindspore/mindspore_profiler.dll +0 -0
  82. mindspore/mindspore_pyboost.dll +0 -0
  83. mindspore/mindspore_pynative.dll +0 -0
  84. mindspore/mindspore_runtime_pipeline.dll +0 -0
  85. mindspore/mindspore_runtime_utils.dll +0 -0
  86. mindspore/mindspore_tools.dll +0 -0
  87. mindspore/mint/__init__.py +15 -10
  88. mindspore/mint/distributed/distributed.py +182 -62
  89. mindspore/mint/nn/__init__.py +2 -16
  90. mindspore/mint/nn/functional.py +4 -110
  91. mindspore/mint/nn/layer/__init__.py +0 -2
  92. mindspore/mint/nn/layer/activation.py +0 -6
  93. mindspore/mint/nn/layer/basic.py +0 -47
  94. mindspore/mint/nn/layer/conv.py +4 -4
  95. mindspore/mint/nn/layer/normalization.py +8 -13
  96. mindspore/mint/nn/layer/pooling.py +0 -4
  97. mindspore/nn/__init__.py +1 -3
  98. mindspore/nn/cell.py +16 -66
  99. mindspore/nn/layer/basic.py +49 -1
  100. mindspore/nn/layer/container.py +16 -0
  101. mindspore/nn/layer/embedding.py +4 -169
  102. mindspore/nn/layer/normalization.py +2 -1
  103. mindspore/nn/layer/thor_layer.py +4 -85
  104. mindspore/nn/optim/ada_grad.py +0 -1
  105. mindspore/nn/optim/adafactor.py +0 -1
  106. mindspore/nn/optim/adam.py +31 -124
  107. mindspore/nn/optim/adamax.py +0 -1
  108. mindspore/nn/optim/asgd.py +0 -1
  109. mindspore/nn/optim/ftrl.py +8 -102
  110. mindspore/nn/optim/lamb.py +0 -1
  111. mindspore/nn/optim/lars.py +0 -3
  112. mindspore/nn/optim/lazyadam.py +25 -218
  113. mindspore/nn/optim/momentum.py +5 -43
  114. mindspore/nn/optim/optimizer.py +6 -55
  115. mindspore/nn/optim/proximal_ada_grad.py +0 -1
  116. mindspore/nn/optim/rmsprop.py +0 -1
  117. mindspore/nn/optim/rprop.py +0 -1
  118. mindspore/nn/optim/sgd.py +0 -1
  119. mindspore/nn/optim/tft_wrapper.py +0 -1
  120. mindspore/nn/optim/thor.py +0 -2
  121. mindspore/nn/probability/bijector/bijector.py +7 -8
  122. mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
  123. mindspore/nn/probability/bijector/power_transform.py +20 -21
  124. mindspore/nn/probability/bijector/scalar_affine.py +5 -5
  125. mindspore/nn/probability/bijector/softplus.py +13 -14
  126. mindspore/nn/wrap/grad_reducer.py +4 -74
  127. mindspore/numpy/array_creations.py +2 -2
  128. mindspore/numpy/fft.py +9 -9
  129. mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
  130. mindspore/onnx/onnx_export.py +137 -0
  131. mindspore/opencv_core4110.dll +0 -0
  132. mindspore/opencv_imgcodecs4110.dll +0 -0
  133. mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
  134. mindspore/ops/__init__.py +2 -0
  135. mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
  136. mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
  137. mindspore/ops/_op_impl/cpu/__init__.py +0 -5
  138. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
  139. mindspore/ops/auto_generate/gen_extend_func.py +2 -7
  140. mindspore/ops/auto_generate/gen_ops_def.py +98 -141
  141. mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
  142. mindspore/ops/communication.py +97 -0
  143. mindspore/ops/composite/__init__.py +5 -2
  144. mindspore/ops/composite/base.py +15 -1
  145. mindspore/ops/composite/multitype_ops/__init__.py +3 -1
  146. mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
  147. mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
  148. mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
  149. mindspore/ops/function/__init__.py +1 -0
  150. mindspore/ops/function/array_func.py +14 -12
  151. mindspore/ops/function/comm_func.py +3883 -0
  152. mindspore/ops/function/debug_func.py +3 -4
  153. mindspore/ops/function/math_func.py +45 -54
  154. mindspore/ops/function/nn_func.py +75 -294
  155. mindspore/ops/function/random_func.py +9 -18
  156. mindspore/ops/functional.py +2 -0
  157. mindspore/ops/functional_overload.py +354 -18
  158. mindspore/ops/operations/__init__.py +2 -5
  159. mindspore/ops/operations/_custom_ops_utils.py +7 -9
  160. mindspore/ops/operations/_inner_ops.py +1 -38
  161. mindspore/ops/operations/_rl_inner_ops.py +0 -933
  162. mindspore/ops/operations/array_ops.py +1 -0
  163. mindspore/ops/operations/comm_ops.py +94 -2
  164. mindspore/ops/operations/custom_ops.py +228 -19
  165. mindspore/ops/operations/debug_ops.py +27 -29
  166. mindspore/ops/operations/manually_defined/ops_def.py +27 -306
  167. mindspore/ops/operations/nn_ops.py +2 -2
  168. mindspore/ops/operations/sparse_ops.py +0 -83
  169. mindspore/ops/primitive.py +1 -17
  170. mindspore/ops/tensor_method.py +72 -3
  171. mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
  172. mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
  173. mindspore/ops_generate/api/functions_cc_generator.py +53 -4
  174. mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
  175. mindspore/ops_generate/common/gen_constants.py +11 -10
  176. mindspore/ops_generate/common/op_proto.py +18 -1
  177. mindspore/ops_generate/common/template.py +102 -245
  178. mindspore/ops_generate/common/template_utils.py +212 -0
  179. mindspore/ops_generate/gen_custom_ops.py +69 -0
  180. mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
  181. mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
  182. mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
  183. mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
  184. mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
  185. mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
  186. mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
  187. mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
  188. mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
  189. mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
  190. mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
  191. mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
  192. mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
  193. mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
  194. mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
  195. mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
  196. mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
  197. mindspore/ops_generate/resources/yaml_loader.py +13 -0
  198. mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
  199. mindspore/parallel/_cell_wrapper.py +1 -1
  200. mindspore/parallel/_parallel_serialization.py +1 -4
  201. mindspore/parallel/_utils.py +29 -6
  202. mindspore/parallel/checkpoint_transform.py +18 -2
  203. mindspore/parallel/cluster/process_entity/_api.py +24 -32
  204. mindspore/parallel/cluster/process_entity/_utils.py +9 -5
  205. mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
  206. mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
  207. mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
  208. mindspore/parallel/strategy.py +336 -0
  209. mindspore/parallel/transform_safetensors.py +117 -16
  210. mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
  211. mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
  212. mindspore/profiler/common/constant.py +5 -0
  213. mindspore/profiler/common/file_manager.py +9 -0
  214. mindspore/profiler/common/msprof_cmd_tool.py +38 -2
  215. mindspore/profiler/common/path_manager.py +56 -24
  216. mindspore/profiler/common/profiler_context.py +2 -12
  217. mindspore/profiler/common/profiler_info.py +3 -3
  218. mindspore/profiler/common/profiler_path_manager.py +13 -0
  219. mindspore/profiler/common/util.py +30 -3
  220. mindspore/profiler/experimental_config.py +2 -1
  221. mindspore/profiler/platform/npu_profiler.py +33 -6
  222. mindspore/run_check/_check_version.py +108 -24
  223. mindspore/runtime/__init__.py +3 -2
  224. mindspore/runtime/executor.py +11 -3
  225. mindspore/runtime/memory.py +112 -0
  226. mindspore/swresample-4.dll +0 -0
  227. mindspore/swscale-6.dll +0 -0
  228. mindspore/tinyxml2.dll +0 -0
  229. mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
  230. mindspore/tools/data_dump.py +130 -0
  231. mindspore/tools/sdc_detect.py +91 -0
  232. mindspore/tools/stress_detect.py +63 -0
  233. mindspore/train/__init__.py +6 -6
  234. mindspore/train/_utils.py +5 -18
  235. mindspore/train/amp.py +6 -4
  236. mindspore/train/callback/_checkpoint.py +0 -9
  237. mindspore/train/callback/_train_fault_tolerance.py +69 -18
  238. mindspore/train/data_sink.py +1 -5
  239. mindspore/train/model.py +38 -211
  240. mindspore/train/serialization.py +126 -387
  241. mindspore/turbojpeg.dll +0 -0
  242. mindspore/utils/__init__.py +6 -3
  243. mindspore/utils/dlpack.py +92 -0
  244. mindspore/utils/dryrun.py +1 -1
  245. mindspore/utils/runtime_execution_order_check.py +10 -0
  246. mindspore/utils/sdc_detect.py +14 -12
  247. mindspore/utils/stress_detect.py +43 -0
  248. mindspore/utils/utils.py +144 -8
  249. mindspore/version.py +1 -1
  250. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
  251. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
  252. mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
  253. mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
  254. mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
  255. mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
  256. mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
  257. mindspore/experimental/llm_boost/register.py +0 -130
  258. mindspore/experimental/llm_boost/utils.py +0 -31
  259. mindspore/include/OWNERS +0 -7
  260. mindspore/mindspore_cpu_res_manager.dll +0 -0
  261. mindspore/mindspore_ops_kernel_common.dll +0 -0
  262. mindspore/mindspore_res_manager.dll +0 -0
  263. mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
  264. mindspore/nn/reinforcement/_batch_read_write.py +0 -142
  265. mindspore/nn/reinforcement/_tensors_queue.py +0 -152
  266. mindspore/nn/reinforcement/tensor_array.py +0 -145
  267. mindspore/opencv_core452.dll +0 -0
  268. mindspore/opencv_imgcodecs452.dll +0 -0
  269. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
  270. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
  271. mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
  272. mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
  273. mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
  274. mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
  275. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
  276. mindspore/ops/operations/_tensor_array.py +0 -359
  277. mindspore/ops/operations/rl_ops.py +0 -288
  278. mindspore/parallel/_offload_context.py +0 -275
  279. mindspore/parallel/_recovery_context.py +0 -115
  280. mindspore/parallel/_transformer/__init__.py +0 -35
  281. mindspore/parallel/_transformer/layers.py +0 -765
  282. mindspore/parallel/_transformer/loss.py +0 -251
  283. mindspore/parallel/_transformer/moe.py +0 -693
  284. mindspore/parallel/_transformer/op_parallel_config.py +0 -222
  285. mindspore/parallel/_transformer/transformer.py +0 -3124
  286. mindspore/parallel/mpi/_mpi_config.py +0 -116
  287. mindspore/train/memory_profiling_pb2.py +0 -298
  288. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
  289. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
  290. {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
@@ -1,251 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """
16
- Parallel Loss for the Parallel Training.
17
- These are experimental APIs that are subject to change or deletion.
18
- """
19
- from __future__ import absolute_import
20
-
21
- from mindspore.parallel import set_algo_parameters
22
- from mindspore.common.tensor import Tensor
23
- import mindspore.common.dtype as mstype
24
- from mindspore.ops import operations as P
25
- from mindspore.ops import functional as F
26
- from mindspore.nn import Cell
27
- from mindspore.nn.loss.loss import _check_is_tensor
28
- from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
29
- from mindspore.context import ParallelMode
30
- from mindspore.parallel._utils import _get_device_num, _get_pipeline_stages
31
- from mindspore.log import _LogActionOnce
32
- from mindspore import log as logger
33
- from mindspore.parallel._transformer.layers import _check_input_dtype
34
- from mindspore.parallel._transformer.op_parallel_config import default_dpmp_config, OpParallelConfig
35
-
36
- __all__ = ["CrossEntropyLoss"]
37
-
38
-
39
- class _Softmax(Cell):
40
- """
41
- Calculate the softmax results with given logits.
42
-
43
- Note:
44
- The bprop of the cell is rewritten, just returns the accepted dout as returns. This cell should be used
45
- together with _NLLoss, to optimize the bprop of the cross entroy loss.
46
-
47
- Args:
48
- parallel_config (OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
49
- an instance of `OpParallelConfig` with default args.
50
-
51
- Inputs:
52
- - **logits** (Tensor) - Tensor of shape (N, C). Data type must be float16 or float32. The output logits of
53
- the backbone.
54
-
55
-
56
- Outputs:
57
- Tensor. The corresponding softmax results.
58
- """
59
- def __init__(self, parallel_config=default_dpmp_config):
60
- super(_Softmax, self).__init__()
61
- if not isinstance(parallel_config, OpParallelConfig):
62
- raise TypeError("For 'CrossEntropyLoss', the class variable 'parallel_config' must be OpParallelConfig"
63
- ", but got the type: {}.".format(type(parallel_config)))
64
- dp = parallel_config.data_parallel
65
- mp = parallel_config.model_parallel
66
- # on/off value for onehot, for smooth labeling, modify the off_value
67
- self.on_value = Tensor(1.0, mstype.float32)
68
- self.off_value = Tensor(0.0, mstype.float32)
69
-
70
- self.sum = P.ReduceSum().shard(((dp, mp),))
71
- self.max = P.ArgMaxWithValue(axis=-1, keep_dims=True).shard(
72
- ((dp, mp),))
73
- self.sub = P.Sub().shard(((dp, mp), (dp, 1)))
74
- self.exp = P.Exp().shard(((dp, mp),))
75
- self.div = P.RealDiv().shard(((dp, mp), (dp, 1)))
76
- self.onehot = P.OneHot().shard(((dp, mp), (), ()))
77
-
78
- def construct(self, logits, label):
79
- # LogSoftmax for logits over last dimension
80
- logits = F.cast(logits, mstype.float32)
81
- _, logit_max = self.max(logits)
82
- logit_sub = self.sub(logits, logit_max)
83
- logit_exp = self.exp(logit_sub)
84
- exp_sum = self.sum(logit_exp, -1)
85
- exp_sum = P.Reshape()(exp_sum, (F.shape(exp_sum)[0], 1))
86
- softmax_result = self.div(logit_exp, exp_sum)
87
-
88
- one_hot_label = self.onehot(label, F.shape(logits)[-1], self.on_value, self.off_value)
89
- return softmax_result, one_hot_label
90
-
91
- def bprop(self, logits, label, out, dout):
92
- """just return the loss of the dout. Note this should be used together with _NLLLoss"""
93
- d_logits = F.cast(dout[0], F.dtype(logits))
94
- return d_logits, F.zeros_like(label)
95
-
96
-
97
- class _NLLLoss(Cell):
98
- """
99
- Calculate the NLLLoss results with given softmax results and the label.
100
-
101
- Note:
102
- The bprop of the cell is rewritten. This cell should be used
103
- together with _Softmax, to optimize the bprop of the cross entroy loss.
104
-
105
- Args:
106
- parallel_config (OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
107
- an instance of `OpParallelConfig` with default args.
108
-
109
- Inputs:
110
- - **loss** (Tensor) - Tensor of shape (N, C). Data type is float32.
111
-
112
- Outputs:
113
- Tensor. The corresponding loss results.
114
- """
115
- def __init__(self, parallel_config=default_dpmp_config):
116
- super(_NLLLoss, self).__init__()
117
- if not isinstance(parallel_config, OpParallelConfig):
118
- raise TypeError("For 'CrossEntropyLoss', the class variable 'parallel_config' must be OpParallelConfig"
119
- ", but got the type: {}.".format(type(parallel_config)))
120
- dp = parallel_config.data_parallel
121
- mp = parallel_config.model_parallel
122
- self.repeat_loss = 1
123
- self.eps_const = Tensor(1e-24, mstype.float32)
124
- # In auto parallel, there will be a virtual div in the back propagation begins. As we use custom bprop function
125
- # we need to eliminate this virtual div by adding a factor "mp".
126
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL, ParallelMode.SEMI_AUTO_PARALLEL):
127
- self.repeat_loss = mp
128
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation():
129
- self.sum = P.ReduceSum()
130
- self.mul = P.Mul()
131
- self.neg = P.Neg()
132
- self.log = P.Log()
133
- self.add = P.Add().shard(((dp, mp), ()))
134
- else:
135
- self.sum = P.ReduceSum().shard(((dp, mp),))
136
- self.mul = P.Mul().shard(((dp, mp), (dp, mp)))
137
- self.neg = P.Neg().shard(((dp, mp),))
138
- self.log = P.Log().shard(((dp, mp),))
139
- self.add = P.Add().shard(((dp, mp), ()))
140
-
141
- def construct(self, softmax_result, one_hot_label):
142
- """The forward of _NLLLoss"""
143
- log_softmax_result = self.log(self.add(softmax_result, self.eps_const))
144
- loss = self.mul(log_softmax_result, one_hot_label)
145
- loss_unsum = self.neg(loss)
146
- loss_reduce = self.sum(loss_unsum, -1)
147
- return loss_reduce
148
-
149
- def bprop(self, softmax_result, one_hot_label, out, dout):
150
- """A simplified function. Note this should be used together with _Softmax"""
151
- logits = softmax_result - one_hot_label
152
- logits = logits * P.ExpandDims()(dout, -1) * self.repeat_loss
153
-
154
- return logits, F.zeros_like(one_hot_label)
155
-
156
-
157
- class CrossEntropyLoss(Cell):
158
- """
159
- Calculate the cross entropy loss.
160
-
161
- Args:
162
- parallel_config (OpParallelConfig): The parallel configure. Default `default_dpmp_config`,
163
- an instance of `OpParallelConfig` with default args.
164
-
165
- Inputs:
166
- - **logits** (Tensor) - Tensor of shape (N, C). Data type must be float16 or float32. The output logits of
167
- the backbone.
168
-
169
- - **labels** (Tensor) - Tensor of shape (N, ). The ground truth label of the sample.
170
-
171
- - **input_mask** (Tensor) - Tensor of shape (N, ). input_mask indicates whether there are padded inputs and for
172
- padded inputs it will not be counted into loss.
173
-
174
- Outputs:
175
- Tensor. The corresponding cross entropy loss.
176
-
177
- Examples:
178
- >>> import numpy as np
179
- >>> from mindspore import dtype as mstype
180
- >>> from mindspore.nn.transformer import CrossEntropyLoss
181
- >>> from mindspore import Tensor
182
- >>> loss = CrossEntropyLoss()
183
- >>> logits = Tensor(np.array([[3, 5, 6, 9, 12, 33, 42, 12, 32, 72]]), mstype.float32)
184
- >>> labels_np = np.array([1]).astype(np.int32)
185
- >>> input_mask = Tensor(np.ones(1).astype(np.float32))
186
- >>> labels = Tensor(labels_np)
187
- >>> output = loss(logits, labels, input_mask)
188
- >>> print(output.shape)
189
- (1,)
190
- """
191
- @_LogActionOnce(logger=logger, key='CrossEntropyLoss',
192
- no_warning=_get_parallel_mode() in (ParallelMode.STAND_ALONE,))
193
- def __init__(self, parallel_config=default_dpmp_config):
194
- super(CrossEntropyLoss, self).__init__()
195
- if not isinstance(parallel_config, OpParallelConfig):
196
- raise TypeError("For 'CrossEntropyLoss', the class variable 'parallel_config' must be OpParallelConfig"
197
- ", but got the type: {}.".format(type(parallel_config)))
198
- dp = parallel_config.data_parallel
199
- mp = parallel_config.model_parallel
200
- self.enable_force_redistribute = False
201
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL, ParallelMode.SEMI_AUTO_PARALLEL):
202
- self.enable_force_redistribute = True
203
- self.add = P.Add().shard(((dp, mp), ())).add_prim_attr("keep_alive", True)
204
- self.add_label = P.Add().shard(((dp,), ())).add_prim_attr("keep_alive", True)
205
- self._check_and_modify_sharding_context(dp)
206
- self.sum2 = P.ReduceSum().shard(((1,),))
207
- self.mul2 = P.Mul().shard(((1,), (1,)))
208
- self.add2 = P.Add()
209
- self.div2 = P.RealDiv()
210
- self.relu = P.ReLU().shard(((1,),))
211
-
212
- self._softmax = _Softmax(parallel_config)
213
- self._nllloss = _NLLLoss(parallel_config)
214
-
215
- @staticmethod
216
- def _check_and_modify_sharding_context(dp):
217
- device_num = _get_device_num()
218
- stages = _get_pipeline_stages()
219
- if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and dp * stages != device_num:
220
- set_algo_parameters(fully_use_devices=False)
221
-
222
- def construct(self, logits, label, input_mask):
223
- self._check_input(logits, label, input_mask)
224
-
225
- # The add is used for forcing the redistribution before stepping in sub graphs, when semi/auto parallel enabled.
226
- if self.enable_force_redistribute:
227
- logits = self.add(logits, 0)
228
- label = self.add_label(label, 0)
229
- softmax, one_hot_label = self._softmax(logits, label)
230
- loss_reduce = self._nllloss(softmax, one_hot_label)
231
-
232
- # Using input_mask to mask the loss
233
- input_mask = P.Reshape()(input_mask, (-1,))
234
- numerator = self.sum2(self.mul2(loss_reduce, input_mask))
235
-
236
- denominator = self.add2(
237
- self.sum2(input_mask),
238
- P.Cast()(F.tuple_to_array((1e-5,)), mstype.float32))
239
- loss = self.div2(numerator, denominator)
240
-
241
- return loss
242
-
243
- def _check_input(self, logits, label, input_mask):
244
- r"""Check the input tensor shape and type"""
245
- _check_is_tensor('logits', logits, self.cls_name)
246
- _check_is_tensor('label', label, self.cls_name)
247
- _check_is_tensor('input_mask', input_mask, self.cls_name)
248
- _check_input_dtype(F.dtype(logits), "logits", [mstype.float32, mstype.float16], self.cls_name)
249
- _check_input_dtype(F.dtype(label), "label", [mstype.int32], self.cls_name)
250
- _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32], self.cls_name)
251
- return True