mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (308) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +3 -1
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +50 -9
  9. mindspore/_extends/parse/compile_config.py +41 -0
  10. mindspore/_extends/parse/parser.py +9 -7
  11. mindspore/_extends/parse/standard_method.py +52 -14
  12. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  13. mindspore/amp.py +24 -10
  14. mindspore/atlprov.dll +0 -0
  15. mindspore/avcodec-59.dll +0 -0
  16. mindspore/avdevice-59.dll +0 -0
  17. mindspore/avfilter-8.dll +0 -0
  18. mindspore/avformat-59.dll +0 -0
  19. mindspore/avutil-57.dll +0 -0
  20. mindspore/c1.dll +0 -0
  21. mindspore/c1xx.dll +0 -0
  22. mindspore/c2.dll +0 -0
  23. mindspore/common/__init__.py +6 -4
  24. mindspore/common/_pijit_context.py +190 -0
  25. mindspore/common/_register_for_tensor.py +2 -1
  26. mindspore/common/_tensor_overload.py +139 -0
  27. mindspore/common/api.py +102 -87
  28. mindspore/common/dump.py +5 -6
  29. mindspore/common/generator.py +1 -7
  30. mindspore/common/hook_handle.py +14 -26
  31. mindspore/common/mindir_util.py +2 -2
  32. mindspore/common/parameter.py +46 -13
  33. mindspore/common/recompute.py +39 -9
  34. mindspore/common/sparse_tensor.py +7 -3
  35. mindspore/common/tensor.py +209 -29
  36. mindspore/communication/__init__.py +1 -1
  37. mindspore/communication/_comm_helper.py +38 -3
  38. mindspore/communication/comm_func.py +310 -55
  39. mindspore/communication/management.py +14 -14
  40. mindspore/context.py +123 -22
  41. mindspore/dataset/__init__.py +1 -1
  42. mindspore/dataset/audio/__init__.py +1 -1
  43. mindspore/dataset/core/config.py +7 -0
  44. mindspore/dataset/core/validator_helpers.py +7 -0
  45. mindspore/dataset/engine/cache_client.py +1 -1
  46. mindspore/dataset/engine/datasets.py +72 -44
  47. mindspore/dataset/engine/datasets_audio.py +7 -7
  48. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  49. mindspore/dataset/engine/datasets_text.py +20 -20
  50. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  51. mindspore/dataset/engine/datasets_vision.py +33 -33
  52. mindspore/dataset/engine/iterators.py +29 -0
  53. mindspore/dataset/engine/obs/util.py +7 -0
  54. mindspore/dataset/engine/queue.py +114 -60
  55. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  56. mindspore/dataset/engine/validators.py +34 -14
  57. mindspore/dataset/text/__init__.py +1 -4
  58. mindspore/dataset/transforms/__init__.py +0 -3
  59. mindspore/dataset/utils/line_reader.py +2 -0
  60. mindspore/dataset/vision/__init__.py +1 -4
  61. mindspore/dataset/vision/utils.py +1 -1
  62. mindspore/dataset/vision/validators.py +2 -1
  63. mindspore/dnnl.dll +0 -0
  64. mindspore/dpcmi.dll +0 -0
  65. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  66. mindspore/experimental/es/embedding_service.py +883 -0
  67. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  68. mindspore/experimental/llm_boost/__init__.py +21 -0
  69. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  70. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  71. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  72. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  73. mindspore/experimental/llm_boost/register.py +129 -0
  74. mindspore/experimental/llm_boost/utils.py +31 -0
  75. mindspore/experimental/optim/adamw.py +85 -0
  76. mindspore/experimental/optim/optimizer.py +3 -0
  77. mindspore/hal/__init__.py +3 -3
  78. mindspore/hal/contiguous_tensors_handle.py +175 -0
  79. mindspore/hal/stream.py +18 -0
  80. mindspore/include/api/model_group.h +13 -1
  81. mindspore/include/api/types.h +10 -10
  82. mindspore/include/dataset/config.h +2 -2
  83. mindspore/include/dataset/constants.h +2 -2
  84. mindspore/include/dataset/execute.h +2 -2
  85. mindspore/include/dataset/vision.h +4 -0
  86. mindspore/jpeg62.dll +0 -0
  87. mindspore/log.py +1 -1
  88. mindspore/mindrecord/filewriter.py +68 -51
  89. mindspore/mindspore_backend.dll +0 -0
  90. mindspore/mindspore_common.dll +0 -0
  91. mindspore/mindspore_core.dll +0 -0
  92. mindspore/mindspore_glog.dll +0 -0
  93. mindspore/mindspore_np_dtype.dll +0 -0
  94. mindspore/mindspore_ops.dll +0 -0
  95. mindspore/mint/__init__.py +495 -46
  96. mindspore/mint/distributed/__init__.py +31 -0
  97. mindspore/mint/distributed/distributed.py +254 -0
  98. mindspore/mint/nn/__init__.py +266 -21
  99. mindspore/mint/nn/functional.py +125 -19
  100. mindspore/mint/nn/layer/__init__.py +39 -0
  101. mindspore/mint/nn/layer/activation.py +133 -0
  102. mindspore/mint/nn/layer/normalization.py +477 -0
  103. mindspore/mint/nn/layer/pooling.py +110 -0
  104. mindspore/mint/optim/adamw.py +28 -7
  105. mindspore/mint/special/__init__.py +63 -0
  106. mindspore/msobj140.dll +0 -0
  107. mindspore/mspdb140.dll +0 -0
  108. mindspore/mspdbcore.dll +0 -0
  109. mindspore/mspdbst.dll +0 -0
  110. mindspore/mspft140.dll +0 -0
  111. mindspore/msvcdis140.dll +0 -0
  112. mindspore/msvcp140_1.dll +0 -0
  113. mindspore/msvcp140_2.dll +0 -0
  114. mindspore/msvcp140_atomic_wait.dll +0 -0
  115. mindspore/msvcp140_codecvt_ids.dll +0 -0
  116. mindspore/multiprocessing/__init__.py +2 -1
  117. mindspore/nn/__init__.py +0 -1
  118. mindspore/nn/cell.py +275 -93
  119. mindspore/nn/layer/activation.py +211 -44
  120. mindspore/nn/layer/basic.py +113 -3
  121. mindspore/nn/layer/embedding.py +120 -2
  122. mindspore/nn/layer/normalization.py +101 -5
  123. mindspore/nn/layer/padding.py +34 -48
  124. mindspore/nn/layer/pooling.py +161 -7
  125. mindspore/nn/layer/transformer.py +3 -3
  126. mindspore/nn/loss/__init__.py +2 -2
  127. mindspore/nn/loss/loss.py +84 -6
  128. mindspore/nn/optim/__init__.py +2 -1
  129. mindspore/nn/optim/adadelta.py +1 -1
  130. mindspore/nn/optim/adam.py +1 -1
  131. mindspore/nn/optim/lamb.py +1 -1
  132. mindspore/nn/optim/tft_wrapper.py +127 -0
  133. mindspore/nn/wrap/cell_wrapper.py +12 -23
  134. mindspore/nn/wrap/grad_reducer.py +5 -5
  135. mindspore/nn/wrap/loss_scale.py +17 -3
  136. mindspore/numpy/__init__.py +1 -1
  137. mindspore/numpy/array_creations.py +65 -68
  138. mindspore/numpy/array_ops.py +64 -60
  139. mindspore/numpy/fft.py +610 -75
  140. mindspore/numpy/logic_ops.py +11 -10
  141. mindspore/numpy/math_ops.py +85 -84
  142. mindspore/numpy/utils_const.py +4 -4
  143. mindspore/opencv_core452.dll +0 -0
  144. mindspore/opencv_imgcodecs452.dll +0 -0
  145. mindspore/opencv_imgproc452.dll +0 -0
  146. mindspore/ops/__init__.py +6 -4
  147. mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
  148. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  149. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  150. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  151. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  152. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
  153. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  154. mindspore/ops/auto_generate/gen_extend_func.py +734 -13
  155. mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
  156. mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
  157. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  158. mindspore/ops/composite/base.py +85 -48
  159. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  160. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  161. mindspore/ops/function/__init__.py +22 -0
  162. mindspore/ops/function/array_func.py +490 -153
  163. mindspore/ops/function/debug_func.py +113 -1
  164. mindspore/ops/function/fft_func.py +15 -2
  165. mindspore/ops/function/grad/grad_func.py +3 -2
  166. mindspore/ops/function/math_func.py +558 -207
  167. mindspore/ops/function/nn_func.py +817 -383
  168. mindspore/ops/function/other_func.py +3 -2
  169. mindspore/ops/function/random_func.py +184 -8
  170. mindspore/ops/function/reshard_func.py +13 -11
  171. mindspore/ops/function/sparse_unary_func.py +1 -1
  172. mindspore/ops/function/vmap_func.py +3 -2
  173. mindspore/ops/functional.py +24 -14
  174. mindspore/ops/op_info_register.py +3 -3
  175. mindspore/ops/operations/__init__.py +6 -1
  176. mindspore/ops/operations/_grad_ops.py +2 -76
  177. mindspore/ops/operations/_infer_ops.py +1 -1
  178. mindspore/ops/operations/_inner_ops.py +71 -94
  179. mindspore/ops/operations/array_ops.py +12 -146
  180. mindspore/ops/operations/comm_ops.py +42 -53
  181. mindspore/ops/operations/custom_ops.py +83 -19
  182. mindspore/ops/operations/debug_ops.py +42 -10
  183. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  184. mindspore/ops/operations/manually_defined/ops_def.py +265 -10
  185. mindspore/ops/operations/math_ops.py +12 -223
  186. mindspore/ops/operations/nn_ops.py +20 -114
  187. mindspore/ops/operations/other_ops.py +7 -4
  188. mindspore/ops/operations/random_ops.py +46 -1
  189. mindspore/ops/primitive.py +18 -6
  190. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  191. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  192. mindspore/ops_generate/gen_constants.py +36 -0
  193. mindspore/ops_generate/gen_ops.py +67 -52
  194. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  195. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  196. mindspore/ops_generate/op_proto.py +10 -3
  197. mindspore/ops_generate/pyboost_utils.py +14 -1
  198. mindspore/ops_generate/template.py +43 -21
  199. mindspore/parallel/__init__.py +3 -1
  200. mindspore/parallel/_auto_parallel_context.py +28 -8
  201. mindspore/parallel/_cell_wrapper.py +83 -0
  202. mindspore/parallel/_parallel_serialization.py +47 -19
  203. mindspore/parallel/_tensor.py +81 -11
  204. mindspore/parallel/_utils.py +13 -1
  205. mindspore/parallel/algo_parameter_config.py +5 -5
  206. mindspore/parallel/checkpoint_transform.py +46 -39
  207. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  208. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  209. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  210. mindspore/parallel/parameter_broadcast.py +3 -4
  211. mindspore/parallel/shard.py +162 -31
  212. mindspore/parallel/transform_safetensors.py +993 -0
  213. mindspore/pgodb140.dll +0 -0
  214. mindspore/pgort140.dll +0 -0
  215. mindspore/profiler/__init__.py +2 -1
  216. mindspore/profiler/common/constant.py +29 -0
  217. mindspore/profiler/common/registry.py +47 -0
  218. mindspore/profiler/common/util.py +28 -0
  219. mindspore/profiler/dynamic_profiler.py +694 -0
  220. mindspore/profiler/envprofiling.py +17 -19
  221. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  222. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  223. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  227. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  228. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  229. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  230. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  231. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  232. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  233. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  234. mindspore/profiler/parser/framework_parser.py +1 -391
  235. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  236. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  237. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  238. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  239. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  240. mindspore/profiler/parser/profiler_info.py +78 -6
  241. mindspore/profiler/profiler.py +153 -0
  242. mindspore/profiler/profiling.py +280 -412
  243. mindspore/rewrite/__init__.py +1 -2
  244. mindspore/rewrite/common/namespace.py +4 -4
  245. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  246. mindspore/run_check/_check_version.py +36 -103
  247. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  248. mindspore/swresample-4.dll +0 -0
  249. mindspore/swscale-6.dll +0 -0
  250. mindspore/tbbmalloc.dll +0 -0
  251. mindspore/tinyxml2.dll +0 -0
  252. mindspore/train/__init__.py +4 -3
  253. mindspore/train/_utils.py +28 -2
  254. mindspore/train/amp.py +171 -53
  255. mindspore/train/callback/__init__.py +2 -2
  256. mindspore/train/callback/_callback.py +4 -4
  257. mindspore/train/callback/_checkpoint.py +85 -22
  258. mindspore/train/callback/_cluster_monitor.py +1 -1
  259. mindspore/train/callback/_flops_collector.py +1 -0
  260. mindspore/train/callback/_loss_monitor.py +3 -3
  261. mindspore/train/callback/_on_request_exit.py +134 -31
  262. mindspore/train/callback/_summary_collector.py +5 -5
  263. mindspore/train/callback/_tft_register.py +352 -0
  264. mindspore/train/dataset_helper.py +7 -3
  265. mindspore/train/metrics/metric.py +3 -3
  266. mindspore/train/metrics/roc.py +4 -4
  267. mindspore/train/mind_ir_pb2.py +44 -39
  268. mindspore/train/model.py +134 -58
  269. mindspore/train/serialization.py +336 -112
  270. mindspore/turbojpeg.dll +0 -0
  271. mindspore/utils/__init__.py +21 -0
  272. mindspore/utils/utils.py +60 -0
  273. mindspore/vcmeta.dll +0 -0
  274. mindspore/vcruntime140.dll +0 -0
  275. mindspore/vcruntime140_1.dll +0 -0
  276. mindspore/version.py +1 -1
  277. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
  278. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +281 -275
  279. mindspore/include/c_api/ms/abstract.h +0 -67
  280. mindspore/include/c_api/ms/attribute.h +0 -197
  281. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  282. mindspore/include/c_api/ms/base/macros.h +0 -32
  283. mindspore/include/c_api/ms/base/status.h +0 -33
  284. mindspore/include/c_api/ms/base/types.h +0 -283
  285. mindspore/include/c_api/ms/context.h +0 -102
  286. mindspore/include/c_api/ms/graph.h +0 -160
  287. mindspore/include/c_api/ms/node.h +0 -606
  288. mindspore/include/c_api/ms/tensor.h +0 -161
  289. mindspore/include/c_api/ms/value.h +0 -84
  290. mindspore/mindspore_shared_lib.dll +0 -0
  291. mindspore/nn/extend/basic.py +0 -140
  292. mindspore/nn/extend/embedding.py +0 -143
  293. mindspore/nn/extend/layer/normalization.py +0 -109
  294. mindspore/nn/extend/pooling.py +0 -117
  295. mindspore/nn/layer/embedding_service.py +0 -531
  296. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  297. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  298. mindspore/ops/extend/__init__.py +0 -53
  299. mindspore/ops/extend/array_func.py +0 -218
  300. mindspore/ops/extend/math_func.py +0 -76
  301. mindspore/ops/extend/nn_func.py +0 -308
  302. mindspore/ops/silent_check.py +0 -162
  303. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  304. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  305. mindspore/train/callback/_mindio_ttp.py +0 -443
  306. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  307. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
  308. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  Transformer Cells module, include TransformerEncoderLayer, TransformerDecoderLayer,
