mindspore 2.1.0__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (539) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  172. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  173. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  174. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  175. mindspore/nn/__init__.py +0 -2
  176. mindspore/nn/cell.py +316 -74
  177. mindspore/nn/dynamic_lr.py +21 -21
  178. mindspore/nn/layer/activation.py +21 -28
  179. mindspore/nn/layer/basic.py +15 -13
  180. mindspore/nn/layer/channel_shuffle.py +1 -1
  181. mindspore/nn/layer/container.py +271 -9
  182. mindspore/nn/layer/conv.py +310 -207
  183. mindspore/nn/layer/dense.py +8 -5
  184. mindspore/nn/layer/embedding.py +33 -27
  185. mindspore/nn/layer/flash_attention.py +82 -41
  186. mindspore/nn/layer/image.py +8 -6
  187. mindspore/nn/layer/math.py +13 -18
  188. mindspore/nn/layer/normalization.py +107 -66
  189. mindspore/nn/layer/padding.py +1 -1
  190. mindspore/nn/layer/pooling.py +131 -109
  191. mindspore/nn/layer/rnn_cells.py +22 -17
  192. mindspore/nn/layer/rnns.py +13 -16
  193. mindspore/nn/layer/thor_layer.py +1 -1
  194. mindspore/nn/layer/transformer.py +221 -154
  195. mindspore/nn/learning_rate_schedule.py +9 -1
  196. mindspore/nn/loss/loss.py +235 -174
  197. mindspore/nn/optim/ada_grad.py +2 -1
  198. mindspore/nn/optim/adadelta.py +1 -0
  199. mindspore/nn/optim/adafactor.py +2 -1
  200. mindspore/nn/optim/adam.py +7 -4
  201. mindspore/nn/optim/adamax.py +3 -2
  202. mindspore/nn/optim/adasum.py +2 -2
  203. mindspore/nn/optim/asgd.py +2 -3
  204. mindspore/nn/optim/ftrl.py +6 -5
  205. mindspore/nn/optim/lamb.py +7 -4
  206. mindspore/nn/optim/lars.py +1 -1
  207. mindspore/nn/optim/lazyadam.py +5 -3
  208. mindspore/nn/optim/momentum.py +2 -1
  209. mindspore/nn/optim/optimizer.py +53 -4
  210. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  211. mindspore/nn/optim/rmsprop.py +4 -3
  212. mindspore/nn/optim/rprop.py +23 -12
  213. mindspore/nn/optim/sgd.py +26 -11
  214. mindspore/nn/optim/thor.py +9 -7
  215. mindspore/nn/probability/bijector/bijector.py +5 -5
  216. mindspore/nn/probability/bijector/power_transform.py +27 -27
  217. mindspore/nn/probability/bijector/softplus.py +3 -3
  218. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  219. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  220. mindspore/nn/probability/distribution/beta.py +3 -3
  221. mindspore/nn/probability/distribution/categorical.py +7 -7
  222. mindspore/nn/probability/distribution/cauchy.py +0 -1
  223. mindspore/nn/probability/distribution/distribution.py +3 -3
  224. mindspore/nn/probability/distribution/gamma.py +3 -3
  225. mindspore/nn/probability/distribution/geometric.py +4 -4
  226. mindspore/nn/probability/distribution/gumbel.py +4 -4
  227. mindspore/nn/probability/distribution/log_normal.py +2 -2
  228. mindspore/nn/probability/distribution/logistic.py +2 -2
  229. mindspore/nn/probability/distribution/poisson.py +4 -4
  230. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  231. mindspore/nn/probability/distribution/uniform.py +6 -6
  232. mindspore/nn/wrap/cell_wrapper.py +78 -34
  233. mindspore/nn/wrap/grad_reducer.py +8 -5
  234. mindspore/nn/wrap/loss_scale.py +105 -42
  235. mindspore/numpy/array_creations.py +1 -2
  236. mindspore/numpy/array_ops.py +3 -2
  237. mindspore/offline_debug/convert_async.py +2 -2
  238. mindspore/ops/_grad_experimental/__init__.py +0 -5
  239. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  240. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  241. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  242. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  243. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  244. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  245. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  246. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  248. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  249. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  250. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  251. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  252. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  253. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  254. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  255. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  256. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  257. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  258. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  259. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  260. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  261. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  262. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  263. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  264. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  265. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  266. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  267. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  268. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  269. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  270. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  271. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  272. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  273. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  274. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  275. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  276. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  277. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  278. mindspore/ops/_primitive_cache.py +1 -1
  279. mindspore/ops/_tracefunc.py +45 -13
  280. mindspore/ops/_utils/utils.py +4 -1
  281. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  282. mindspore/ops/_vmap/vmap_base.py +3 -3
  283. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  284. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  285. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  286. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  287. mindspore/ops/arg_dtype_cast.py +54 -0
  288. mindspore/ops/composite/base.py +37 -10
  289. mindspore/ops/composite/math_ops.py +5 -4
  290. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  291. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  292. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  293. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  294. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  295. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  297. mindspore/ops/deprecated.py +304 -0
  298. mindspore/ops/function/__init__.py +4 -1
  299. mindspore/ops/function/array_func.py +167 -189
  300. mindspore/ops/function/clip_func.py +81 -13
  301. mindspore/ops/function/debug_func.py +1 -1
  302. mindspore/ops/function/grad/grad_func.py +18 -8
  303. mindspore/ops/function/image_func.py +10 -4
  304. mindspore/ops/function/linalg_func.py +5 -5
  305. mindspore/ops/function/math_func.py +575 -386
  306. mindspore/ops/function/nn_func.py +470 -251
  307. mindspore/ops/function/random_func.py +86 -56
  308. mindspore/ops/function/sparse_func.py +1 -1
  309. mindspore/ops/function/sparse_unary_func.py +14 -12
  310. mindspore/ops/function/vmap_func.py +6 -5
  311. mindspore/ops/functional.py +15 -10
  312. mindspore/ops/op_info_register.py +235 -19
  313. mindspore/ops/operations/__init__.py +25 -17
  314. mindspore/ops/operations/_grad_ops.py +52 -7
  315. mindspore/ops/operations/_inner_ops.py +213 -12
  316. mindspore/ops/operations/_quant_ops.py +4 -8
  317. mindspore/ops/operations/_sequence_ops.py +42 -0
  318. mindspore/ops/operations/array_ops.py +64 -280
  319. mindspore/ops/operations/comm_ops.py +105 -57
  320. mindspore/ops/operations/custom_ops.py +10 -3
  321. mindspore/ops/operations/debug_ops.py +8 -4
  322. mindspore/ops/operations/image_ops.py +18 -12
  323. mindspore/ops/operations/math_ops.py +185 -138
  324. mindspore/ops/operations/nn_ops.py +716 -492
  325. mindspore/ops/operations/other_ops.py +0 -22
  326. mindspore/ops/operations/random_ops.py +53 -111
  327. mindspore/ops/operations/sparse_ops.py +3 -1
  328. mindspore/ops/primitive.py +24 -18
  329. mindspore/parallel/_auto_parallel_context.py +68 -8
  330. mindspore/parallel/_cost_model_context.py +2 -2
  331. mindspore/parallel/_offload_context.py +17 -3
  332. mindspore/parallel/_parallel_serialization.py +2 -2
  333. mindspore/parallel/_ps_context.py +12 -0
  334. mindspore/parallel/_tensor.py +14 -12
  335. mindspore/parallel/_transformer/layers.py +5 -3
  336. mindspore/parallel/_transformer/loss.py +1 -0
  337. mindspore/parallel/_transformer/moe.py +2 -2
  338. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  339. mindspore/parallel/_transformer/transformer.py +23 -3
  340. mindspore/parallel/_utils.py +11 -7
  341. mindspore/parallel/algo_parameter_config.py +85 -5
  342. mindspore/parallel/checkpoint_transform.py +6 -10
  343. mindspore/parallel/shard.py +4 -4
  344. mindspore/profiler/common/struct_type.py +3 -3
  345. mindspore/profiler/common/util.py +3 -2
  346. mindspore/profiler/envprofiling.py +1 -1
  347. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  348. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  349. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  350. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  351. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  352. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  353. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  354. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  355. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  356. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  357. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  358. mindspore/profiler/parser/flops_parser.py +15 -11
  359. mindspore/profiler/parser/framework_parser.py +37 -21
  360. mindspore/profiler/parser/hccl_parser.py +16 -12
  361. mindspore/profiler/parser/integrator.py +22 -11
  362. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  363. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  364. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  365. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  366. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  367. mindspore/profiler/parser/optime_parser.py +1 -1
  368. mindspore/profiler/parser/profiler_info.py +2 -2
  369. mindspore/profiler/parser/step_trace_parser.py +11 -14
  370. mindspore/profiler/profiling.py +139 -71
  371. mindspore/rewrite/api/node.py +102 -19
  372. mindspore/rewrite/api/node_type.py +5 -1
  373. mindspore/rewrite/api/scoped_value.py +9 -17
  374. mindspore/rewrite/api/symbol_tree.py +131 -47
  375. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  376. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  377. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  378. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  379. mindspore/rewrite/common/rewrite_elog.py +5 -1
  380. mindspore/rewrite/namer.py +33 -24
  381. mindspore/rewrite/namespace.py +14 -5
  382. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  383. mindspore/rewrite/node/call_function.py +79 -0
  384. mindspore/rewrite/node/cell_container.py +135 -0
  385. mindspore/rewrite/node/control_flow.py +88 -0
  386. mindspore/rewrite/{node.py → node/node.py} +273 -234
  387. mindspore/rewrite/node/node_manager.py +254 -0
  388. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  389. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  390. mindspore/rewrite/parsers/assign_parser.py +216 -221
  391. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  392. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  393. mindspore/rewrite/parsers/constant_parser.py +9 -6
  394. mindspore/rewrite/parsers/container_parser.py +9 -7
  395. mindspore/rewrite/parsers/for_parser.py +36 -15
  396. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  397. mindspore/rewrite/parsers/if_parser.py +28 -24
  398. mindspore/rewrite/parsers/module_parser.py +196 -25
  399. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  400. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  401. mindspore/rewrite/parsers/return_parser.py +6 -6
  402. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  403. mindspore/rewrite/sparsify/utils.py +1 -1
  404. mindspore/rewrite/symbol_tree.py +525 -577
  405. mindspore/rewrite/symbol_tree_builder.py +9 -193
  406. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  407. mindspore/run_check/_check_version.py +2 -2
  408. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  409. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  410. mindspore/scipy/linalg.py +1 -1
  411. mindspore/scipy/optimize/minimize.py +7 -3
  412. mindspore/train/_utils.py +7 -3
  413. mindspore/train/amp.py +323 -123
  414. mindspore/train/anf_ir_pb2.py +14 -2
  415. mindspore/train/callback/_backup_and_restore.py +2 -12
  416. mindspore/train/callback/_callback.py +29 -4
  417. mindspore/train/callback/_checkpoint.py +23 -8
  418. mindspore/train/callback/_early_stop.py +2 -2
  419. mindspore/train/callback/_landscape.py +4 -4
  420. mindspore/train/callback/_loss_monitor.py +2 -2
  421. mindspore/train/callback/_on_request_exit.py +2 -2
  422. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  423. mindspore/train/callback/_summary_collector.py +14 -7
  424. mindspore/train/callback/_time_monitor.py +58 -5
  425. mindspore/train/data_sink.py +5 -11
  426. mindspore/train/dataset_helper.py +83 -57
  427. mindspore/train/loss_scale_manager.py +2 -2
  428. mindspore/train/metrics/__init__.py +3 -3
  429. mindspore/train/metrics/cosine_similarity.py +1 -1
  430. mindspore/train/metrics/hausdorff_distance.py +3 -2
  431. mindspore/train/metrics/mean_surface_distance.py +3 -2
  432. mindspore/train/metrics/metric.py +39 -19
  433. mindspore/train/metrics/roc.py +2 -2
  434. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  435. mindspore/train/mind_ir_pb2.py +85 -36
  436. mindspore/train/model.py +185 -45
  437. mindspore/train/serialization.py +390 -150
  438. mindspore/train/summary/_writer_pool.py +3 -2
  439. mindspore/train/summary/summary_record.py +14 -10
  440. mindspore/train/train_thor/convert_utils.py +3 -3
  441. mindspore/train/train_thor/dataset_helper.py +1 -1
  442. mindspore/version.py +1 -1
  443. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  444. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +447 -507
  445. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  446. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  447. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  448. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  449. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  450. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  451. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  452. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  453. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  454. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  455. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  456. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  457. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  458. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  459. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  460. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  461. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  462. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  463. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  464. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  465. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  466. mindspore/_extends/graph_kernel/expander.py +0 -80
  467. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  468. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  469. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  470. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  471. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  472. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  473. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  474. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  475. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  476. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  477. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  478. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  479. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  480. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  481. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  482. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  483. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  484. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  485. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  486. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  487. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  488. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  489. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  490. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  491. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  492. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  493. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  494. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  495. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  496. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  497. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  498. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  499. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  500. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  501. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  502. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  503. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  504. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  505. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  506. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  507. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  508. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  509. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  510. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  511. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  512. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  513. mindspore/dataset/datapreprocess/__init__.py +0 -20
  514. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  515. mindspore/include/api/net.h +0 -142
  516. mindspore/nn/lr_scheduler.py +0 -262
  517. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  518. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  519. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  520. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  521. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  522. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  523. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  524. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  525. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  526. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  529. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  530. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  531. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  532. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  533. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  537. mindspore/rewrite/node_visitor.py +0 -44
  538. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  539. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -1,181 +0,0 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """Define the grad rules of linalg related operations."""
