mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (287) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/initializer.py +51 -15
  26. mindspore/common/mindir_util.py +2 -2
  27. mindspore/common/parameter.py +62 -15
  28. mindspore/common/recompute.py +39 -9
  29. mindspore/common/sparse_tensor.py +7 -3
  30. mindspore/common/tensor.py +183 -37
  31. mindspore/communication/__init__.py +1 -1
  32. mindspore/communication/_comm_helper.py +38 -3
  33. mindspore/communication/comm_func.py +315 -60
  34. mindspore/communication/management.py +14 -14
  35. mindspore/context.py +132 -22
  36. mindspore/dataset/__init__.py +1 -1
  37. mindspore/dataset/audio/__init__.py +1 -1
  38. mindspore/dataset/core/config.py +7 -0
  39. mindspore/dataset/core/validator_helpers.py +7 -0
  40. mindspore/dataset/engine/cache_client.py +1 -1
  41. mindspore/dataset/engine/datasets.py +72 -44
  42. mindspore/dataset/engine/datasets_audio.py +7 -7
  43. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  44. mindspore/dataset/engine/datasets_text.py +20 -20
  45. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  46. mindspore/dataset/engine/datasets_vision.py +33 -33
  47. mindspore/dataset/engine/iterators.py +29 -0
  48. mindspore/dataset/engine/obs/util.py +7 -0
  49. mindspore/dataset/engine/queue.py +114 -60
  50. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  51. mindspore/dataset/engine/validators.py +34 -14
  52. mindspore/dataset/text/__init__.py +1 -4
  53. mindspore/dataset/transforms/__init__.py +0 -3
  54. mindspore/dataset/utils/line_reader.py +2 -0
  55. mindspore/dataset/vision/__init__.py +1 -4
  56. mindspore/dataset/vision/utils.py +1 -1
  57. mindspore/dataset/vision/validators.py +2 -1
  58. mindspore/dnnl.dll +0 -0
  59. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  60. mindspore/experimental/es/embedding_service.py +883 -0
  61. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  62. mindspore/experimental/llm_boost/__init__.py +21 -0
  63. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  64. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  65. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  66. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  67. mindspore/experimental/llm_boost/register.py +129 -0
  68. mindspore/experimental/llm_boost/utils.py +31 -0
  69. mindspore/experimental/optim/adamw.py +85 -0
  70. mindspore/experimental/optim/optimizer.py +3 -0
  71. mindspore/hal/__init__.py +3 -3
  72. mindspore/hal/contiguous_tensors_handle.py +175 -0
  73. mindspore/hal/stream.py +18 -0
  74. mindspore/include/api/model_group.h +13 -1
  75. mindspore/include/api/types.h +10 -10
  76. mindspore/include/dataset/config.h +2 -2
  77. mindspore/include/dataset/constants.h +2 -2
  78. mindspore/include/dataset/execute.h +2 -2
  79. mindspore/include/dataset/vision.h +4 -0
  80. mindspore/jpeg62.dll +0 -0
  81. mindspore/log.py +1 -1
  82. mindspore/mindrecord/filewriter.py +68 -51
  83. mindspore/mindspore_backend.dll +0 -0
  84. mindspore/mindspore_common.dll +0 -0
  85. mindspore/mindspore_core.dll +0 -0
  86. mindspore/mindspore_glog.dll +0 -0
  87. mindspore/mindspore_np_dtype.dll +0 -0
  88. mindspore/mindspore_ops.dll +0 -0
  89. mindspore/mint/__init__.py +983 -46
  90. mindspore/mint/distributed/__init__.py +31 -0
  91. mindspore/mint/distributed/distributed.py +254 -0
  92. mindspore/mint/nn/__init__.py +268 -23
  93. mindspore/mint/nn/functional.py +125 -19
  94. mindspore/mint/nn/layer/__init__.py +39 -0
  95. mindspore/mint/nn/layer/activation.py +133 -0
  96. mindspore/mint/nn/layer/normalization.py +477 -0
  97. mindspore/mint/nn/layer/pooling.py +110 -0
  98. mindspore/mint/optim/adamw.py +26 -13
  99. mindspore/mint/special/__init__.py +63 -0
  100. mindspore/multiprocessing/__init__.py +2 -1
  101. mindspore/nn/__init__.py +0 -1
  102. mindspore/nn/cell.py +276 -96
  103. mindspore/nn/layer/activation.py +211 -44
  104. mindspore/nn/layer/basic.py +137 -10
  105. mindspore/nn/layer/embedding.py +137 -2
  106. mindspore/nn/layer/normalization.py +101 -5
  107. mindspore/nn/layer/padding.py +34 -48
  108. mindspore/nn/layer/pooling.py +161 -7
  109. mindspore/nn/layer/transformer.py +3 -3
  110. mindspore/nn/loss/__init__.py +2 -2
  111. mindspore/nn/loss/loss.py +84 -6
  112. mindspore/nn/optim/__init__.py +2 -1
  113. mindspore/nn/optim/adadelta.py +1 -1
  114. mindspore/nn/optim/adam.py +1 -1
  115. mindspore/nn/optim/lamb.py +1 -1
  116. mindspore/nn/optim/tft_wrapper.py +124 -0
  117. mindspore/nn/wrap/cell_wrapper.py +12 -23
  118. mindspore/nn/wrap/grad_reducer.py +5 -5
  119. mindspore/nn/wrap/loss_scale.py +17 -3
  120. mindspore/numpy/__init__.py +1 -1
  121. mindspore/numpy/array_creations.py +65 -68
  122. mindspore/numpy/array_ops.py +64 -60
  123. mindspore/numpy/fft.py +610 -75
  124. mindspore/numpy/logic_ops.py +11 -10
  125. mindspore/numpy/math_ops.py +85 -84
  126. mindspore/numpy/utils_const.py +4 -4
  127. mindspore/opencv_core452.dll +0 -0
  128. mindspore/opencv_imgcodecs452.dll +0 -0
  129. mindspore/opencv_imgproc452.dll +0 -0
  130. mindspore/ops/__init__.py +6 -4
  131. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  132. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  133. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  134. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  135. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  136. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  137. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  138. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  139. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  140. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  141. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  142. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  143. mindspore/ops/composite/base.py +85 -48
  144. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  145. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  146. mindspore/ops/function/__init__.py +22 -0
  147. mindspore/ops/function/array_func.py +492 -153
  148. mindspore/ops/function/debug_func.py +113 -1
  149. mindspore/ops/function/fft_func.py +15 -2
  150. mindspore/ops/function/grad/grad_func.py +3 -2
  151. mindspore/ops/function/math_func.py +564 -207
  152. mindspore/ops/function/nn_func.py +817 -383
  153. mindspore/ops/function/other_func.py +3 -2
  154. mindspore/ops/function/random_func.py +402 -12
  155. mindspore/ops/function/reshard_func.py +13 -11
  156. mindspore/ops/function/sparse_unary_func.py +1 -1
  157. mindspore/ops/function/vmap_func.py +3 -2
  158. mindspore/ops/functional.py +24 -14
  159. mindspore/ops/op_info_register.py +3 -3
  160. mindspore/ops/operations/__init__.py +7 -2
  161. mindspore/ops/operations/_grad_ops.py +2 -76
  162. mindspore/ops/operations/_infer_ops.py +1 -1
  163. mindspore/ops/operations/_inner_ops.py +71 -94
  164. mindspore/ops/operations/array_ops.py +14 -146
  165. mindspore/ops/operations/comm_ops.py +63 -53
  166. mindspore/ops/operations/custom_ops.py +83 -19
  167. mindspore/ops/operations/debug_ops.py +42 -10
  168. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  169. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  170. mindspore/ops/operations/math_ops.py +12 -223
  171. mindspore/ops/operations/nn_ops.py +20 -114
  172. mindspore/ops/operations/other_ops.py +7 -4
  173. mindspore/ops/operations/random_ops.py +46 -1
  174. mindspore/ops/primitive.py +18 -6
  175. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  176. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  177. mindspore/ops_generate/gen_constants.py +36 -0
  178. mindspore/ops_generate/gen_ops.py +67 -52
  179. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  180. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  181. mindspore/ops_generate/op_proto.py +10 -3
  182. mindspore/ops_generate/pyboost_utils.py +14 -1
  183. mindspore/ops_generate/template.py +43 -21
  184. mindspore/parallel/__init__.py +3 -1
  185. mindspore/parallel/_auto_parallel_context.py +31 -9
  186. mindspore/parallel/_cell_wrapper.py +85 -0
  187. mindspore/parallel/_parallel_serialization.py +47 -19
  188. mindspore/parallel/_tensor.py +127 -13
  189. mindspore/parallel/_utils.py +53 -22
  190. mindspore/parallel/algo_parameter_config.py +5 -5
  191. mindspore/parallel/checkpoint_transform.py +46 -39
  192. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  193. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  194. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  195. mindspore/parallel/parameter_broadcast.py +3 -4
  196. mindspore/parallel/shard.py +162 -31
  197. mindspore/parallel/transform_safetensors.py +1146 -0
  198. mindspore/profiler/__init__.py +2 -1
  199. mindspore/profiler/common/constant.py +29 -0
  200. mindspore/profiler/common/registry.py +47 -0
  201. mindspore/profiler/common/util.py +28 -0
  202. mindspore/profiler/dynamic_profiler.py +694 -0
  203. mindspore/profiler/envprofiling.py +17 -19
  204. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  205. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  206. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  207. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  208. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  209. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  210. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  211. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  212. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  213. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  214. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  215. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  216. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  217. mindspore/profiler/parser/framework_parser.py +1 -391
  218. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  219. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  220. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  221. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  222. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  223. mindspore/profiler/parser/profiler_info.py +78 -6
  224. mindspore/profiler/profiler.py +153 -0
  225. mindspore/profiler/profiling.py +285 -413
  226. mindspore/rewrite/__init__.py +1 -2
  227. mindspore/rewrite/common/namespace.py +4 -4
  228. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  229. mindspore/run_check/_check_version.py +39 -104
  230. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  231. mindspore/swresample-4.dll +0 -0
  232. mindspore/swscale-6.dll +0 -0
  233. mindspore/tinyxml2.dll +0 -0
  234. mindspore/train/__init__.py +4 -3
  235. mindspore/train/_utils.py +105 -19
  236. mindspore/train/amp.py +171 -53
  237. mindspore/train/callback/__init__.py +2 -2
  238. mindspore/train/callback/_callback.py +4 -4
  239. mindspore/train/callback/_checkpoint.py +97 -31
  240. mindspore/train/callback/_cluster_monitor.py +1 -1
  241. mindspore/train/callback/_flops_collector.py +1 -0
  242. mindspore/train/callback/_loss_monitor.py +3 -3
  243. mindspore/train/callback/_on_request_exit.py +145 -31
  244. mindspore/train/callback/_summary_collector.py +5 -5
  245. mindspore/train/callback/_tft_register.py +375 -0
  246. mindspore/train/dataset_helper.py +15 -3
  247. mindspore/train/metrics/metric.py +3 -3
  248. mindspore/train/metrics/roc.py +4 -4
  249. mindspore/train/mind_ir_pb2.py +44 -39
  250. mindspore/train/model.py +154 -58
  251. mindspore/train/serialization.py +342 -128
  252. mindspore/turbojpeg.dll +0 -0
  253. mindspore/utils/__init__.py +21 -0
  254. mindspore/utils/utils.py +60 -0
  255. mindspore/version.py +1 -1
  256. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  257. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
  258. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
  259. mindspore/include/c_api/ms/abstract.h +0 -67
  260. mindspore/include/c_api/ms/attribute.h +0 -197
  261. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  262. mindspore/include/c_api/ms/base/macros.h +0 -32
  263. mindspore/include/c_api/ms/base/status.h +0 -33
  264. mindspore/include/c_api/ms/base/types.h +0 -283
  265. mindspore/include/c_api/ms/context.h +0 -102
  266. mindspore/include/c_api/ms/graph.h +0 -160
  267. mindspore/include/c_api/ms/node.h +0 -606
  268. mindspore/include/c_api/ms/tensor.h +0 -161
  269. mindspore/include/c_api/ms/value.h +0 -84
  270. mindspore/mindspore_shared_lib.dll +0 -0
  271. mindspore/nn/extend/basic.py +0 -140
  272. mindspore/nn/extend/embedding.py +0 -143
  273. mindspore/nn/extend/layer/normalization.py +0 -109
  274. mindspore/nn/extend/pooling.py +0 -117
  275. mindspore/nn/layer/embedding_service.py +0 -531
  276. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  277. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  278. mindspore/ops/extend/__init__.py +0 -53
  279. mindspore/ops/extend/array_func.py +0 -218
  280. mindspore/ops/extend/math_func.py +0 -76
  281. mindspore/ops/extend/nn_func.py +0 -308
  282. mindspore/ops/silent_check.py +0 -162
  283. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  284. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  285. mindspore/train/callback/_mindio_ttp.py +0 -443
  286. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  287. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ from __future__ import absolute_import