17
17
  TransformerEncoder, TransformerDecoder, Transformer.
18
18
  """
19
+ import copy
19
20
  import math
20
21
  from typing import Union, Optional
21
22
  import mindspore
@@ -31,7 +32,6 @@ from .basic import Dense, Dropout
31
32
  from .activation import ReLU, GELU
32
33
  from .normalization import LayerNorm
33
34
  from .container import CellList
34
-
35
35
  __all__ = ['MultiheadAttention', 'TransformerEncoderLayer', 'TransformerDecoderLayer',
36
36
  'TransformerEncoder', 'TransformerDecoder', 'Transformer']
37
37
 
@@ -588,7 +588,7 @@ class TransformerEncoder(Cell):
588
588
  encoder_layer.dropout_num, encoder_layer.activation1,
589
589
  encoder_layer.layernorm_eps, encoder_layer.batch_first,
590
590
  encoder_layer.norm_first, dtype=encoder_layer.dtype)
591
- self.layers = CellList([layers for _ in range(num_layers)])
591
+ self.layers = CellList([copy.deepcopy(layers) for _ in range(num_layers)])
592
592
  self.num_layers = num_layers
593
593
  self.norm = norm
594
594
 
@@ -663,7 +663,7 @@ class TransformerDecoder(Cell):
663
663
  decoder_layer.dropout_num, decoder_layer.activation1,
664
664
  decoder_layer.layernorm_eps, decoder_layer.batch_first,
665
665
  decoder_layer.norm_first, dtype=decoder_layer.dtype)
666
- self.layers = CellList([layers for _ in range(num_layers)])
666
+ self.layers = CellList([copy.deepcopy(layers) for _ in range(num_layers)])
667
667
  self.num_layers = num_layers
668
668
  self.norm = norm
669
669
 
@@ -25,7 +25,7 @@ from mindspore.nn.loss.loss import LossBase, L1Loss, CTCLoss, MSELoss, SmoothL1L
25
25
  SampledSoftmaxLoss, TripletMarginWithDistanceLoss,\
26
26
  PoissonNLLLoss, MultiLabelSoftMarginLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss, \
27
27
  RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss, KLDivLoss, MarginRankingLoss, GaussianNLLLoss, \
28
- HingeEmbeddingLoss, MultilabelMarginLoss, TripletMarginLoss
28
+ HingeEmbeddingLoss, MultilabelMarginLoss, TripletMarginLoss, L1LossExt
29
29
 
30
30
 
31
31
  __all__ = ['LossBase', 'L1Loss', 'CTCLoss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss',
@@ -33,4 +33,4 @@ __all__ = ['LossBase', 'L1Loss', 'CTCLoss', 'MSELoss', 'SmoothL1Loss', 'SoftMarg
33
33
  'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'TripletMarginWithDistanceLoss', 'PoissonNLLLoss',
34
34
  'MultiLabelSoftMarginLoss', 'DiceLoss', 'MultiClassDiceLoss', 'MultilabelMarginLoss',
35
35
  'RMSELoss', 'MAELoss', 'HuberLoss', 'CrossEntropyLoss', 'NLLLoss', 'KLDivLoss', 'MarginRankingLoss',
36
- 'GaussianNLLLoss', 'HingeEmbeddingLoss', 'TripletMarginLoss']
36
+ 'GaussianNLLLoss', 'HingeEmbeddingLoss', 'TripletMarginLoss', 'L1LossExt']
mindspore/nn/loss/loss.py CHANGED
@@ -33,6 +33,7 @@ from mindspore.nn.cell import Cell
33
33
  from mindspore.nn.layer.activation import get_activation
34
34
  from mindspore import _checkparam as validator
35
35
  from mindspore import context
36
+ from mindspore.ops.auto_generate import l1_loss_ext_op
36
37
 
37
38
 
38
39
  class LossBase(Cell):
@@ -247,6 +248,80 @@ class L1Loss(LossBase):
247
248
  return F.l1_loss(logits, labels, self.reduction)
248
249
 
249
250
 
251
+ class L1LossExt(LossBase):
252
+ r"""
253
+ L1Loss is used to calculate the mean absolute error between the predicted value and the target value.
254
+
255
+ Assuming that the :math:`x` and :math:`y` are 1-D Tensor, length :math:`N`, then calculate the loss of :math:`x` and
256
+ :math:`y` without dimensionality reduction (the reduction parameter is set to ``'none'`` ). The formula is as
257
+ follows:
258
+
259
+ .. math::
260
+ \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad \text{with } l_n = \left| x_n - y_n \right|,
261
+
262
+ where :math:`N` is the batch size. If `reduction` is not ``'none'`` , then:
263
+
264
+ .. math::
265
+ \ell(x, y) =
266
+ \begin{cases}
267
+ \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
268
+ \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
269
+ \end{cases}
270
+
271
+ Args:
272
+ reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
273
+ ``'sum'`` . Default: ``'mean'`` .
274
+
275
+ - ``'none'``: no reduction will be applied.
276
+ - ``'mean'``: compute and return the mean of elements in the output.
277
+ - ``'sum'``: the output elements will be summed.
278
+
279
+ Inputs:
280
+ - **logits** (Tensor) - Predicted value, Tensor of any dimension.
281
+ - **labels** (Tensor) - Target value, same shape as the `logits` in common cases.
282
+ However, it supports the shape of `logits` is different from the shape of `labels`
283
+ and they should be broadcasted to each other.
284
+
285
+ Outputs:
286
+ Tensor, data type is float.
287
+
288
+ Raises:
289
+ ValueError: If `reduction` is not one of ``'none'`` , ``'mean'`` or ``'sum'`` .
290
+ ValueError: If `logits` and `labels` have different shapes and cannot be broadcasted to each other.
291
+
292
+ Supported Platforms:
293
+ ``Ascend``
294
+
295
+ Examples:
296
+ >>> import mindspore
297
+ >>> from mindspore import Tensor, nn
298
+ >>> import numpy as np
299
+ >>> # Case 1: logits.shape = labels.shape = (3,)
300
+ >>> loss = nn.L1LossExt()
301
+ >>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
302
+ >>> labels = Tensor(np.array([1, 2, 2]), mindspore.float32)
303
+ >>> output = loss(logits, labels)
304
+ >>> print(output)
305
+ 0.33333334
306
+ >>> # Case 2: logits.shape = (3,), labels.shape = (2, 3)
307
+ >>> loss = nn.L1LossExt(reduction='none')
308
+ >>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
309
+ >>> labels = Tensor(np.array([[1, 1, 1], [1, 2, 2]]), mindspore.float32)
310
+ >>> output = loss(logits, labels)
311
+ >>> print(output)
312
+ [[0. 1. 2.]
313
+ [0. 0. 1.]]
314
+ """
315
+
316
+ def __init__(self, reduction='mean'):
317
+ """Initialize L1LossExt."""
318
+ super(L1LossExt, self).__init__(reduction)
319
+ self.reduction = reduction
320
+
321
+ def construct(self, logits, labels):
322
+ return l1_loss_ext_op(logits, labels, self.reduction)
323
+
324
+
250
325
  class MSELoss(LossBase):
251
326
  r"""