17
- from __future__ import absolute_import
18
-
19
- import mindspore
20
-
21
- from mindspore.ops import Tensor
22
- from mindspore.ops import functional as F
23
- from mindspore.ops import operations as P
24
- from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
25
- from mindspore.ops.operations import math_ops as math
26
- from mindspore.ops.operations import linalg_ops as linalg
27
- from mindspore.ops.operations import array_ops as arrays
28
- from mindspore.ops.primitive import constexpr, _primexpr
29
- from mindspore.ops._grad_experimental.grad_base import bprop_getters
30
-
31
- _shape = arrays.Shape()
32
-
33
- _dtype = arrays.DType()
34
- _cast = arrays.Cast()
35
- _transpose = arrays.Transpose()
36
-
37
- _conj = math.Conj()
38
- _reciprocal = math.Reciprocal()
39
-
40
- _k_0 = Tensor(0, mindspore.int32)
41
-
42
-
43
- @_primexpr
44
- def _check_dim(dim):
45
- if dim < 2:
46
- raise ValueError(f"The dim can not be less than 2, which is {dim}.")
47
-
48
-
49
- @_primexpr
50
- def generate_perm_for_matrix_transpose(input_dim):
51
- perm = tuple(range(input_dim - 2))
52
- perm = perm + (input_dim - 1, input_dim - 2)
53
- return perm
54
-
55
-
56
- def _matrix_transpose(a):
57
- """Transpose last two axes"""
58
- a_shape = _shape(a)
59
- if F.is_sequence_value_unknown(a_shape):
60
- dim = P.Rank()(a)
61
- perm = P.Range()(P.Cast()(0, mindspore.int64), P.Cast()(dim, mindspore.int64), P.Cast()(1, mindspore.int64))
62
- perm = P.Concat(axis=-1)((perm[:-2], perm[-1:], perm[-2:-1]))
63
- else:
64
- dim = P.Rank()(a)
65
- _check_dim(dim)
66
- perm = generate_perm_for_matrix_transpose(dim)
67
- return _transpose(a, perm)
68
-
69
-
70
- def _adjoint(a):
71
- return _matrix_transpose(_conj(a))
72
-
73
-
74
- def _safe_reciprocal(x, epsilon=1e-20):
75
- return x * _reciprocal(x * x + epsilon)
76
-
77
-
78
- @constexpr
79
- def _make_tensor(value, dtype):
80
- return Tensor(value, dtype)
81
-
82
-
83
- def _matrix_diag(diagonal):
84
- """Do matrix diagnoal"""
85
- diagonal_shape = _shape(diagonal)
86
- if F.is_sequence_value_unknown(diagonal_shape):
87
- row = P.Cast()(diagonal_shape[-1], mindspore.int32)
88
- return arrays.MatrixDiagV3()(diagonal, _k_0, row, row, P.Cast()(0, _dtype(diagonal)))
89
-
90
- row = _make_tensor(diagonal_shape[-1], mindspore.int32)
91
- return arrays.MatrixDiagV3()(diagonal, _k_0, row, row, _make_tensor(0, _dtype(diagonal)))
92
-
93
-
94
- def _mat_mul(x, y):
95
- """Do matmul"""
96
- tensor_rank = P.Rank()(x)
97
- if tensor_rank > 2:
98
- return math.BatchMatMul()(x, y)
99
- return math.MatMul()(x, y)
100
-
101
-
102
- @bprop_getters.register(linalg.Svd)
103
- def get_bprop_svd(self):
104
- """Generate bprop for Svd"""
105
- full_matrices = self.full_matrices
106
- compute_uv = self.compute_uv
107
-
108
- svd = linalg.Svd(compute_uv=True)
109
- square = math.Square()
110
- matrix_set_diag = arrays.MatrixSetDiagV3()
111
- expand_dims = arrays.ExpandDims()
112
-
113
- def bprop(a, out, dout):
114
- if not compute_uv:
115
- s, u, v = svd(a)
116
- da = _mat_mul(u, _mat_mul(_matrix_diag(dout[0]), _adjoint(v)))
117
- return (da,)
118
-
119
- a_shape = _shape(a)
120
- tensor_rank = P.Rank()(a)
121
- _check_dim(tensor_rank)
122
- m = a_shape[-2]
123
- n = a_shape[-1]
124
- s, u, v = out
125
- ds, du, dv = dout
126
- use_adjoint = False
127
- if m > n:
128
- use_adjoint = True
129
- m, n = n, m
130
- u, v = v, u
131
- du, dv = dv, du
132
-
133
- if full_matrices and abs(m - n) > 1:
134
- raise ValueError("For 'Svd' gradient, not support for abs(m - n) > 1 with full_matrices is True.")
135
-
136
- s_mat = _matrix_diag(s)
137
- s2 = square(s)
138
-
139
- f = matrix_set_diag(_safe_reciprocal(expand_dims(s2, -2) - expand_dims(s2, -1)), zeros_like(s), _k_0)
140
- s_inv_mat = _matrix_diag(_safe_reciprocal(s))
141
-
142
- v1 = v[..., :, :m]
143
- dv1 = dv[..., :, :m]
144
-
145
- u_gu = _mat_mul(_adjoint(u), du)
146
- v_gv = _mat_mul(_adjoint(v1), dv1)
147
-
148
- f_u = f * u_gu
149
- f_v = f * v_gv
150
- ds_mat = _matrix_diag(_cast(ds, _dtype(a)))
151
- term1_nouv = (ds_mat + _mat_mul(f_u + _adjoint(f_u), s_mat) + _mat_mul(s_mat, f_v + _adjoint(f_v)))
152
-
153
- term1 = _mat_mul(u, _mat_mul(term1_nouv, _adjoint(v1)))
154
-
155
- if m == n:
156
- da_before_transpose = term1
157
- else:
158
- gv1t = _matrix_transpose(dv1)
159
- gv1t_v1 = _mat_mul(gv1t, v1)
160
- term2_nous = gv1t - _mat_mul(gv1t_v1, _adjoint(v1))
161
-
162
- if full_matrices:
163
- v2 = v[..., :, m:n]
164
- d_v2 = dv[..., :, m:n]
165
-
166
- v1t_gv2 = _mat_mul(_adjoint(v1), d_v2)
167
- term2_nous -= _mat_mul(v1t_gv2, _adjoint(v2))
168
-
169
- u_s_inv = _mat_mul(u, s_inv_mat)
170
- term2 = _mat_mul(u_s_inv, term2_nous)
171
-
172
- da_before_transpose = term1 + term2
173
-
174
- if use_adjoint:
175
- da = _matrix_transpose(da_before_transpose)
176
- else:
177
- da = da_before_transpose
178
-
179
- return (da,)
180
-
181
- return bprop
@@ -1,72 +0,0 @@
1
- # Copyright 2020-2022 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """Generate bprop for other ops"""
17
-
18
- from mindspore.ops import operations as P
19
- from mindspore.ops.operations import _grad_ops as G
20
- from mindspore.ops.operations import _inner_ops as inner
21
- from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
22
- from mindspore.ops._grad_experimental.grad_base import bprop_getters
23
-
24
- # Unused parameters are placeholders.
25
-
26
-
27
- @bprop_getters.register(P.InvertPermutation)
28
- def get_bprop_invert_permutation(self):
29
- """Generate bprop for InvertPermutation"""
30
-
31
- def bprop(x, out, dout):
32
- return (zeros_like(x),)
33
- return bprop
34
-
35
-
36
- @bprop_getters.register(inner.SyncBatchNorm)
37
- def get_bprop_sync_batch_norm(self):
38
- """Grad definition for `SyncBatchNorm` operation."""
39
- input_grad = G.SyncBatchNormGrad(self.epsilon, self.group, self.device_num)
40
-
41
- def bprop(x, scale, b, mean, variance, out, dout):
42
- saved_mean = out[3]
43
- saved_variance = out[4]
44
- out = input_grad(dout[0], x, scale, saved_mean, saved_variance)
45
- dx = out[0]
46
- dscale = out[1]
47
- dbias = out[2]
48
- res = (dx, dscale, dbias, zeros_like(mean), zeros_like(variance))
49
- return res
50
- return bprop
51
-
52
-
53
- @bprop_getters.register(inner.GpuConvertToDynamicShape)
54
- def get_bprop_gpu_convert_to_dynamic_shape(self):
55
- """Get backprop for GpuConvertToDynamicShape."""
56
-
57
- def bprop(x, out, dout):
58
- return (dout,)
59
- return bprop
60
-
61
-
62
- @bprop_getters.register(P._DynamicLossScale) # pylint: disable=W0212
63
- def get_bprop_dynamic_loss_scale(self):
64
- """Get backprop for dynamic_loss_scale."""
65
- mul = P.Mul()
66
- mul.add_prim_attr('split_overflow', True)
67
- mul.add_prim_attr('layer_overflow', self.layer)
68
-
69
- def bprop(x, loss_scale, out, dout):
70
- res = mul(dout, loss_scale)
71
- return res, zeros_like(loss_scale)
72
- return bprop
@@ -1,112 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """Generate bprop for quantization aware ops"""
17
-
18
- from mindspore.ops.operations import _scalar_ops
19
- from mindspore.ops._grad_experimental.grad_base import bprop_getters
20
- from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
21
-
22
-
23
- @bprop_getters.register(_scalar_ops.ScalarAdd)
24
- def get_bprop_scalar_add(self):
25
- """Grad definition for `ScalarAdd` operation."""
26
-
27
- def bprop(x, y, out, dout):
28
- return dout, dout
29
-
30
- return bprop
31
-
32
-
33
- @bprop_getters.register(_scalar_ops.ScalarSub)
34
- def get_bprop_scalar_sub(self):
35
- """Grad definition for `ScalarSub` operation."""
36
-
37
- def bprop(x, y, out, dout):
38
- return dout, 0 - dout
39
-
40
- return bprop
41
-
42
-
43
- @bprop_getters.register(_scalar_ops.ScalarMul)
44
- def get_bprop_scalar_mul(self):
45
- """Grad definition for `ScalarMul` operation."""
46
-
47
- def bprop(x, y, out, dout):
48
- bc_dx = y * dout
49
- bc_dy = x * dout
50
- return bc_dx, bc_dy
51
-
52
- return bprop
53
-
54
-
55
- @bprop_getters.register(_scalar_ops.ScalarDiv)
56
- def get_bprop_scalar_div(self):
57
- """Grad definition for `ScalarDiv` operation."""
58
-
59
- def bprop(x, y, out, dout):
60
- bc_dx = dout / y
61
- bc_dy = 0 - bc_dx * out
62
- return bc_dx, bc_dy
63
-
64
- return bprop
65
-
66
-
67
- @bprop_getters.register(_scalar_ops.ScalarFloordiv)
68
- def get_bprop_scalar_floordiv(self):
69
- """Grad definition for `ScalarFloorDiv` operation."""
70
-
71
- def bprop(x, y, out, dout):
72
- return zeros_like(x), zeros_like(y)
73
-
74
- return bprop
75
-
76
-
77
- @bprop_getters.register(_scalar_ops.ScalarMod)
78
- def get_bprop_scalar_mod(self):
79
- """Grad definition for `ScalarMod` operation."""
80
-
81
- def bprop(x, y, out, dout):
82
- bc_dx = dout
83
- bc_dy = -dout * (x // y)
84
- return bc_dx, bc_dy
85
-
86
- return bprop
87
-
88
-
89
- @bprop_getters.register(_scalar_ops.scalar_eq)
90
- @bprop_getters.register(_scalar_ops.scalar_le)
91
- @bprop_getters.register(_scalar_ops.scalar_lt)
92
- @bprop_getters.register(_scalar_ops.scalar_ge)
93
- @bprop_getters.register(_scalar_ops.scalar_gt)
94
- @bprop_getters.register(_scalar_ops.bit_and)
95
- @bprop_getters.register(_scalar_ops.bit_or)
96
- def get_bprop_scalar_logic(self):
97
- """Grad definition for `ScalarLogicOps` operation."""
98
-
99
- def bprop(x, y, out, dout):
100
- return zeros_like(x), zeros_like(y)
101
-
102
- return bprop
103
-
104
-
105
- @bprop_getters.register(_scalar_ops.ScalarBool)
106
- def get_bprop_scalar_bool(self):
107
- """Grad definition for `ScalarBool` operation."""
108
-
109
- def bprop(x, out, dout):
110
- return zeros_like(x)
111
-
112
- return bprop
@@ -1,351 +0,0 @@
1
- # Copyright 2023 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """grad_sequence_ops"""
17
-
18
- from mindspore.ops.operations import _sequence_ops as seq
19
- from mindspore.ops import operations as P
20
- from mindspore.ops import functional as F
21
- from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
22
- from mindspore.ops._grad_experimental.grad_base import bprop_getters
23
- from mindspore.ops.primitive import Primitive
24
-
25
-
26
- tuple_setitem = Primitive("tuple_setitem")
27
- list_setitem = Primitive("list_setitem")
28
-
29
-
30
- @bprop_getters.register(seq.SequenceCount)
31
- def get_bprop_count(self):
32
- """Generate bprop for SequenceCount"""
33
-
34
- def bprop(x, y, out, dout):
35
- return (zeros_like(x), zeros_like(y))
36
-
37
- return bprop
38
-
39
-
40
- @bprop_getters.register(seq.sequence_len)
41
- def get_bprop_sequence_len(self):
42
- """Generate bprop for sequence_len"""
43
- def bprop(x, out, dout):
44
- return (zeros_like(x),)
45
-
46
- return bprop
47
-
48
-
49
- @bprop_getters.register(seq.SequenceAdd)
50
- def get_bprop_sequence_add(self):
51
- """Generate bprop for SequenceAdd"""
52
- def bprop(x, y, out, dout):
53
- out_offset = seq.SequenceAddOffset()(x, y)
54
- dx = seq.SequenceSlice()(dout, out_offset[0], len(x), 1)
55
- dy = seq.SequenceSlice()(dout, out_offset[1], len(x) + len(y), 1)
56
-
57
- return (dx, dy)
58
-
59
- return bprop
60
-
61
-
62
- @bprop_getters.register(seq.SequenceUnstack)
63
- def get_bprop_sequence_unstack(self):
64
- """Generate bprop for SequenceUnstack"""
65
- axis = self.axis
66
-
67
- def bprop(x, out, dout):
68
- seq_unstack_grad = seq.SequenceStack(axis)
69
- out = seq_unstack_grad(dout)
70
- return (out,)
71
-
72
- return bprop
73
-
74
-
75
- @bprop_getters.register(seq.SequenceSlice)
76
- def get_bprop_slice(self):
77
- """Generate bprop for SequenceSlice"""
78
-
79
- def bprop(x, start, stop, step, out, dout):
80
- dx = seq.SequenceSliceGrad()(dout, x, start, stop, step)
81
- res = (dx, zeros_like(start), zeros_like(stop), zeros_like(step))
82
- return res
83
-
84
- return bprop
85
-
86
-
87
- @bprop_getters.register(seq.SequenceIndex)
88
- def get_bprop_index(self):
89
- """Generate bprop for SequenceIndex"""
90
-
91
- def bprop(x, y, start, end, out, dout):
92
- res = (zeros_like(x), zeros_like(y), zeros_like(start), zeros_like(end))
93
- return res
94
-
95
- return bprop
96
-
97
-
98
- @bprop_getters.register(seq.InSequence)
99
- def get_bprop_insequence(self):
100
- """Generate bprop for InSequence"""
101
-
102
- def bprop(x, y, out, dout):
103
- return (zeros_like(x), seq.SequenceZerosLike()(y))
104
-
105
- return bprop
106
-
107
-
108
- @bprop_getters.register("tuple_equal")
109
- @bprop_getters.register("list_equal")
110
- def get_bprop_seq_equal(self):
111
- """Generate bprop for tuple_equal and list_equal"""
112
-
113
- def bprop(x, y, out, dout):
114
- return (zeros_like(x), zeros_like(y))
115
-
116
- return bprop
117
-
118
-
119
- @bprop_getters.register("shape_mul")
120
- def get_bprop_shape_mul(self):
121
- """Generate bprop for tuple_equal and list_equal"""
122
-
123
- def bprop(x, out, dout):
124
- dx = seq.ShapeMulGrad()(x, dout)
125
- return (dx,)
126
-
127
- return bprop
128
-
129
-
130
- @bprop_getters.register("tuple_setitem")
131
- def get_bprop_tuple_setitem(self):
132
- """Generate bprop for TupleSetItem and ListSetItem"""
133
-
134
- def bprop(x, idx, value, out, dout):
135
- d_x = tuple_setitem(dout, idx, zeros_like(value))
136
- d_value = dout[idx]
137
- d_idx = 0
138
- return (d_x, zeros_like(d_idx), d_value)
139
-
140
- return bprop
141
-
142
-
143
- @bprop_getters.register("list_setitem")
144
- def get_bprop_list_setitem(self):
145
- """Generate bprop for TupleSetItem and ListSetItem"""
146
-
147
- def bprop(x, idx, value, out, dout):
148
- d_x = list_setitem(dout, idx, zeros_like(value))
149
- d_value = dout[idx]
150
- d_idx = 0
151
- return (d_x, zeros_like(d_idx), d_value)
152
-
153
- return bprop
154
-
155
-
156
- @bprop_getters.register("ListInplaceReverse")
157
- def get_bprop_list_inplace_reverse(self):
158
- """Generate bprop for list inplace reverse"""
159
-
160
- def bprop(x, out, dout):
161
- return (zeros_like(x),)
162
-
163
- return bprop
164
-
165
-
166
- @bprop_getters.register("ListInplaceExtend")
167
- def get_bprop_list_inplace_extend(self):
168
- """Generate bprop for list inplace extend"""
169
-
170
- def bprop(x, y, out, dout):
171
- return (zeros_like(x), zeros_like(y))
172
-
173
- return bprop
174
-
175
-
176
- @bprop_getters.register("ListInplaceInsert")
177
- def get_bprop_list_inplace_insert(self):
178
- """Generate bprop for list inplace insert"""
179
-
180
- def bprop(x, index, target, out, dout):
181
- return (zeros_like(x), zeros_like(index), zeros_like(target))
182
-
183
- return bprop
184
-
185
-
186
- @bprop_getters.register("ListInplacePop")
187
- def get_bprop_list_inplace_pop(self):
188
- """Generate bprop for list inplace pop"""
189
-
190
- def bprop(x, index, out, dout):
191
- return (zeros_like(x), zeros_like(index))
192
-
193
- return bprop
194
-
195
-
196
- @bprop_getters.register(seq.ListAppend)
197
- def get_bprop_list_append(self):
198
- """Generate bprop for ListAppend"""
199
-
200
- def bprop(x, value, out, dout):
201
- d_x = seq.ListAppendAndInsertGrad()(dout, -1)
202
- return (d_x, zeros_like(value))
203
-
204
- return bprop
205
-
206
-
207
- @bprop_getters.register(seq.ListInsert)
208
- def get_bprop_list_insert(self):
209
- """Generate bprop for ListInsert"""
210
-
211
- def bprop(x, idx, value, out, dout):
212
- d_x = seq.ListAppendAndInsertGrad()(dout, idx)
213
- return (d_x, zeros_like(idx), zeros_like(value))
214
-
215
- return bprop
216
-
217
-
218
- @bprop_getters.register(seq.TupleToTensor)
219
- def get_bprop_tuple_to_tensor(self):
220
- """Generate bprop for TupleToTensor"""
221
-
222
- def bprop(x, dtype, out, dout):
223
- tuple_type = F.typeof(x)
224
- dout = P.Cast()(dout, tuple_type)
225
- d_x = seq.TensorToTuple()(dout)
226
- return (d_x, zeros_like(dtype))
227
-
228
- return bprop
229
-
230
-
231
- @bprop_getters.register(seq.ListToTensor)
232
- def get_bprop_list_to_tensor(self):
233
- """Generate bprop for ListToTensor"""
234
-
235
- def bprop(x, dtype, out, dout):
236
- tuple_type = F.typeof(x)
237
- dout = P.Cast()(dout, tuple_type)
238
- d_x = seq.TensorToList()(dout)
239
- return (d_x, zeros_like(dtype))
240
-
241
- return bprop
242
-
243
-
244
- @bprop_getters.register(P.ScalarToTensor)
245
- def get_bprop_scalar_to_tensor(self):
246
- """Generate bprop for ScalarToTensor"""
247
-
248
- def bprop(x, dtype, out, dout):
249
- scalar_type = F.typeof(x)
250
- dout = P.Cast()(dout, scalar_type)
251
- d_x = seq.TensorToScalar()(dout)
252
- return (d_x, zeros_like(dtype))
253
-
254
- return bprop
255
-
256
-
257
- @bprop_getters.register(seq.TensorToTuple)
258
- def get_bprop_tensor_to_tuple(self):
259
- """Generate bprop for TensorToTuple"""
260
-
261
- def bprop(x, out, dout):
262
- dtype = F.typeof(x)
263
- d_x = seq.TupleToTensor()(dout, dtype)
264
- return (d_x,)
265
-
266
- return bprop
267
-
268
-
269
- @bprop_getters.register(seq.TensorToList)
270
- def get_bprop_tensor_to_list(self):
271
- """Generate bprop for TensorToList"""
272
-
273
- def bprop(x, out, dout):
274
- dtype = F.typeof(x)
275
- d_x = seq.ListToTensor()(dout, dtype)
276
- return (d_x,)
277
-
278
- return bprop
279
-
280
-
281
- @bprop_getters.register(seq.TensorToScalar)
282
- def get_bprop_tensor_to_scalar(self):
283
- """Generate bprop for TensorToScalar"""
284
-
285
- def bprop(x, out, dout):
286
- dtype = F.typeof(x)
287
- d_x = P.ScalarToTensor()(dout, dtype)
288
- return (d_x,)
289
-
290
- return bprop
291
-
292
-
293
- @bprop_getters.register("tuple_le")
294
- @bprop_getters.register("tuple_lt")
295
- @bprop_getters.register("list_le")
296
- @bprop_getters.register("list_lt")
297
- def get_bprop_less(self):
298
- """Generate bprop for SequenceLessThan and SequenceLessEqual"""
299
-
300
- def bprop(x, y, out, dout):
301
- return (zeros_like(x), zeros_like(y))
302
-
303
- return bprop
304
-
305
-
306
- @bprop_getters.register(seq.SequenceMul)
307
- def get_bprop_mul(self):
308
- """Generate bprop for SequenceMul"""
309
-
310
- def bprop(x, y, out, dout):
311
- dx = x
312
- if isinstance(x, tuple):
313
- for i in range(len(x)):
314
- dx = tuple_setitem(dx, i, dout[i])
315
- else:
316
- for i in range(len(x)):
317
- dx = list_setitem(dx, i, dout[i])
318
- return (dx, zeros_like(y))
319
-
320
- return bprop
321
-
322
-
323
- @bprop_getters.register(seq.SequenceMin)
324
- @bprop_getters.register(seq.SequenceMax)
325
- def get_bprop_max_min(self):
326
- """Generate bprop for SequenceMax and SequenceMax"""
327
-
328
- def bprop(x, out, dout):
329
- index = x.index(out)
330
- if isinstance(x, tuple):
331
- dx = tuple_setitem(zeros_like(x), index, dout)
332
- else:
333
- dx = list_setitem(zeros_like(x), index, dout)
334
- return (dx,)
335
-
336
- return bprop
337
-
338
-
339
- @bprop_getters.register("tuple_greater_than")
340
- @bprop_getters.register("list_greater_than")
341
- @bprop_getters.register("tuple_greater_equal")
342
- @bprop_getters.register("list_greater_equal")
343
- def get_bprop_greater(self):
344
- """Generate bprop for tuple_greater_than, list_greater_than,
345
- tuple_greater_equal, list_greater_equal.
346
- """
347
-
348
- def bprop(x, y, out, dout):
349
- return (zeros_like(x), zeros_like(y))
350
-
351
- return bprop