18
18
  from mindspore.ops import operations as P
19
19
  from mindspore.ops import functional as F
20
20
  import mindspore.ops as ops
21
+ from mindspore.ops.function.nn_func import avg_pool2d_ext
21
22
  from mindspore._checkparam import _check_3d_int_or_tuple
22
23
  from mindspore import _checkparam as validator
23
24
  from mindspore.ops.primitive import constexpr, _primexpr
@@ -26,13 +27,14 @@ import mindspore.context as context
26
27
  from mindspore.common import dtype as mstype
27
28
  from mindspore.ops.operations.nn_ops import AdaptiveMaxPool2D
28
29
  from mindspore.ops.operations.nn_ops import AdaptiveMaxPool3D, AdaptiveAvgPool3D
30
+ from mindspore.ops.auto_generate.gen_ops_prim import MaxPoolWithIndices, MaxPoolWithMask
29
31
  from mindspore.nn.cell import Cell
30
32
  from mindspore._c_expression import MSContext
31
33
 
32
34
  __all__ = ['AvgPool3d', 'MaxPool3d', 'AvgPool2d', 'MaxPool2d', 'AvgPool1d', 'MaxPool1d', 'FractionalMaxPool2d',
33
35
  'FractionalMaxPool3d', 'AdaptiveAvgPool1d', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d',
34
36
  'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'LPPool1d',
35
- 'LPPool2d']
37
+ 'LPPool2d', 'AvgPool2dExt', 'MaxPool2dExt']
36
38
 