252
327
  Calculates the mean squared error between the predicted value and the label value.
@@ -287,6 +362,7 @@ class MSELoss(LossBase):
287
362
  Raises:
288
363
  ValueError: If `reduction` is not one of ``'none'``, ``'mean'`` or ``'sum'``.
289
364
  ValueError: If `logits` and `labels` have different shapes and cannot be broadcasted.
365
+ TypeError: if `logits` and `labels` have different data types.
290
366
 
291
367
  Supported Platforms:
292
368
  ``Ascend`` ``GPU`` ``CPU``
@@ -1580,7 +1656,7 @@ class BCELoss(LossBase):
1580
1656
  The formula is as follow:
1581
1657
 
1582
1658
  .. math::
1583
- L = \{l_1,\dots,l_N\}^\top, \quad
1659
+ L = \{l_1,\dots,l_n,\dots,l_N\}^\top, \quad
1584
1660
  l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right]
1585
1661
 
1586
1662
  where N is the batch size. Then,
@@ -1850,14 +1926,16 @@ class BCEWithLogitsLoss(LossBase):
1850
1926
 
1851
1927
  weight (Tensor, optional): A rescaling weight applied to the loss of each batch element.
1852
1928
  If not None, it can be broadcast to a tensor with shape of `input`,
1853
- data type must be float16 or float32. Default: ``None`` .
1929
+ data type must be float16, float32 or bfloat16(only Atlas A2 series products are supported).
1930
+ Default: ``None`` .
1854
1931
  pos_weight (Tensor, optional): A weight of positive examples. Must be a vector with length equal to the