37
39
 
38
40
  class _PoolNd(Cell):
@@ -96,6 +98,9 @@ class LPPool1d(Cell):
96
98
  .. math::
97
99
  f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
98
100
 
101
+ Note:
102
+ This interface currently does not support Atlas A2 training series products.
103
+
99
104
  Args:
100
105
  norm_type (Union[int, float]): Type of normalization, represents :math:`p` in the formula, can not be 0.
101
106
 
@@ -169,6 +174,9 @@ class LPPool2d(Cell):
169
174
  .. math::
170
175
  f(X) = \sqrt[p]{\sum_{x \in X} x^{p}}
171
176
 
177
+ Note:
178
+ This interface currently does not support Atlas A2 training series products.
179
+
172
180
  Args:
173
181
  norm_type(Union[int, float]): Type of normalization, represents :math:`p` in the formula, can not be 0.
174
182
 
@@ -374,6 +382,7 @@ class MaxPool3d(_PoolNd):
374
382
  Examples:
375
383
  >>> import mindspore as ms
376
384
  >>> import mindspore.nn as nn
385
+ >>> from mindspore import Tensor
377
386
  >>> import numpy as np
378
387
  >>> np_x = np.random.randint(0, 10, [5, 3, 4, 6, 7])
379
388
  >>> x = Tensor(np_x, ms.float32)
@@ -592,6 +601,102 @@ class MaxPool2d(_PoolNd):
592
601
  return out
593
602
 
594
603
 
604
+ class MaxPool2dExt(Cell):
605
+ r"""
606
+ Applies a 2D max pooling over an input Tensor which can be regarded as a composition of 2D planes.
607
+
608
+ Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, MaxPool2d outputs
609
+ regional maximum in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size
610
+ :math:`(h_{ker}, w_{ker})` and stride :math:`(s_0, s_1)`, the operation is as follows.
611
+
612
+ .. math::
613
+ \text{output}(N_i, C_j, h, w) = \max_{m=0, \ldots, h_{ker}-1} \max_{n=0, \ldots, w_{ker}-1}
614
+ \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)
615
+
616
+ .. warning::
617
+ Only support on Atlas training series.
618
+
619
+ Args:
620
+ kernel_size (Union[int, tuple[int]]): The size of kernel used to take the max value,
621
+ is an int number or a single element tuple that represents height and width are both kernel_size,
622
+ or a tuple of two int numbers that represent height and width respectively.
623
+ Default: ``1`` .
624
+ stride (Union[int, tuple[int], None]): The distance of kernel moving, an int number or a single element tuple
625
+ that represents the height and width of movement are both stride, or a tuple of two int numbers that
626
+ represent height and width of movement respectively.
627
+ Default: ``None`` , which indicates the moving step is `kernel_size` .
628
+ padding (Union(int, tuple[int], list[int])): Specifies the padding value of the pooling operation.
629
+ Default: ``0`` . `padding` can only be an integer or a tuple/list containing one or two integers. If
630
+ `padding` is an integer or a tuple/list containing one integer, it will be padded `padding` times in the
631
+ four directions of the input. If `padding` is a tuple/list containing two integers, it will be padded
632
+ `padding[0]` times in the up-down direction of the input and `padding[1]` times in the left-right direction
633
+ of the input.
634
+ dilation (Union(int, tuple[int])): The spacing between the elements of the kernel in convolution,
635
+ used to increase the receptive field of the pooling operation. If it is a tuple, it must contain one or two
636
+ integers. Default: ``1`` .
637
+ return_indices (bool): If ``True`` , the function will return both the result of max pooling and the indices of
638
+ the max elements. Default: ``False`` .
639
+ ceil_mode (bool): If ``True`` , use ceil to compute the output shape instead of floor. Default: ``False`` .
640
+
641
+ Inputs:
642
+ - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
643
+
644
+ Outputs:
645
+ If `return_indices` is ``False`` , return a Tensor `output`, else return a tuple (`output`, `argmax`).
646
+
647
+ - **output** (Tensor) - Maxpooling result, with shape :math:`(N_{out}, C_{out}, H_{out}, W_{out})`. It has the
648
+ same data type as `input`.
649
+ - **argmax** (Tensor) - Index corresponding to the maximum value. Data type is int32.
650
+
651
+ .. math::
652
+ H_{out} = \left\lfloor\frac{H_{in} + 2 * \text{padding[0]} - \text{dilation[0]}
653
+ \times (\text{kernel_size[0]} - 1) - 1}{\text{stride[0]}} + 1\right\rfloor
654
+
655
+ .. math::
656
+ W_{out} = \left\lfloor\frac{W_{in} + 2 * \text{padding[1]} - \text{dilation[1]}
657
+ \times (\text{kernel_size[1]} - 1) - 1}{\text{stride[1]}} + 1\right\rfloor
658
+
659
+ Raises:
660
+ TypeError: If `input` is not a Tensor.
661
+ ValueError: If length of shape of `input` is not equal to 4.
662
+ TypeError: If `kernel_size` , `stride` , `padding` or `dilation` is not int or tuple.
663
+ ValueError: If `kernel_size`, `stride` or `dilation` is less than 1.
664
+ ValueError: If `dilation` is not all 1.
665
+ ValueError: If `padding` is less than 0.
666
+ ValueError: If `padding` is more than half of `kernel_size`.
667
+ TypeError: If `ceil_mode` is not bool.
668
+
669
+ Supported Platforms:
670
+ ``Ascend``
671
+
672
+ Examples:
673
+ >>> import mindspore as ms
674
+ >>> import numpy as np
675
+ >>> pool = ms.mint.nn.MaxPool2d(kernel_size=3, stride=1)
676
+ >>> input = ms.Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), ms.float32)
677
+ >>> output = pool(input)
678
+ >>> print(output.shape)
679
+ (1, 2, 2, 2)
680
+ """
681
+
682
+ def __init__(self, kernel_size=1, stride=None, padding=0, dilation=1, return_indices=False,
683
+ ceil_mode=False):
684
+ """Initialize MaxPool2d."""
685
+ super(MaxPool2dExt, self).__init__()
686
+ self.return_indices = return_indices
687
+ strides = stride if (stride is not None) else kernel_size
688
+ if return_indices:
689
+ self.max_pool_func_ = MaxPoolWithIndices(kernel_size, strides, padding, dilation, ceil_mode)
690
+ else:
691
+ self.max_pool_func_ = MaxPoolWithMask(kernel_size, strides, padding, dilation, ceil_mode)
692
+
693
+ def construct(self, input):
694
+ out, indices = self.max_pool_func_(input)
695
+ if self.return_indices:
696
+ return out, indices
697
+ return out
698
+
699
+
595
700
  class MaxPool1d(_PoolNd):
596
701
  r"""
597
702
  Applies a 1D max pooling over an input Tensor which can be regarded as a composition of 1D planes.
@@ -793,6 +898,9 @@ class AvgPool3d(_PoolNd):
793
898
  \frac{1}{d_{ker} * h_{ker} * w_{ker}} \sum_{l=0}^{d_{ker}-1} \sum_{m=0}^{h_{ker}-1} \sum_{n=0}^{w_{ker}-1}
794
899
  \text{input}(N_i, C_j, s_0 \times d + l, s_1 \times h + m, s_2 \times w + n)
795
900
 
901
+ Note:
902
+ This interface currently does not support Atlas A2 training series products.
903
+
796
904
  Args:
797
905
  kernel_size (Union[int, tuple[int]], optional): The size of kernel used to take the average value,
798
906
  can be an int number or a single element tuple that represents depth, height and width, or a tuple of three
@@ -910,6 +1018,46 @@ class AvgPool3d(_PoolNd):
910
1018
  return out
911
1019
 
912
1020
 
1021
+ class AvgPool2dExt(Cell):
1022
+ r"""
1023
+ Applies a 2D average pooling over an input Tensor which can be regarded as
1024
+ a composition of 2D input planes.
1025
+
1026
+ For details, please refer to :func:`mindspore.mint.nn.functional.avg_pool2d`.
1027
+
1028
+ Supported Platforms:
1029
+ ``Ascend``
1030
+
1031
+ Examples:
1032
+ >>> import numpy as np
1033
+ >>> from mindspore import Tensor, nn
1034
+ >>> from mindspore import dtype as mstype
1035
+ >>> x = Tensor(np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4), mstype.float32)
1036
+ >>> m = nn.AvgPool2dExt(x, kernel_size=2, stride=1)
1037
+ >>> output = m(x)
1038
+ >>> print(output)
1039
+ [[[[ 2.5 3.5 4.5]
1040
+ [ 6.5 7.5 8.5]]
1041
+ [[14.5 15.5 16.5]
1042
+ [18.5 19.5 20.5]]
1043
+ [[26.5 27.5 28.5]
1044
+ [30.5 31.5 32.5]]]]
1045
+ """
1046
+ def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False,
1047
+ count_include_pad=True, divisor_override=None):
1048
+ super(AvgPool2dExt, self).__init__()
1049
+ self.kernel_size = kernel_size
1050
+ self.stride = stride
1051
+ self.padding = padding
1052
+ self.ceil_mode = ceil_mode
1053
+ self.count_include_pad = count_include_pad
1054
+ self.divisor_override = divisor_override
1055
+
1056
+ def construct(self, input):
1057
+ return avg_pool2d_ext(input, self.kernel_size, self.stride, self.padding,
1058
+ self.ceil_mode, self.count_include_pad, self.divisor_override)
1059
+
1060
+
913
1061
  class AvgPool2d(_PoolNd):
914
1062
  r"""