1855
1932
  number of classes. If not None, it must be broadcast to a tensor with shape of `input`, data type
1856
- must be float16 or float32. Default: ``None`` .
1933
+ must be float16, float32 or bfloat16(only Atlas A2 series products are supported). Default: ``None`` .
1857
1934
 
1858
1935
  Inputs:
1859
1936
  - **input** (Tensor) - Input `input` with shape :math:`(N, *)` where :math:`*` means, any number
1860
- of additional dimensions. The data type must be float16 or float32.
1937
+ of additional dimensions. The data type must be float16, float32 or bfloat16(only Atlas A2 series products
1938
+ are supported).
1861
1939
  - **target** (Tensor) - Ground truth label with shape :math:`(N, *)` where :math:`*` means, any number
1862
1940
  of additional dimensions. The same shape and data type as `input`.
1863
1941
 
@@ -1867,9 +1945,9 @@ class BCEWithLogitsLoss(LossBase):
1867
1945
 
1868
1946
  Raises:
1869
1947
  TypeError: If input `input` or `target` is not Tensor.
1870
- TypeError: If data type of `input` or `target` is neither float16 nor float32.
1948
+ TypeError: If data type of `input` or `target` is not float16, float32 or bfloat16.
1871
1949
  TypeError: If `weight` or `pos_weight` is a parameter.
1872
- TypeError: If data type of `weight` or `pos_weight` is neither float16 nor float32.
1950
+ TypeError: If data type of `weight` or `pos_weight` is not float16 , float32 or bfloat16.
1873
1951
  TypeError: If data type of `reduction` is not string.
1874
1952
  ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `input`.
1875
1953
  ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
@@ -38,7 +38,8 @@ from mindspore.nn.optim.adafactor import AdaFactor
38
38
  from mindspore.nn.optim.adasum import AdaSumByDeltaWeightWrapCell, AdaSumByGradWrapCell
39
39
  from mindspore.nn.optim.adamax import AdaMax
40
40
  from mindspore.nn.optim.adadelta import Adadelta
41
+ from mindspore.nn.optim.tft_wrapper import OptTFTWrapper
41
42
 
42
43
  __all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', 'AdamOffload',
43
44
  'Lamb', 'SGD', 'ASGD', 'Rprop', 'FTRL', 'RMSProp', 'ProximalAdagrad', 'Adagrad', 'thor', 'AdaFactor',
44
- 'AdaSumByDeltaWeightWrapCell', 'AdaSumByGradWrapCell', 'AdaMax', 'Adadelta']
45
+ 'AdaSumByDeltaWeightWrapCell', 'AdaSumByGradWrapCell', 'AdaMax', 'Adadelta', 'OptTFTWrapper']
@@ -55,7 +55,7 @@ class Adadelta(Optimizer):
55
55
  w_{t} = w_{t-1} - \gamma * update_{t}
56
56
  \end{array}
57
57
 
58
- where :math:`g` represents `grads`, :math:`\gamma` represents `learning_rate`, :math:`p` represents `rho`,
58
+ where :math:`g` represents `grads`, :math:`\gamma` represents `learning_rate`, :math:`\rho` represents `rho`,
59
59
  :math:`\epsilon` represents `epsilon`, :math:`w` represents `params`,
60
60
  :math:`accum` represents accumulation, :math:`accum\_update` represents accumulation update,
61
61
  :math:`t` represents current step.
@@ -906,7 +906,7 @@ class AdamWeightDecay(Optimizer):
906
906
  There is usually no connection between a optimizer and mixed precision. But when `FixedLossScaleManager` is used
907
907
  and `drop_overflow_update` in `FixedLossScaleManager` is set to False, optimizer needs to set the 'loss_scale'.
908
908
  As this optimizer has no argument of `loss_scale`, so `loss_scale` needs to be processed by other means, refer
909
- document `LossScale <https://www.mindspore.cn/tutorials/en/master/advanced/mixed_precision.html>`_ to
909
+ document `LossScale <https://www.mindspore.cn/tutorials/en/master/beginner/mixed_precision.html>`_ to
910
910
  process `loss_scale` correctly.
911
911
 
912
912
  If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without
@@ -132,7 +132,7 @@ class Lamb(Optimizer):
132
132
  There is usually no connection between a optimizer and mixed precision. But when `FixedLossScaleManager` is used
133
133
  and `drop_overflow_update` in `FixedLossScaleManager` is set to False, optimizer needs to set the 'loss_scale'.
134
134
  As this optimizer has no argument of `loss_scale`, so `loss_scale` needs to be processed by other means. Refer
135
- document `LossScale <https://www.mindspore.cn/tutorials/en/master/advanced/mixed_precision.html>`_ to
135
+ document `LossScale <https://www.mindspore.cn/tutorials/en/master/beginner/mixed_precision.html>`_ to
136
136
  process `loss_scale` correctly.
137
137
 
138
138
  If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without