915
1063
  Applies a 2D average pooling over an input Tensor which can be regarded as a composition of 2D input planes.
@@ -922,6 +1070,9 @@ class AvgPool2d(_PoolNd):
922
1070
  \text{output}(N_i, C_j, h, w) = \frac{1}{h_{ker} * w_{ker}} \sum_{m=0}^{h_{ker}-1} \sum_{n=0}^{w_{ker}-1}
923
1071
  \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n)
924
1072
 
1073
+ Note:
1074
+ This interface currently does not support Atlas A2 training series products.
1075
+
925
1076
  Args:
926
1077
  kernel_size (Union[int, tuple[int]]): The size of kernel used to take the average value.
927
1078
  The data type of kernel_size must be int or a single element tuple and the value represents the height
@@ -1015,12 +1166,12 @@ class AvgPool2d(_PoolNd):
1015
1166
  data_format="NCHW"):
1016
1167
  """Initialize AvgPool2d."""
1017
1168
  super(AvgPool2d, self).__init__(kernel_size, stride, pad_mode, data_format)
1018
- self.ascend_910bc_target = (MSContext.get_instance().get_ascend_soc_version() in ['ascend910b', 'ascend910c'])
1169
+ self.ascend_910b_target = (MSContext.get_instance().get_ascend_soc_version() in ['ascend910b', 'ascend910_93'])
1019
1170
  if pad_mode.upper() == 'PAD' or padding != 0 or ceil_mode or not count_include_pad \
1020
1171
  or divisor_override is not None:
1021
- if self.ascend_910bc_target:
1022
- raise ValueError(f"For '{self.cls_name}, the pad_mod 'PAD' is not support in 910B now, "
1023
- f"it will be supported in the future.")
1172
+ if self.ascend_910b_target:
1173
+ raise ValueError(f"For '{self.cls_name}, the pad_mod 'PAD' is not support in Ascend910B or Ascend910_93"
1174
+ f" now, it will be supported in the future.")
1024
1175
  if self.format == "NHWC":
1025
1176
  raise ValueError(f"For '{self.cls_name}, the 'NHWC' format are not support when 'pad_mode' is 'pad' or "
1026
1177
  f"'padding' is not 0 or 'ceil_mode' is not False or 'count_include_pad' is not True"
@@ -1083,6 +1234,9 @@ class AvgPool1d(_PoolNd):
1083
1234
  \text{output}(N_i, C_j, l) = \frac{1}{l_{ker}} \sum_{n=0}^{l_{ker}-1}
1084
1235
  \text{input}(N_i, C_j, s_0 \times l + n)
1085
1236
 
1237
+ Note:
1238
+ This interface currently does not support Atlas A2 training series products.
1239
+
1086
1240
  Args:
1087
1241
  kernel_size (int): The size of kernel window used to take the average value, Default: ``1`` .
1088
1242
  stride (int): The distance of kernel moving, an int number that represents
@@ -1682,7 +1836,7 @@ class AdaptiveMaxPool3d(Cell):
1682
1836
 
1683
1837
  class FractionalMaxPool2d(Cell):
1684
1838
  r"""
1685
- Applies the 2D FractionalMaxPool operatin over input. The output Tensor shape can be determined by either
1839
+ Applies the 2D FractionalMaxPool operation over input. The output Tensor shape can be determined by either
1686
1840
  `output_size` or `output_ratio`, and the step size is determined by `_random_samples`. `output_size` will take
1687
1841
  effect when `output_size` and `output_ratio` are set at the same time.
1688
1842
  And `output_size` and `output_ratio` can not be ``None`` at the same time.
@@ -1783,7 +1937,7 @@ class FractionalMaxPool2d(Cell):
1783
1937
 
1784
1938
  class FractionalMaxPool3d(Cell):
1785
1939
  r"""
1786
- Applies the 3D FractionalMaxPool operatin over `input`. The output Tensor shape can be determined by either
1940
+ Applies the 3D FractionalMaxPool operation over `input`. The output Tensor shape can be determined by either
1787
1941
  `output_size` or `output_ratio`, and the step size is determined by `_random_samples`. `output_size` will take
1788
1942
  effect when `output_size` and `output_ratio` are set at the same time.
1789
1943
  And `output_size` and `output_ratio` can not be ``None`` at the same time.
@@ -16,6 +16,7 @@
16
16
  Transformer Cells module, include TransformerEncoderLayer, TransformerDecoderLayer,
17
17
  TransformerEncoder, TransformerDecoder, Transformer.
18
18
  """
19
+ import copy
19
20
  import math
20
21
  from typing import Union, Optional
21
22
  import mindspore
@@ -31,7 +32,6 @@ from .basic import Dense, Dropout
31
32
  from .activation import ReLU, GELU
32
33
  from .normalization import LayerNorm
33
34
  from .container import CellList
34
-
35
35
  __all__ = ['MultiheadAttention', 'TransformerEncoderLayer', 'TransformerDecoderLayer',
36
36
  'TransformerEncoder', 'TransformerDecoder', 'Transformer']
37
37
 
@@ -588,7 +588,7 @@ class TransformerEncoder(Cell):
588
588
  encoder_layer.dropout_num, encoder_layer.activation1,
589
589
  encoder_layer.layernorm_eps, encoder_layer.batch_first,
590
590
  encoder_layer.norm_first, dtype=encoder_layer.dtype)
591
- self.layers = CellList([layers for _ in range(num_layers)])
591
+ self.layers = CellList([copy.deepcopy(layers) for _ in range(num_layers)])
592
592
  self.num_layers = num_layers
593
593
  self.norm = norm
594
594
 
@@ -663,7 +663,7 @@ class TransformerDecoder(Cell):
663
663
  decoder_layer.dropout_num, decoder_layer.activation1,
664
664
  decoder_layer.layernorm_eps, decoder_layer.batch_first,
665
665
  decoder_layer.norm_first, dtype=decoder_layer.dtype)
666
- self.layers = CellList([layers for _ in range(num_layers)])
666
+ self.layers = CellList([copy.deepcopy(layers) for _ in range(num_layers)])
667
667
  self.num_layers = num_layers
668
668
  self.norm = norm
669
669
 
@@ -25,7 +25,7 @@ from mindspore.nn.loss.loss import LossBase, L1Loss, CTCLoss, MSELoss, SmoothL1L
25
25
  SampledSoftmaxLoss, TripletMarginWithDistanceLoss,\
26
26
  PoissonNLLLoss, MultiLabelSoftMarginLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss, \
27
27
  RMSELoss, MAELoss, HuberLoss, CrossEntropyLoss, NLLLoss, KLDivLoss, MarginRankingLoss, GaussianNLLLoss, \
28
- HingeEmbeddingLoss, MultilabelMarginLoss, TripletMarginLoss
28
+ HingeEmbeddingLoss, MultilabelMarginLoss, TripletMarginLoss, L1LossExt
29
29
 
30
30
 
31
31
  __all__ = ['LossBase', 'L1Loss', 'CTCLoss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss',
@@ -33,4 +33,4 @@ __all__ = ['LossBase', 'L1Loss', 'CTCLoss', 'MSELoss', 'SmoothL1Loss', 'SoftMarg
33
33
  'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'TripletMarginWithDistanceLoss', 'PoissonNLLLoss',
34
34
  'MultiLabelSoftMarginLoss', 'DiceLoss', 'MultiClassDiceLoss', 'MultilabelMarginLoss',
35
35
  'RMSELoss', 'MAELoss', 'HuberLoss', 'CrossEntropyLoss', 'NLLLoss', 'KLDivLoss', 'MarginRankingLoss',
36
- 'GaussianNLLLoss', 'HingeEmbeddingLoss', 'TripletMarginLoss']
36
+ 'GaussianNLLLoss', 'HingeEmbeddingLoss', 'TripletMarginLoss', 'L1LossExt']
mindspore/nn/loss/loss.py CHANGED
@@ -33,6 +33,7 @@ from mindspore.nn.cell import Cell
33
33
  from mindspore.nn.layer.activation import get_activation
34
34
  from mindspore import _checkparam as validator
35
35
  from mindspore import context
36
+ from mindspore.ops.auto_generate import l1_loss_ext_op
36
37
 
37
38
 
38
39
  class LossBase(Cell):
@@ -247,6 +248,80 @@ class L1Loss(LossBase):
247
248
  return F.l1_loss(logits, labels, self.reduction)
248
249
 
249
250
 
251
+ class L1LossExt(LossBase):
252
+ r"""
253
+ L1Loss is used to calculate the mean absolute error between the predicted value and the target value.
254
+
255
+ Assuming that the :math:`x` and :math:`y` are 1-D Tensor, length :math:`N`, then calculate the loss of :math:`x` and
256
+ :math:`y` without dimensionality reduction (the reduction parameter is set to ``'none'`` ). The formula is as
257
+ follows:
258
+
259
+ .. math::
260
+ \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad \text{with } l_n = \left| x_n - y_n \right|,
261
+
262
+ where :math:`N` is the batch size. If `reduction` is not ``'none'`` , then:
263
+
264
+ .. math::
265
+ \ell(x, y) =
266
+ \begin{cases}
267
+ \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
268
+ \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
269
+ \end{cases}
270
+
271
+ Args:
272
+ reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
273
+ ``'sum'`` . Default: ``'mean'`` .
274
+
275
+ - ``'none'``: no reduction will be applied.
276
+ - ``'mean'``: compute and return the mean of elements in the output.
277
+ - ``'sum'``: the output elements will be summed.
278
+
279
+ Inputs:
280
+ - **logits** (Tensor) - Predicted value, Tensor of any dimension.
281
+ - **labels** (Tensor) - Target value, same shape as the `logits` in common cases.
282
+ However, it supports the shape of `logits` is different from the shape of `labels`
283
+ and they should be broadcasted to each other.
284
+
285
+ Outputs:
286
+ Tensor, data type is float.
287
+
288
+ Raises:
289
+ ValueError: If `reduction` is not one of ``'none'`` , ``'mean'`` or ``'sum'`` .
290
+ ValueError: If `logits` and `labels` have different shapes and cannot be broadcasted to each other.
291
+
292
+ Supported Platforms:
293
+ ``Ascend``
294
+
295
+ Examples:
296
+ >>> import mindspore
297
+ >>> from mindspore import Tensor, nn
298
+ >>> import numpy as np
299
+ >>> # Case 1: logits.shape = labels.shape = (3,)
300
+ >>> loss = nn.L1LossExt()
301
+ >>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
302
+ >>> labels = Tensor(np.array([1, 2, 2]), mindspore.float32)
303
+ >>> output = loss(logits, labels)
304
+ >>> print(output)
305
+ 0.33333334
306
+ >>> # Case 2: logits.shape = (3,), labels.shape = (2, 3)
307
+ >>> loss = nn.L1LossExt(reduction='none')
308
+ >>> logits = Tensor(np.array([1, 2, 3]), mindspore.float32)
309
+ >>> labels = Tensor(np.array([[1, 1, 1], [1, 2, 2]]), mindspore.float32)
310
+ >>> output = loss(logits, labels)
311
+ >>> print(output)
312
+ [[0. 1. 2.]
313
+ [0. 0. 1.]]
314
+ """
315
+
316
+ def __init__(self, reduction='mean'):
317
+ """Initialize L1LossExt."""
318
+ super(L1LossExt, self).__init__(reduction)
319
+ self.reduction = reduction
320
+
321
+ def construct(self, logits, labels):
322
+ return l1_loss_ext_op(logits, labels, self.reduction)
323
+
324
+
250
325
  class MSELoss(LossBase):
251
326
  r"""
252
327
  Calculates the mean squared error between the predicted value and the label value.
@@ -287,6 +362,7 @@ class MSELoss(LossBase):
287
362
  Raises:
288
363
  ValueError: If `reduction` is not one of ``'none'``, ``'mean'`` or ``'sum'``.
289
364
  ValueError: If `logits` and `labels` have different shapes and cannot be broadcasted.
365
+ TypeError: if `logits` and `labels` have different data types.
290
366
 
291
367
  Supported Platforms:
292
368
  ``Ascend`` ``GPU`` ``CPU``
@@ -1580,7 +1656,7 @@ class BCELoss(LossBase):
1580
1656
  The formula is as follow:
1581
1657
 
1582
1658
  .. math::
1583
- L = \{l_1,\dots,l_N\}^\top, \quad
1659
+ L = \{l_1,\dots,l_n,\dots,l_N\}^\top, \quad
1584
1660
  l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right]