@@ -0,0 +1,127 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """OptTFTWrapper"""
16
+ from __future__ import absolute_import
17
+
18
+ import os
19
+ from mindspore.common.tensor import Tensor
20
+ from mindspore.nn.optim.optimizer import Optimizer
21
+ from mindspore.ops.operations.manually_defined._inner import TensorReport
22
+ from mindspore import ops, context
23
+
24
+ class OptTFTWrapper(Optimizer):
25
+ r"""
26
+ Implements TFT optimizer wrapper, this wrapper is used to report status to MindIO TFT before optimizer updating.
27
+
28
+ Note:
29
+ This optimizer is depend on MindIO TFT feature. Currently only support ascend graph mode and
30
+ sink_size must be less than 1.
31
+
32
+ Args:
33
+ opt (Optimizer): Must be sub-class of Optimizer.
34
+
35
+ Inputs:
36
+ - **gradients** (tuple[Tensor]) - The gradients of opt's `params`, the shape is the same as opt's `params`.
37
+
38
+ Outputs:
39
+ Tensor, result of executing optimizer 'opt'.
40
+
41
+ Raises:
42
+ TypeError: If the parameter opt is not an subclass of Optimizer.
43
+ ValueError: If the platform is not Ascend graph mode, or customer doesn't switch on TFT feature.
44
+
45
+ Supported Platforms:
46
+ ``Ascend``
47
+
48
+ Examples:
49
+ >>> import mindspore as ms
50
+ >>> from mindspore import nn
51
+ >>>
52
+ >>> # Define the network structure of LeNet5. Refer to
53
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
54
+ >>> net = LeNet5()
55
+ >>> #1) All parameters use the same learning rate and weight decay
56
+ >>> optim = nn.SGD(params=net.trainable_params())
57
+ >>> optim_wrapper = nn.OptTFTWrapper(optim)
58
+ >>>
59
+ >>> loss = nn.SoftmaxCrossEntropyWithLogits()
60
+ >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim)
61
+ """
62
+
63
+ def __init__(self, opt, **kwargs):
64
+ super(OptTFTWrapper, self).__init__(opt.learning_rate, opt._parameters) # pylint: disable=W0212
65
+ if not isinstance(opt, Optimizer):
66
+ raise TypeError(f"For 'OptTFTWrapper', the argument 'opt' must be Optimizer type, " f"but got {type(opt)}.")
67
+ tft_env = os.getenv("MS_ENABLE_TFT", "")
68
+ if ("TTP:1" not in tft_env) and ("UCE:1" not in tft_env):
69
+ raise ValueError("MindIO TFT regitster need custom switch on[MS_ENABLE_TFT='{TTP:1,UCE:1}']!")
70
+ mode = context.get_context("mode")
71
+ device_target = context.get_context("device_target")
72
+ if device_target != "Ascend" or mode != context.GRAPH_MODE:
73
+ raise ValueError("MindIO adataper only support on Ascend device with GRAPH Mode!")
74
+ self.opt = opt
75
+ self.report = TensorReport()
76
+ self.depend = ops.Depend()
77
+ self.g_one = Tensor([0.1])
78
+ # enable consistent check by default, only disable when enable_consistent_check is False
79
+ self.use_allreduce = kwargs.get("enable_consistent_check", True)
80
+
81
+ if self.use_allreduce:
82
+ self.allreduce_sum = ops.AllReduce()
83
+ self.allreduce_sum.add_prim_attr("tft_report_before", True)
84
+
85
+ self.param_rank = opt.param_rank
86
+ self.optim_filter = opt.optim_filter
87
+ self.loss_scale = opt.loss_scale
88
+ self.dynamic_weight_decay = opt.dynamic_weight_decay
89
+ self.grad_centralization = opt.grad_centralization
90
+
91
+ self.dynamic_lr = opt.dynamic_lr
92
+ self.global_step = opt.global_step
93
+ self.is_group = opt.is_group
94
+ self.is_group_lr = opt.is_group_lr
95
+ self.is_group_params_ordered = opt.is_group_params_ordered
96
+ self.use_parallel = opt.use_parallel
97
+ if self.is_group:
98
+ self.group_params = opt.group_params
99
+ self.group_lr = opt.group_lr
100
+ self.group_weight_decay = opt.group_weight_decay
101
+ self.group_grad_centralization = opt.group_grad_centralization
102
+ self.grad_centralization_flags = opt.grad_centralization_flags
103
+
104
+ self.skip_auto_parallel_compile = opt.skip_auto_parallel_compile
105
+
106
+ self.learning_rate = opt.learning_rate
107
+ self.parameters = opt.parameters
108
+ self.decay_flags = opt.decay_flags
109
+ self.dynamic_decay_flags = opt.dynamic_decay_flags
110
+ self.weight_decay = opt.weight_decay
111
+ self.exec_weight_decay = opt.exec_weight_decay
112
+ self.ps_parameters = opt.ps_parameters
113
+ self.cache_enable = opt.cache_enable
114
+ self.reciprocal_scale = opt.reciprocal_scale
115
+ self.need_scale = opt.need_scale
116
+ self.global_step_increase_tensor = opt.global_step_increase_tensor
117
+ self.param_length = opt.param_length
118
+ self.enable_tuple_broaden = opt.enable_tuple_broaden
119
+
120
+ def construct(self, gradients):
121
+ g_one = self.depend(self.g_one, gradients)
122
+ if self.use_allreduce is True:
123
+ g_one_res = self.allreduce_sum(g_one)
124
+ else:
125
+ g_one_res = g_one
126
+ self.report("tft_report", g_one_res)
127
+ return self.opt(gradients)
@@ -23,7 +23,7 @@ from types import FunctionType, MethodType
23
23
  from mindspore import log as logger
24
24
  from mindspore.parallel._utils import _get_device_num, _get_gradients_mean,\
25
25
  _get_parallel_mode, _get_enable_parallel_optimizer, _is_pynative_parallel