1585
1661
 
1586
1662
  where N is the batch size. Then,
@@ -1850,14 +1926,16 @@ class BCEWithLogitsLoss(LossBase):
1850
1926
 
1851
1927
  weight (Tensor, optional): A rescaling weight applied to the loss of each batch element.
1852
1928
  If not None, it can be broadcast to a tensor with shape of `input`,
1853
- data type must be float16 or float32. Default: ``None`` .
1929
+ data type must be float16, float32 or bfloat16(only Atlas A2 series products are supported).
1930
+ Default: ``None`` .
1854
1931
  pos_weight (Tensor, optional): A weight of positive examples. Must be a vector with length equal to the
1855
1932
  number of classes. If not None, it must be broadcast to a tensor with shape of `input`, data type
1856
- must be float16 or float32. Default: ``None`` .
1933
+ must be float16, float32 or bfloat16(only Atlas A2 series products are supported). Default: ``None`` .
1857
1934
 
1858
1935
  Inputs:
1859
1936
  - **input** (Tensor) - Input `input` with shape :math:`(N, *)` where :math:`*` means, any number
1860
- of additional dimensions. The data type must be float16 or float32.
1937
+ of additional dimensions. The data type must be float16, float32 or bfloat16(only Atlas A2 series products
1938
+ are supported).
1861
1939
  - **target** (Tensor) - Ground truth label with shape :math:`(N, *)` where :math:`*` means, any number
1862
1940
  of additional dimensions. The same shape and data type as `input`.
1863
1941
 
@@ -1867,9 +1945,9 @@ class BCEWithLogitsLoss(LossBase):
1867
1945
 
1868
1946
  Raises:
1869
1947
  TypeError: If input `input` or `target` is not Tensor.
1870
- TypeError: If data type of `input` or `target` is neither float16 nor float32.
1948
+ TypeError: If data type of `input` or `target` is not float16, float32 or bfloat16.
1871
1949
  TypeError: If `weight` or `pos_weight` is a parameter.
1872
- TypeError: If data type of `weight` or `pos_weight` is neither float16 nor float32.
1950
+ TypeError: If data type of `weight` or `pos_weight` is not float16 , float32 or bfloat16.
1873
1951
  TypeError: If data type of `reduction` is not string.
1874
1952
  ValueError: If `weight` or `pos_weight` can not be broadcast to a tensor with shape of `input`.
1875
1953
  ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
@@ -38,7 +38,8 @@ from mindspore.nn.optim.adafactor import AdaFactor
38
38
  from mindspore.nn.optim.adasum import AdaSumByDeltaWeightWrapCell, AdaSumByGradWrapCell
39
39
  from mindspore.nn.optim.adamax import AdaMax
40
40
  from mindspore.nn.optim.adadelta import Adadelta
41
+ from mindspore.nn.optim.tft_wrapper import OptTFTWrapper
41
42
 
42
43
  __all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', 'AdamOffload',
43
44
  'Lamb', 'SGD', 'ASGD', 'Rprop', 'FTRL', 'RMSProp', 'ProximalAdagrad', 'Adagrad', 'thor', 'AdaFactor',
44
- 'AdaSumByDeltaWeightWrapCell', 'AdaSumByGradWrapCell', 'AdaMax', 'Adadelta']
45
+ 'AdaSumByDeltaWeightWrapCell', 'AdaSumByGradWrapCell', 'AdaMax', 'Adadelta', 'OptTFTWrapper']
@@ -55,7 +55,7 @@ class Adadelta(Optimizer):
55
55
  w_{t} = w_{t-1} - \gamma * update_{t}
56
56
  \end{array}
57
57
 
58
- where :math:`g` represents `grads`, :math:`\gamma` represents `learning_rate`, :math:`p` represents `rho`,
58
+ where :math:`g` represents `grads`, :math:`\gamma` represents `learning_rate`, :math:`\rho` represents `rho`,
59
59
  :math:`\epsilon` represents `epsilon`, :math:`w` represents `params`,
60
60
  :math:`accum` represents accumulation, :math:`accum\_update` represents accumulation update,
61
61
  :math:`t` represents current step.
@@ -906,7 +906,7 @@ class AdamWeightDecay(Optimizer):
906
906
  There is usually no connection between a optimizer and mixed precision. But when `FixedLossScaleManager` is used
907
907
  and `drop_overflow_update` in `FixedLossScaleManager` is set to False, optimizer needs to set the 'loss_scale'.
908
908
  As this optimizer has no argument of `loss_scale`, so `loss_scale` needs to be processed by other means, refer
909
- document `LossScale <https://www.mindspore.cn/tutorials/en/master/advanced/mixed_precision.html>`_ to
909
+ document `LossScale <https://www.mindspore.cn/tutorials/en/master/beginner/mixed_precision.html>`_ to
910
910
  process `loss_scale` correctly.
911
911
 
912
912
  If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without
@@ -132,7 +132,7 @@ class Lamb(Optimizer):
132
132
  There is usually no connection between a optimizer and mixed precision. But when `FixedLossScaleManager` is used
133
133
  and `drop_overflow_update` in `FixedLossScaleManager` is set to False, optimizer needs to set the 'loss_scale'.
134
134
  As this optimizer has no argument of `loss_scale`, so `loss_scale` needs to be processed by other means. Refer