26
- from mindspore.context import ParallelMode, GRAPH_MODE, get_context
26
+ from mindspore.context import ParallelMode
27
27
  from mindspore import _checkparam as validator
28
28
  from mindspore import ops, nn
29
29
  from mindspore.common import dtype as mstype
@@ -36,6 +36,7 @@ from mindspore.ops import operations as P
36
36
  from mindspore.ops.operations.comm_ops import _VirtualDataset
37
37
  from mindspore.nn.cell import Cell
38
38
  from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
39
+ from mindspore.utils import ExitByRequest
39
40
 
40
41
  _get_datatype = C.MultitypeFuncGraph("_get_datatype")
41
42
 
@@ -414,6 +415,11 @@ class TrainOneStepCell(Cell):
414
415
  group = server_group_name
415
416
  self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree, group=group)
416
417
  self._get_attr_from_cell(network)
418
+ self.use_graceful_exit = os.environ.get("MS_ENABLE_GRACEFUL_EXIT") == "1"
419
+ if self.use_graceful_exit:
420
+ self.graceful_exit = ExitByRequest()
421
+ self.exit_param = Parameter(Tensor(False, mstype.bool_), name="graceful_exit") # update by reduce value
422
+ self.init_param = Parameter(Tensor([0], mstype.int32), name="graceful_init") # update by config file
417
423
 
418
424
  def construct(self, *inputs):
419
425
  if not self.sense_flag:
@@ -422,6 +428,8 @@ class TrainOneStepCell(Cell):
422
428
  sens = F.fill(loss.dtype, loss.shape, self.sens)
423
429
  grads = self.grad(self.network, self.weights)(*inputs, sens)
424
430
  grads = self.grad_reducer(grads)
431
+ if self.use_graceful_exit:
432
+ grads = self.graceful_exit.exit_by_request(grads, self.init_param, self.exit_param)
425
433
  loss = F.depend(loss, self.optimizer(grads))
426
434
  if self.return_grad:
427
435
  grad_with_param_name = {}
@@ -435,6 +443,8 @@ class TrainOneStepCell(Cell):
435
443
  loss = self.network(*inputs)
436
444
  grads = self.grad_no_sens(self.network, self.weights)(*inputs)
437
445
  grads = self.grad_reducer(grads)
446
+ if self.use_graceful_exit:
447
+ grads = self.graceful_exit.exit_by_request(grads, self.init_param, self.exit_param)
438
448
  loss = F.depend(loss, self.optimizer(grads))
439
449
  if self.return_grad:
440
450
  grad_with_param_name = {}
@@ -742,18 +752,7 @@ class _TrainGradAccuStepCell(TrainOneStepCell):
742
752
  self.hyper_map = ops.HyperMap()
743
753
  self.opt_shard = _get_enable_parallel_optimizer()
744
754
  self._get_attr_from_cell(network)
745
- self.enable_mindio = False
746
- mode = get_context("mode")
747
- device_type = get_context("device_target")
748
- if device_type != "Ascend" or mode != GRAPH_MODE:
749
- return
750
- graceful_exit = os.getenv("MS_ENABLE_MINDIO_GRACEFUL_EXIT")
751
- ttp_lib_path = os.getenv("MS_MINDIO_TTP_LIB_PATH")
752
- ttp_path_check = ttp_lib_path is not None and os.path.isfile(ttp_lib_path)
753
- if graceful_exit == "true" and ttp_path_check:
754
- self.g_one = Tensor([0.1])
755
- self.allreduce_sum = ops.AllReduce()
756
- self.enable_mindio = True
755
+ self.enable_tft = False
757
756
 
758
757
  def construct(self, *inputs):
759
758
  if not self.sense_flag:
@@ -762,11 +761,6 @@ class _TrainGradAccuStepCell(TrainOneStepCell):
762
761
  sens = ops.fill(ops.DType()(loss), ops.Shape()(loss), self.sens)
763
762
  grads = self.grad(self.network, self.weights)(*inputs, sens)
764
763
  accu_grads = ops.depend(self.accu_grads, grads)
765
- if self.enable_mindio:
766
- g_one = ops.depend(self.g_one, accu_grads)
767
- g_one_res = self.allreduce_sum(g_one)
768
- accu_grads = ops.depend(accu_grads, g_one_res)
769
- grads = ops.depend(grads, g_one_res)
770
764
  if self.opt_shard:
771
765
  succ = self.optimizer(grads)
772
766
  else:
@@ -781,11 +775,6 @@ class _TrainGradAccuStepCell(TrainOneStepCell):
781
775
  loss = self.network(*inputs)
782
776
  grads = self.grad_no_sens(self.network, self.weights)(*inputs)
783
777
  accu_grads = ops.depend(self.accu_grads, grads)
784
- if self.enable_mindio:
785
- g_one = ops.depend(self.g_one, accu_grads)
786
- g_one_res = self.allreduce_sum(g_one)
787
- accu_grads = ops.depend(accu_grads, g_one_res)
788
- grads = ops.depend(grads, g_one_res)
789
778
  if self.opt_shard:
790
779
  succ = self.optimizer(grads)
791
780
  else:
@@ -335,14 +335,14 @@ class DistributedGradReducer(Cell):
335
335
 
336
336
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
337
337
  Please see the `rank table Startup
338
- <https://www.mindspore.cn/tutorials/experts/en/master/parallel/rank_table.html>`_
338
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
339
339
  for more details.
340
340
 
341
341
  For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
342
- <https://www.mindspore.cn/tutorials/experts/en/master/parallel/mpirun.html>`_ .
342
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
343
343
 
344
344
  For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
345
- Startup <https://www.mindspore.cn/tutorials/experts/en/master/parallel/dynamic_cluster.html>`_ .
345
+ Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
346
346
 
347
347
  This example should be run with multiple devices.
348
348
 
@@ -509,11 +509,11 @@ class PipelineGradReducer(Cell):
509
509
 
510
510
  For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
511
511
  Please see the `rank table Startup
512
- <https://www.mindspore.cn/tutorials/experts/en/master/parallel/rank_table.html>`_
512
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
513
513
  for more details.
514
514
 
515
515
  For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
516
- <https://www.mindspore.cn/tutorials/experts/en/master/parallel/mpirun.html>`_ .
516
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
517
517
 
518
518
  This example should be run with multiple devices.
519
519
 
@@ -33,6 +33,8 @@ from mindspore.ops.operations.nn_ops import AllFinite
33
33
  from mindspore.common import dtype as mstype
34
34
  from mindspore.common.api import jit
35
35
  from mindspore._c_expression import MSContext
36
+ from mindspore.run_check._check_version import AscendEnvChecker
37
+ from mindspore import log as logger
36
38
 
37
39
  _grad_scale = C.MultitypeFuncGraph("grad_scale")
38
40
  reciprocal = P.Reciprocal()
@@ -49,6 +51,7 @@ def tensor_grad_scale_row_tensor(scale, grad):
49
51
  grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)),
50
52
  grad.dense_shape)
51
53
 
54
+
52
55
  _grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
53
56
  grad_overflow = P.FloatStatus()
54
57
 
@@ -355,6 +358,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
355
358
  >>> train_network.set_sense_scale(scaling_sens)
356
359
  >>> output = train_network(inputs, label)
357
360
  """
361
+
358
362
  def __init__(self, network, optimizer, scale_sense):
359
363
  super(TrainOneStepWithLossScaleCell, self).__init__(network, optimizer, sens=None)
360
364
  self.hyper_map = C.HyperMap()
@@ -369,7 +373,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
369
373
  self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
370
374
  self.gpu_target = (context.get_context("device_target") == "GPU")
371
375
  self.ascend_910a_target = (MSContext.get_instance().get_ascend_soc_version() == 'ascend910')
372
- self.ascend_910bc_target = (MSContext.get_instance().get_ascend_soc_version() in ['ascend910b', 'ascend910c'])
376
+ self.ascend_910b_target = (MSContext.get_instance().get_ascend_soc_version() in ['ascend910b', 'ascend910_93'])
373
377
  self.loss_scaling_manager = None
374
378
  self._ascend_check_overflow_mode = os.environ.get('MS_ASCEND_CHECK_OVERFLOW_MODE')
375
379
 
@@ -377,12 +381,21 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
377
381
  runtime_conf = os.environ.get('MS_DEV_RUNTIME_CONF')
378
382
  global_jit_config = context.get_jit_config()
379
383
  if runtime_conf is not None and ("all_finite:True" in runtime_conf or "all_finite:true" in runtime_conf):
384
+ logger.debug("Enable AllFinite through the environment variable MS_DEV_RUNTIME_CONF.")
380
385
  self.enable_allfinite = True
381
386
  elif runtime_conf is not None and ("all_finite:False" in runtime_conf or "all_finite:false" in runtime_conf):
387
+ logger.debug("Disable AllFinite through the environment variable MS_DEV_RUNTIME_CONF.")
382
388
  self.enable_allfinite = False
383
389
  elif global_jit_config:
390
+ logger.debug("Current global jit config is: {}".format(global_jit_config["jit_level"]))
384
391
  self.enable_allfinite = global_jit_config["jit_level"] == "O0" or global_jit_config["jit_level"] == "O1"
385
392
 
393
+ if self.ascend_910b_target:
394
+ checker = AscendEnvChecker(None)
395
+ if not checker.check_custom_version():
396
+ logger.debug("Disable AllFinite due to version check failure.")
397
+ self.enable_allfinite = False
398
+
386
399
  if isinstance(scale_sense, Cell):
387
400
  self.loss_scaling_manager = scale_sense
388
401
  self.scale_sense = Parameter(Tensor(scale_sense.get_loss_scale(), dtype=mstype.float32),
@@ -460,7 +473,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
460
473
  is cleaned up when the function returns.
461
474
  """
462
475
  status = Tensor([0] * 8, mstype.int32)
463
- if self.ascend_910a_target or (self.ascend_910bc_target and \
476
+ if self.ascend_910a_target or (self.ascend_910b_target and \
464
477
  self._ascend_check_overflow_mode == "SATURATION_MODE"):
465
478
  status = F.depend(status, pre_cond)
466
479
  # clear overflow buffer
@@ -554,7 +567,7 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell):
554
567
  """
555
568
  if self.gpu_target:
556
569
  overflow = self._get_gpu_overflow_status(compute_output)
557
- elif self.ascend_910bc_target:
570
+ elif self.ascend_910b_target:
558
571
  if self._ascend_check_overflow_mode == "SATURATION_MODE":
559
572
  overflow = self._get_ascend_overflow_status_on_saturation_mode(status, compute_output)
560
573
  else:
@@ -613,6 +626,7 @@ class _TrainGradAccuWithLossScaleCell(TrainOneStepCell):
613
626
  optimizer (Optimizer): Optimizer for updating the weights.
614
627
  scale_sense (Cell): Cell to do the loss scale.
615
628
  """
629
+
616
630
  def __init__(self, network, optimizer, scale_sense):
617
631
  super(_TrainGradAccuWithLossScaleCell, self).__init__(network, optimizer, sens=None)
618
632
  self.network = network
@@ -64,7 +64,7 @@ from mindspore.numpy.logic_ops import (not_equal, less_equal, less, greater_equa
64
64
  logical_or, logical_xor, in1d, isin, isclose, signbit, sometrue,
65
65
  array_equal, array_equiv, setdiff1d)
66
66
 
67
- from . import fft
67
+ from mindspore.numpy import fft
68
68
 
69
69
  mod = remainder
70
70
  fabs = absolute