135
- document `LossScale <https://www.mindspore.cn/tutorials/en/master/advanced/mixed_precision.html>`_ to
135
+ document `LossScale <https://www.mindspore.cn/tutorials/en/master/beginner/mixed_precision.html>`_ to
136
136
  process `loss_scale` correctly.
137
137
 
138
138
  If parameters are not grouped, the `weight_decay` in optimizer will be applied on the network parameters without
@@ -0,0 +1,124 @@
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """OptTFTWrapper"""
16
+ from __future__ import absolute_import
17
+
18
+ import os
19
+ from mindspore.common.tensor import Tensor
20
+ from mindspore.nn.optim.optimizer import Optimizer
21
+ from mindspore.ops.operations.manually_defined._inner import TensorReport
22
+ from mindspore import ops, context
23
+ from mindspore.common.parameter import Parameter
24
+ import mindspore.common.dtype as mstype
25
+
26
+ class OptTFTWrapper(Optimizer):
27
+ r"""
28
+ Implements TFT optimizer wrapper, this wrapper is used to report status to MindIO TFT before optimizer updating.
29
+
30
+ Note:
31
+ This optimizer is depend on MindIO TFT feature. Currently only support ascend graph mode and
32
+ sink_size must be less than 1.
33
+
34
+ Args:
35
+ opt (Optimizer): Must be sub-class of Optimizer.
36
+
37
+ Inputs:
38
+ - **gradients** (tuple[Tensor]) - The gradients of opt's `params`, the shape is the same as opt's `params`.
39
+
40
+ Outputs:
41
+ Tensor, result of executing optimizer 'opt'.
42
+
43
+ Raises:
44
+ TypeError: If the parameter opt is not an subclass of Optimizer.
45
+ ValueError: If the platform is not Ascend graph mode, or customer doesn't switch on TFT feature.
46
+
47
+ Supported Platforms:
48
+ ``Ascend``
49
+
50
+ Examples:
51
+ >>> import mindspore as ms
52
+ >>> from mindspore import nn
53
+ >>>
54
+ >>> # Define the network structure of LeNet5. Refer to
55
+ >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
56
+ >>> net = LeNet5()
57
+ >>> #1) All parameters use the same learning rate and weight decay
58
+ >>> optim = nn.SGD(params=net.trainable_params())
59
+ >>> optim_wrapper = nn.OptTFTWrapper(optim)
60
+ >>>
61
+ >>> loss = nn.SoftmaxCrossEntropyWithLogits()
62
+ >>> model = ms.train.Model(net, loss_fn=loss, optimizer=optim)
63
+ """
64
+
65
+ def __init__(self, opt, **kwargs):
66
+ if not isinstance(opt, Optimizer):
67
+ raise TypeError(f"For 'OptTFTWrapper', the argument 'opt' must be Optimizer type, " f"but got {type(opt)}.")
68
+ super(OptTFTWrapper, self).__init__(opt.learning_rate, opt._parameters) # pylint: disable=W0212
69
+ tft_env = os.getenv("MS_ENABLE_TFT", "")
70
+ if ("TTP:1" not in tft_env) and ("UCE:1" not in tft_env):
71
+ raise ValueError("MindIO TFT regitster need custom switch on[MS_ENABLE_TFT='{TTP:1,UCE:1}']!")
72
+ mode = context.get_context("mode")
73
+ device_target = context.get_context("device_target")
74
+ if device_target != "Ascend" or mode != context.GRAPH_MODE:
75
+ raise ValueError("MindIO adataper only support on Ascend device with GRAPH Mode!")
76
+ self.opt = opt
77
+ self.report = TensorReport()
78
+ self.depend = ops.Depend()
79
+ self.allreduce_sum = ops.AllReduce()
80
+ self.allreduce_sum.add_prim_attr("tft_report_before", True)
81
+ self.tft_g_one_flag = Parameter(Tensor([1], dtype=mstype.int32))
82
+
83
+ self.param_rank = opt.param_rank
84
+ self.optim_filter = opt.optim_filter
85
+ self.loss_scale = opt.loss_scale
86
+ self.dynamic_weight_decay = opt.dynamic_weight_decay
87
+ self.grad_centralization = opt.grad_centralization
88
+
89
+ self.dynamic_lr = opt.dynamic_lr
90
+ self.global_step = opt.global_step
91
+ self.is_group = opt.is_group
92
+ self.is_group_lr = opt.is_group_lr
93
+ self.is_group_params_ordered = opt.is_group_params_ordered
94
+ self.use_parallel = opt.use_parallel
95
+ if self.is_group:
96
+ self.group_params = opt.group_params
97
+ self.group_lr = opt.group_lr
98
+ self.group_weight_decay = opt.group_weight_decay
99
+ self.group_grad_centralization = opt.group_grad_centralization
100
+ self.grad_centralization_flags = opt.grad_centralization_flags
101
+
102
+ self.skip_auto_parallel_compile = opt.skip_auto_parallel_compile
103
+
104
+ self.learning_rate = opt.learning_rate
105
+ self.parameters = opt.parameters
106
+ self.decay_flags = opt.decay_flags
107
+ self.dynamic_decay_flags = opt.dynamic_decay_flags
108
+ self.weight_decay = opt.weight_decay
109
+ self.exec_weight_decay = opt.exec_weight_decay
110
+ self.ps_parameters = opt.ps_parameters
111
+ self.cache_enable = opt.cache_enable
112
+ self.reciprocal_scale = opt.reciprocal_scale
113
+ self.need_scale = opt.need_scale
114
+ self.global_step_increase_tensor = opt.global_step_increase_tensor
115
+ self.param_length = opt.param_length
116
+ self.enable_tuple_broaden = opt.enable_tuple_broaden
117
+
118
+ def construct(self, gradients):
119
+ tft_g_one_flag = self.depend(self.tft_g_one_flag, gradients)
120
+ self.tft_g_one_flag = self.allreduce_sum(tft_g_one_flag)
121
+
122
+ grads = self.depend(gradients, self.report("tft_report", self.tft_g_one_flag))
123
+ opt_ret = self.opt(grads)
124
+ return opt_ret