mindspore 2.1.0__cp38-none-any.whl → 2.2.0__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (539) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +49 -16
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  20. mindspore/_akg/akg/utils/kernel_exec.py +58 -260
  21. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  22. mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
  23. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  24. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  25. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  26. mindspore/_check_jit_forbidden_api.py +3 -1
  27. mindspore/_checkparam.py +26 -32
  28. mindspore/_extends/graph_kernel/__init__.py +0 -1
  29. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  30. mindspore/_extends/graph_kernel/splitter.py +1 -9
  31. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  32. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
  33. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  34. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  35. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +4 -4
  36. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  37. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  38. mindspore/_extends/parse/__init__.py +12 -15
  39. mindspore/_extends/parse/namespace.py +7 -33
  40. mindspore/_extends/parse/parser.py +61 -71
  41. mindspore/_extends/parse/resources.py +1 -1
  42. mindspore/_extends/parse/standard_method.py +72 -95
  43. mindspore/_extends/parse/trope.py +1 -1
  44. mindspore/_extends/remote/kernel_build_server.py +24 -7
  45. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  46. mindspore/_install_custom.py +43 -0
  47. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  48. mindspore/amp.py +47 -11
  49. mindspore/bin/cache_admin +0 -0
  50. mindspore/bin/cache_server +0 -0
  51. mindspore/boost/boost.py +1 -8
  52. mindspore/boost/boost_cell_wrapper.py +3 -2
  53. mindspore/boost/grad_accumulation.py +1 -1
  54. mindspore/boost/group_loss_scale_manager.py +8 -7
  55. mindspore/common/__init__.py +5 -3
  56. mindspore/common/_jit_fallback_utils.py +6 -0
  57. mindspore/common/_register_for_adapter.py +2 -0
  58. mindspore/common/_register_for_tensor.py +2 -2
  59. mindspore/common/_stub_tensor.py +13 -0
  60. mindspore/common/_utils.py +13 -0
  61. mindspore/common/api.py +173 -258
  62. mindspore/common/auto_dynamic_shape.py +498 -0
  63. mindspore/common/dtype.py +18 -11
  64. mindspore/common/dump.py +6 -4
  65. mindspore/common/initializer.py +14 -14
  66. mindspore/common/jit_config.py +33 -15
  67. mindspore/common/lazy_inline.py +126 -7
  68. mindspore/common/mindir_util.py +101 -0
  69. mindspore/common/parameter.py +51 -41
  70. mindspore/common/seed.py +4 -4
  71. mindspore/common/sparse_tensor.py +13 -14
  72. mindspore/common/tensor.py +240 -145
  73. mindspore/communication/__init__.py +7 -4
  74. mindspore/communication/_comm_helper.py +83 -4
  75. mindspore/communication/management.py +152 -84
  76. mindspore/config/op_info.config +13 -2
  77. mindspore/config/super_bar_config.json +4 -2
  78. mindspore/context.py +143 -59
  79. mindspore/dataset/__init__.py +5 -5
  80. mindspore/dataset/audio/__init__.py +2 -2
  81. mindspore/dataset/audio/transforms.py +52 -52
  82. mindspore/dataset/callback/ds_callback.py +16 -2
  83. mindspore/dataset/core/config.py +68 -51
  84. mindspore/dataset/engine/cache_client.py +28 -5
  85. mindspore/dataset/engine/datasets.py +250 -112
  86. mindspore/dataset/engine/datasets_audio.py +43 -211
  87. mindspore/dataset/engine/datasets_standard_format.py +11 -35
  88. mindspore/dataset/engine/datasets_text.py +43 -67
  89. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  90. mindspore/dataset/engine/datasets_vision.py +219 -1029
  91. mindspore/dataset/engine/iterators.py +11 -4
  92. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  93. mindspore/dataset/engine/obs/util.py +3 -0
  94. mindspore/dataset/engine/samplers.py +1 -1
  95. mindspore/dataset/engine/validators.py +19 -5
  96. mindspore/dataset/text/__init__.py +3 -3
  97. mindspore/dataset/text/transforms.py +101 -127
  98. mindspore/dataset/text/utils.py +205 -138
  99. mindspore/dataset/transforms/__init__.py +1 -1
  100. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  101. mindspore/dataset/transforms/transforms.py +95 -40
  102. mindspore/dataset/utils/browse_dataset.py +8 -2
  103. mindspore/dataset/utils/line_reader.py +17 -19
  104. mindspore/dataset/vision/__init__.py +3 -3
  105. mindspore/dataset/vision/c_transforms.py +6 -3
  106. mindspore/dataset/vision/transforms.py +409 -287
  107. mindspore/dataset/vision/utils.py +13 -14
  108. mindspore/dataset/vision/validators.py +11 -1
  109. mindspore/experimental/map_parameter.py +14 -0
  110. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  111. mindspore/{nn/optim_ex → experimental/optim}/adam.py +59 -66
  112. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  113. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  114. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  115. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  116. mindspore/gen_ops.py +273 -0
  117. mindspore/include/OWNERS +0 -1
  118. mindspore/include/api/data_type.h +2 -1
  119. mindspore/include/api/graph.h +0 -15
  120. mindspore/include/api/kernel.h +2 -0
  121. mindspore/include/api/kernel_api.h +37 -12
  122. mindspore/include/api/model.h +0 -14
  123. mindspore/include/api/types.h +37 -4
  124. mindspore/include/c_api/ms/abstract.h +67 -0
  125. mindspore/include/c_api/ms/attribute.h +197 -0
  126. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  127. mindspore/include/c_api/ms/base/macros.h +32 -0
  128. mindspore/include/c_api/ms/base/status.h +33 -0
  129. mindspore/include/c_api/ms/base/types.h +282 -0
  130. mindspore/include/c_api/ms/context.h +102 -0
  131. mindspore/include/c_api/ms/graph.h +160 -0
  132. mindspore/include/c_api/ms/node.h +606 -0
  133. mindspore/include/c_api/ms/tensor.h +161 -0
  134. mindspore/include/c_api/ms/value.h +84 -0
  135. mindspore/include/dataset/constants.h +6 -5
  136. mindspore/include/dataset/execute.h +23 -13
  137. mindspore/include/dataset/text.h +26 -26
  138. mindspore/include/dataset/transforms.h +13 -13
  139. mindspore/include/dataset/vision.h +60 -60
  140. mindspore/include/dataset/vision_ascend.h +5 -6
  141. mindspore/include/dataset/vision_lite.h +17 -17
  142. mindspore/include/mindapi/base/type_id.h +1 -0
  143. mindspore/include/mindapi/base/types.h +1 -0
  144. mindspore/lib/libdnnl.so.2 +0 -0
  145. mindspore/lib/libjemalloc.so.2 +0 -0
  146. mindspore/lib/libmindspore.so +0 -0
  147. mindspore/lib/libmindspore_backend.so +0 -0
  148. mindspore/lib/libmindspore_common.so +0 -0
  149. mindspore/lib/libmindspore_core.so +0 -0
  150. mindspore/lib/libmindspore_glog.so.0 +0 -0
  151. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  152. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  153. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  154. mindspore/lib/libmindspore_shared_lib.so +0 -0
  155. mindspore/lib/libnnacl.so +0 -0
  156. mindspore/lib/libopencv_core.so.4.5 +0 -0
  157. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  158. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  159. mindspore/lib/libps_cache.so +0 -0
  160. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  161. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  162. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
  163. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  164. mindspore/lib/plugin/ascend/libakg.so +0 -0
  165. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  166. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  167. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  168. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  169. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  170. mindspore/lib/plugin/cpu/libakg.so +0 -0
  171. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  172. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  173. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  174. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  175. mindspore/nn/__init__.py +0 -2
  176. mindspore/nn/cell.py +316 -74
  177. mindspore/nn/dynamic_lr.py +21 -21
  178. mindspore/nn/layer/activation.py +21 -28
  179. mindspore/nn/layer/basic.py +15 -13
  180. mindspore/nn/layer/channel_shuffle.py +1 -1
  181. mindspore/nn/layer/container.py +271 -9
  182. mindspore/nn/layer/conv.py +310 -207
  183. mindspore/nn/layer/dense.py +8 -5
  184. mindspore/nn/layer/embedding.py +33 -27
  185. mindspore/nn/layer/flash_attention.py +82 -41
  186. mindspore/nn/layer/image.py +8 -6
  187. mindspore/nn/layer/math.py +13 -18
  188. mindspore/nn/layer/normalization.py +107 -66
  189. mindspore/nn/layer/padding.py +1 -1
  190. mindspore/nn/layer/pooling.py +131 -109
  191. mindspore/nn/layer/rnn_cells.py +22 -17
  192. mindspore/nn/layer/rnns.py +13 -16
  193. mindspore/nn/layer/thor_layer.py +1 -1
  194. mindspore/nn/layer/transformer.py +221 -154
  195. mindspore/nn/learning_rate_schedule.py +9 -1
  196. mindspore/nn/loss/loss.py +235 -174
  197. mindspore/nn/optim/ada_grad.py +2 -1
  198. mindspore/nn/optim/adadelta.py +1 -0
  199. mindspore/nn/optim/adafactor.py +2 -1
  200. mindspore/nn/optim/adam.py +7 -4
  201. mindspore/nn/optim/adamax.py +3 -2
  202. mindspore/nn/optim/adasum.py +2 -2
  203. mindspore/nn/optim/asgd.py +2 -3
  204. mindspore/nn/optim/ftrl.py +6 -5
  205. mindspore/nn/optim/lamb.py +7 -4
  206. mindspore/nn/optim/lars.py +1 -1
  207. mindspore/nn/optim/lazyadam.py +5 -3
  208. mindspore/nn/optim/momentum.py +2 -1
  209. mindspore/nn/optim/optimizer.py +53 -4
  210. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  211. mindspore/nn/optim/rmsprop.py +4 -3
  212. mindspore/nn/optim/rprop.py +23 -12
  213. mindspore/nn/optim/sgd.py +26 -11
  214. mindspore/nn/optim/thor.py +9 -7
  215. mindspore/nn/probability/bijector/bijector.py +5 -5
  216. mindspore/nn/probability/bijector/power_transform.py +27 -27
  217. mindspore/nn/probability/bijector/softplus.py +3 -3
  218. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  219. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  220. mindspore/nn/probability/distribution/beta.py +3 -3
  221. mindspore/nn/probability/distribution/categorical.py +7 -7
  222. mindspore/nn/probability/distribution/cauchy.py +0 -1
  223. mindspore/nn/probability/distribution/distribution.py +3 -3
  224. mindspore/nn/probability/distribution/gamma.py +3 -3
  225. mindspore/nn/probability/distribution/geometric.py +4 -4
  226. mindspore/nn/probability/distribution/gumbel.py +4 -4
  227. mindspore/nn/probability/distribution/log_normal.py +2 -2
  228. mindspore/nn/probability/distribution/logistic.py +2 -2
  229. mindspore/nn/probability/distribution/poisson.py +4 -4
  230. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  231. mindspore/nn/probability/distribution/uniform.py +6 -6
  232. mindspore/nn/wrap/cell_wrapper.py +78 -34
  233. mindspore/nn/wrap/grad_reducer.py +8 -5
  234. mindspore/nn/wrap/loss_scale.py +105 -42
  235. mindspore/numpy/array_creations.py +1 -2
  236. mindspore/numpy/array_ops.py +3 -2
  237. mindspore/offline_debug/convert_async.py +2 -2
  238. mindspore/ops/_grad_experimental/__init__.py +0 -5
  239. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -2
  240. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  241. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  242. mindspore/ops/_grad_experimental/grad_implementations.py +10 -0
  243. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  244. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -181
  245. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  246. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  247. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  248. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  249. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  250. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  251. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  252. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  253. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  254. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  255. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  256. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  257. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  258. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  259. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  260. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  261. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  262. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  263. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  264. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  265. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  266. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  267. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  268. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  269. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  270. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  271. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  272. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  273. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  274. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  275. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  276. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  277. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  278. mindspore/ops/_primitive_cache.py +1 -1
  279. mindspore/ops/_tracefunc.py +45 -13
  280. mindspore/ops/_utils/utils.py +4 -1
  281. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  282. mindspore/ops/_vmap/vmap_base.py +3 -3
  283. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  284. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  285. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  286. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  287. mindspore/ops/arg_dtype_cast.py +54 -0
  288. mindspore/ops/composite/base.py +37 -10
  289. mindspore/ops/composite/math_ops.py +5 -4
  290. mindspore/ops/composite/multitype_ops/_compile_utils.py +273 -72
  291. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  292. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  293. mindspore/ops/composite/multitype_ops/getitem_impl.py +40 -2
  294. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  295. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  296. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  297. mindspore/ops/deprecated.py +304 -0
  298. mindspore/ops/function/__init__.py +4 -1
  299. mindspore/ops/function/array_func.py +167 -189
  300. mindspore/ops/function/clip_func.py +81 -13
  301. mindspore/ops/function/debug_func.py +1 -1
  302. mindspore/ops/function/grad/grad_func.py +18 -8
  303. mindspore/ops/function/image_func.py +10 -4
  304. mindspore/ops/function/linalg_func.py +5 -5
  305. mindspore/ops/function/math_func.py +575 -386
  306. mindspore/ops/function/nn_func.py +470 -251
  307. mindspore/ops/function/random_func.py +86 -56
  308. mindspore/ops/function/sparse_func.py +1 -1
  309. mindspore/ops/function/sparse_unary_func.py +14 -12
  310. mindspore/ops/function/vmap_func.py +6 -5
  311. mindspore/ops/functional.py +15 -10
  312. mindspore/ops/op_info_register.py +235 -19
  313. mindspore/ops/operations/__init__.py +25 -17
  314. mindspore/ops/operations/_grad_ops.py +52 -7
  315. mindspore/ops/operations/_inner_ops.py +213 -12
  316. mindspore/ops/operations/_quant_ops.py +4 -8
  317. mindspore/ops/operations/_sequence_ops.py +42 -0
  318. mindspore/ops/operations/array_ops.py +64 -280
  319. mindspore/ops/operations/comm_ops.py +105 -57
  320. mindspore/ops/operations/custom_ops.py +10 -3
  321. mindspore/ops/operations/debug_ops.py +8 -4
  322. mindspore/ops/operations/image_ops.py +18 -12
  323. mindspore/ops/operations/math_ops.py +185 -138
  324. mindspore/ops/operations/nn_ops.py +716 -492
  325. mindspore/ops/operations/other_ops.py +0 -22
  326. mindspore/ops/operations/random_ops.py +53 -111
  327. mindspore/ops/operations/sparse_ops.py +3 -1
  328. mindspore/ops/primitive.py +24 -18
  329. mindspore/parallel/_auto_parallel_context.py +68 -8
  330. mindspore/parallel/_cost_model_context.py +2 -2
  331. mindspore/parallel/_offload_context.py +17 -3
  332. mindspore/parallel/_parallel_serialization.py +2 -2
  333. mindspore/parallel/_ps_context.py +12 -0
  334. mindspore/parallel/_tensor.py +14 -12
  335. mindspore/parallel/_transformer/layers.py +5 -3
  336. mindspore/parallel/_transformer/loss.py +1 -0
  337. mindspore/parallel/_transformer/moe.py +2 -2
  338. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  339. mindspore/parallel/_transformer/transformer.py +23 -3
  340. mindspore/parallel/_utils.py +11 -7
  341. mindspore/parallel/algo_parameter_config.py +85 -5
  342. mindspore/parallel/checkpoint_transform.py +6 -10
  343. mindspore/parallel/shard.py +4 -4
  344. mindspore/profiler/common/struct_type.py +3 -3
  345. mindspore/profiler/common/util.py +3 -2
  346. mindspore/profiler/envprofiling.py +1 -1
  347. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  348. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  349. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  350. mindspore/profiler/parser/ascend_hccl_generator.py +17 -12
  351. mindspore/profiler/parser/ascend_msprof_exporter.py +104 -252
  352. mindspore/profiler/parser/ascend_msprof_generator.py +8 -8
  353. mindspore/profiler/parser/ascend_op_generator.py +5 -5
  354. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  355. mindspore/profiler/parser/ascend_timeline_generator.py +9 -6
  356. mindspore/profiler/parser/base_timeline_generator.py +9 -7
  357. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +14 -10
  358. mindspore/profiler/parser/flops_parser.py +15 -11
  359. mindspore/profiler/parser/framework_parser.py +37 -21
  360. mindspore/profiler/parser/hccl_parser.py +16 -12
  361. mindspore/profiler/parser/integrator.py +22 -11
  362. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  363. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  364. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  365. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  366. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  367. mindspore/profiler/parser/optime_parser.py +1 -1
  368. mindspore/profiler/parser/profiler_info.py +2 -2
  369. mindspore/profiler/parser/step_trace_parser.py +11 -14
  370. mindspore/profiler/profiling.py +139 -71
  371. mindspore/rewrite/api/node.py +102 -19
  372. mindspore/rewrite/api/node_type.py +5 -1
  373. mindspore/rewrite/api/scoped_value.py +9 -17
  374. mindspore/rewrite/api/symbol_tree.py +131 -47
  375. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  376. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  377. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  378. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  379. mindspore/rewrite/common/rewrite_elog.py +5 -1
  380. mindspore/rewrite/namer.py +33 -24
  381. mindspore/rewrite/namespace.py +14 -5
  382. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  383. mindspore/rewrite/node/call_function.py +79 -0
  384. mindspore/rewrite/node/cell_container.py +135 -0
  385. mindspore/rewrite/node/control_flow.py +88 -0
  386. mindspore/rewrite/{node.py → node/node.py} +273 -234
  387. mindspore/rewrite/node/node_manager.py +254 -0
  388. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  389. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  390. mindspore/rewrite/parsers/assign_parser.py +216 -221
  391. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  392. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  393. mindspore/rewrite/parsers/constant_parser.py +9 -6
  394. mindspore/rewrite/parsers/container_parser.py +9 -7
  395. mindspore/rewrite/parsers/for_parser.py +36 -15
  396. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  397. mindspore/rewrite/parsers/if_parser.py +28 -24
  398. mindspore/rewrite/parsers/module_parser.py +196 -25
  399. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  400. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  401. mindspore/rewrite/parsers/return_parser.py +6 -6
  402. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  403. mindspore/rewrite/sparsify/utils.py +1 -1
  404. mindspore/rewrite/symbol_tree.py +525 -577
  405. mindspore/rewrite/symbol_tree_builder.py +9 -193
  406. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  407. mindspore/run_check/_check_version.py +2 -2
  408. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  409. mindspore/safeguard/rewrite_obfuscation.py +517 -0
  410. mindspore/scipy/linalg.py +1 -1
  411. mindspore/scipy/optimize/minimize.py +7 -3
  412. mindspore/train/_utils.py +7 -3
  413. mindspore/train/amp.py +323 -123
  414. mindspore/train/anf_ir_pb2.py +14 -2
  415. mindspore/train/callback/_backup_and_restore.py +2 -12
  416. mindspore/train/callback/_callback.py +29 -4
  417. mindspore/train/callback/_checkpoint.py +23 -8
  418. mindspore/train/callback/_early_stop.py +2 -2
  419. mindspore/train/callback/_landscape.py +4 -4
  420. mindspore/train/callback/_loss_monitor.py +2 -2
  421. mindspore/train/callback/_on_request_exit.py +2 -2
  422. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  423. mindspore/train/callback/_summary_collector.py +14 -7
  424. mindspore/train/callback/_time_monitor.py +58 -5
  425. mindspore/train/data_sink.py +5 -11
  426. mindspore/train/dataset_helper.py +83 -57
  427. mindspore/train/loss_scale_manager.py +2 -2
  428. mindspore/train/metrics/__init__.py +3 -3
  429. mindspore/train/metrics/cosine_similarity.py +1 -1
  430. mindspore/train/metrics/hausdorff_distance.py +3 -2
  431. mindspore/train/metrics/mean_surface_distance.py +3 -2
  432. mindspore/train/metrics/metric.py +39 -19
  433. mindspore/train/metrics/roc.py +2 -2
  434. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  435. mindspore/train/mind_ir_pb2.py +85 -36
  436. mindspore/train/model.py +185 -45
  437. mindspore/train/serialization.py +390 -150
  438. mindspore/train/summary/_writer_pool.py +3 -2
  439. mindspore/train/summary/summary_record.py +14 -10
  440. mindspore/train/train_thor/convert_utils.py +3 -3
  441. mindspore/train/train_thor/dataset_helper.py +1 -1
  442. mindspore/version.py +1 -1
  443. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/METADATA +6 -7
  444. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/RECORD +447 -507
  445. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
  446. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  447. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  448. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  449. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  450. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  451. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  452. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  453. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  454. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  455. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  456. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  457. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  458. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  459. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  460. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  461. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  462. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  463. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  464. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  465. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  466. mindspore/_extends/graph_kernel/expander.py +0 -80
  467. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  468. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  469. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  470. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  471. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  472. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  473. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  474. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  475. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  476. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  477. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  478. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  479. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  480. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  481. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  482. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  483. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  484. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  485. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  486. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  487. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  488. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  489. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  490. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  491. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  492. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  493. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  494. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  495. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  496. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  497. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  498. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  499. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  500. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  501. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  502. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  503. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  504. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  505. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  506. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  507. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  508. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  509. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  510. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  511. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  512. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  513. mindspore/dataset/datapreprocess/__init__.py +0 -20
  514. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  515. mindspore/include/api/net.h +0 -142
  516. mindspore/nn/lr_scheduler.py +0 -262
  517. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  518. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  519. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  520. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  521. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  522. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  523. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  524. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  525. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  526. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  527. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  528. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  529. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  530. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  531. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  532. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  533. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  534. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  535. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  536. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  537. mindspore/rewrite/node_visitor.py +0 -44
  538. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
  539. {mindspore-2.1.0.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
@@ -37,7 +37,7 @@ class Initializer:
37
37
  Initializers are intended to be used for delayed initialization in parallel mode rather than Tensor
38
38
  initialization. If you have to use Initializers to create a Tensor, :func:`mindspore.Tensor.init_data` should be
39
39
  followed in most of the cases. For more information, please refer to `mindspore.Tensor.init_data
40
- <https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore/Tensor/mindspore.Tensor.init_data.html#
40
+ <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore/Tensor/mindspore.Tensor.init_data.html#
41
41
  mindspore-tensor-init-data>`_ .
42
42
 
43
43
  Args:
@@ -350,8 +350,8 @@ class HeUniform(Initializer):
350
350
  .. math::
351
351
  boundary = \text{gain} \times \sqrt{\frac{3}{fan\_mode}}
352
352
 
353
- where :math:`gain` is an optional scaling factor. If :math:`fan\_mode` is 'fan_in', it is the number of input units
354
- of the weight tensor. If :math:`fan\_mode` is 'fan_out',
353
+ where :math:`gain` is an optional scaling factor. If :math:`fan\_mode` is ``'fan_in'``,
354
+ it is the number of input units of the weight tensor. If :math:`fan\_mode` is ``'fan_out'``,
355
355
  it is the number of output units of the weight tensor.
356
356
 
357
357
  For details of HeUniform algorithm, please check
@@ -487,7 +487,7 @@ class Identity(Initializer):
487
487
  class Sparse(Initializer):
488
488
  """
489
489
  Generates a 2 dimension sparse matrix array in order to initialize a tensor. The non-zero positions
490
- will be filled with the value sampled from the normal distribution :math:`{N}(0, 0.01)`.
490
+ will be filled with the value sampled from the normal distribution :math:`{N}(0, sigma)`.
491
491
 
492
492
  Args:
493
493
  sparsity (float): The fraction of elements being set to zero in each column.
@@ -525,11 +525,11 @@ class Sparse(Initializer):
525
525
  class Dirac(Initializer):
526
526
  """
527
527
  Generates an array with the Dirac delta function in order to initialize a tensor.
528
- It tries to preserves the identity of input for convolution layers.
529
- For group convolution, each group of channels will be preserved respectively.
528
+ It's usually used in convolution layers, preserves as many identities of the inputs as possible.
530
529
 
531
530
  Args:
532
- groups (int): The number of group in convolution layer. Default: ``1`` .
531
+ groups (int): The number of groups in convolution layer. Each group applies the same initialization.
532
+ Default: ``1`` .
533
533
 
534
534
  Raises:
535
535
  ValueError: If the dimension of the initialized tensor is not in [3, 4, 5].
@@ -582,7 +582,7 @@ class Orthogonal(Initializer):
582
582
  If the dimension is greater than 2, the trailing dimensions will be flattened.
583
583
 
584
584
  Args:
585
- gain (float): An optional scaling factor. Default: ``1.`` .
585
+ gain (float): An optional scaling factor. Default: ``1.0`` .
586
586
 
587
587
  Raises:
588
588
  ValueError: If the dimension of input tensor is less than 2.
@@ -628,11 +628,11 @@ class VarianceScaling(Initializer):
628
628
  Generates an random array with scaling in order to initialize a tensor.
629
629
  When `distribution` is 'truncated_normal' or 'untruncated_normal', the value will be sampled from truncated or
630
630
  untruncated normal distribution with a mean of 0 and a scaled standard deviation
631
- :math:`stddev = \sqrt{\frac{scale}{n}}`. :math:`n` will be the number of input units if `mode` is 'fan_in',
631
+ :math:`stddev = \sqrt{\frac{scale}{n}}`. :math:`n` will be the number of input units if `mode` is ``'fan_in'``,
632
632
  while :math:`n` will be
633
- the number of output units if `mode` is 'fan_out'. :math:`n` will be the average of 'fan_in' and 'fan_out'
634
- if `mode` is 'fan_avg'.
635
- When `distribution` is 'uniform', the value will be sampled from a uniform distribution within the limit of
633
+ the number of output units if `mode` is ``'fan_out'``. :math:`n` will be the average of ``'fan_in'``
634
+ and ``'fan_out'`` if `mode` is ``'fan_avg'``.
635
+ When `distribution` is ``'uniform'``, the value will be sampled from a uniform distribution within the limit of
636
636
  :math:`[-\sqrt{\frac{3*scale}{n}}, \sqrt{\frac{3*scale}{n}}]`.
637
637
 
638
638
  Args:
@@ -643,8 +643,8 @@ class VarianceScaling(Initializer):
643
643
 
644
644
  Raises:
645
645
  ValueError: If `scale` is not greater than 0.
646
- ValueError: If `mode` is not 'fan_in', 'fan_out' or 'fan_avg'.
647
- ValueError: If `distribution` is not 'uniform', 'truncated_normal' or 'untruncated_normal'.
646
+ ValueError: If `mode` is not ``'fan_in'``, ``'fan_out'`` or ``'fan_avg'``.
647
+ ValueError: If `distribution` is not ``'uniform'``, ``'truncated_normal'`` or ``'untruncated_normal'``.
648
648
 
649
649
  Examples:
650
650
  >>> import mindspore
@@ -1,4 +1,4 @@
1
- # Copyright 2022 Huawei Technologies Co., Ltd
1
+ # Copyright 2022-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -23,19 +23,34 @@ class JitConfig:
23
23
  This is an experimental API that is subject to change or deletion.
24
24
 
25
25
  Args:
26
- jit_level (str): Option for argument `level` for Optimization of lift graph.
27
- Supports ["O0", "O1", "O2", "O3"]. Default: ``"O1"`` .
26
+ jit_level (str, optional): Used to control the compilation optimization level.
27
+ Supports ["O0", "O1", "O2"]. Default: ``"O1"`` .
28
28
 
29
- - "O0": Basic optimization.
30
- - "O1": Manual optimization.
31
- - "O2": Manual optimization and graph computation fusion.
32
- - "O3": Performance optimization, no generalization guaranteed.
29
+ - ``"O0"``: Except for optimizations that may affect functionality, all other optimizations are turned off.
30
+ - ``"O1"``: Using commonly used optimizations, recommended to set the O1 level.
31
+ - ``"O2"``: Activate some experimental level optimizations.
33
32
 
34
- exc_mode (str): Mode for execute the network. Supports ["auto", "sink", "no_sink"]. Default: ``"auto"`` .
33
+ exc_mode (str, optional): Control the execution mode of the model.
34
+ Supports ["auto", "sink", "no_sink"]. Default: ``"auto"`` .
35
35
 
36
- - "auto": Automatic Policies.
37
- - "sink": Build computational graphs with the sink mode.
38
- - "no_sink": Build computational graphs with no sink mode.
36
+ - ``"auto"``: The framework automatically selects the execution method.
37
+ - ``"sink"``: Support the network to load and load the entire device at once, and then execute it by
38
+ input driver, without the need to iterate through each operator to achieve better execution performance.
39
+ This mode is only supported on the Ascend backend.
40
+ - ``"no_sink"``: The network model is executed asynchronously one by one using a single operator.
41
+
42
+ jit_syntax_level (str, optional): JIT syntax level for graph compiling.
43
+ The value must be ``"STRICT"`` , ``"LAX"`` or ``""`` . Default to an empty string, which means that this
44
+ JitConfig configuration will be ignored and the jit_syntax_level of ms.context will be used.
45
+ For more details about ms.context, refer to
46
+ `set_context <https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore/mindspore.set_context.html>`_ .
47
+ Default: ``""`` .
48
+
49
+ - ``"STRICT"``: Only basic syntax is supported, and execution performance is optimal. Can be used for MindIR
50
+ load and export.
51
+ - ``"LAX"``: Compatible with all Python syntax as much as possible. However, execution performance may be
52
+ affected and not optimal. Cannot be used for MindIR load and export due to some syntax that may not be
53
+ able to be exported.
39
54
 
40
55
  **kwargs (dict): A dictionary of keyword arguments that the class needs.
41
56
 
@@ -45,16 +60,19 @@ class JitConfig:
45
60
  >>> jitconfig = JitConfig(jit_level="O1")
46
61
  >>>
47
62
  >>> # Define the network structure of LeNet5. Refer to
48
- >>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
63
+ >>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
49
64
  >>> net = LeNet5()
50
65
  >>>
51
66
  >>> net.set_jit_config(jitconfig)
52
67
  """
53
- def __init__(self, jit_level="O1", exc_mode="auto", **kwargs):
54
- if jit_level not in ["O0", "O1", "O2", "O3"]:
55
- raise ValueError("For 'jit_level' must be one of ['O0', 'O1', 'O2', 'O3'].")
68
+ def __init__(self, jit_level="O1", exc_mode="auto", jit_syntax_level="", **kwargs):
69
+ if jit_level not in ["O0", "O1", "O2"]:
70
+ raise ValueError("For 'jit_level' must be one of ['O0', 'O1', 'O2'].")
56
71
  if exc_mode not in ['auto', 'sink', 'no_sink']:
57
72
  raise ValueError("For 'exc_mode' must be one of '['auto', 'sink', 'no_sink']'.")
73
+ if jit_syntax_level != "" and jit_syntax_level not in ['STRICT', 'COMPATIBLE', 'LAX']:
74
+ raise ValueError("For 'jit_syntax_level' must be one of '['STRICT', 'LAX']'.")
58
75
  self.jit_config_dict = kwargs
59
76
  self.jit_config_dict["jit_level"] = jit_level
60
77
  self.jit_config_dict["exc_mode"] = exc_mode
78
+ self.jit_config_dict["jit_syntax_level"] = jit_syntax_level
@@ -21,17 +21,136 @@ from functools import wraps
21
21
 
22
22
  def lazy_inline(fn=None, attrs=None):
23
23
  """
24
- Make the cell to be reusable. The function graph will not be inline in the front.
24
+ Make the cell to be reusable. The corresponding sub graph will not be inline at first.
25
25
 
26
- Registering the decorator of the built-in operator cell __init__
27
- function will add save all the parameters of __init__ as operator attributes.
26
+ Registering the decorator of the built-in function `__init__` of a cell, the decorator
27
+ will add the parameters of `__init__` according to the `attrs` as the attributes of this cell.
28
+
29
+ .. warning::
30
+ This feature is only supported on Ascend and is not supported on other hardwares.
28
31
 
29
32
  Args:
30
- fn (function): __init__ function of cell.
31
- attrs (list(string) | string): attr list.
33
+ fn (function): `__init__` function of a cell.
34
+ attrs (Union[list[string], string]): The attributes list to add for the cell.
32
35
 
33
36
  Returns:
34
37
  function, original function.
38
+
39
+ Supported Platforms:
40
+ ``Ascend``
41
+
42
+ Examples:
43
+ >>> import numpy as np
44
+ >>> from mindspore import Tensor
45
+ >>> import mindspore.nn as nn
46
+ >>> from mindspore import lazy_inline
47
+ >>> from mindspore import context
48
+ >>> from mindspore import ops
49
+ >>> def conv3x3(in_channels, out_channels, stride=1, padding=1, pad_mode='pad'):
50
+ ... return nn.Conv2d(in_channels, out_channels,
51
+ ... kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode)
52
+ ...
53
+ >>> def conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='pad'):
54
+ ... return nn.Conv2d(in_channels, out_channels,
55
+ ... kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode)
56
+ ...
57
+ >>> class Block(nn.Cell):
58
+ ... expansion = 4
59
+ ...
60
+ ... @lazy_inline
61
+ ... def __init__(self,
62
+ ... in_channels,
63
+ ... out_channels,
64
+ ... stride=1,
65
+ ... down_sample=False):
66
+ ... super(Block, self).__init__()
67
+ ...
68
+ ... out_chls = out_channels
69
+ ... self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0)
70
+ ... self.bn1 = nn.BatchNorm2d(out_chls)
71
+ ...
72
+ ... self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=1)
73
+ ... self.bn2 = nn.BatchNorm2d(out_chls)
74
+ ...
75
+ ... self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
76
+ ... self.bn3 = nn.BatchNorm2d(out_channels)
77
+ ...
78
+ ... self.relu = nn.ReLU()
79
+ ... self.downsample = down_sample
80
+ ...
81
+ ... self.conv_down_sample = conv1x1(in_channels, out_channels,
82
+ ... stride=stride, padding=0)
83
+ ... self.bn_down_sample = nn.BatchNorm2d(out_channels)
84
+ ... self.add = ops.Add()
85
+ ...
86
+ ... def construct(self, x):
87
+ ... identity = x
88
+ ...
89
+ ... out = self.conv1(x)
90
+ ... out = self.bn1(out)
91
+ ... out = self.relu(out)
92
+ ...
93
+ ... out = self.conv2(out)
94
+ ... out = self.bn2(out)
95
+ ... out = self.relu(out)
96
+ ...
97
+ ... out = self.conv3(out)
98
+ ... out = self.bn3(out)
99
+ ...
100
+ ... if self.downsample:
101
+ ... identity = self.conv_down_sample(identity)
102
+ ... identity = self.bn_down_sample(identity)
103
+ ...
104
+ ... out = self.add(out, identity)
105
+ ... out = self.relu(out)
106
+ ...
107
+ ... return out
108
+ ...
109
+ >>> class Net(nn.Cell):
110
+ ... def __init__(self, block, num_classes=100):
111
+ ... super(Net, self).__init__()
112
+ ...
113
+ ... self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
114
+ ... self.bn1 = nn.BatchNorm2d(64)
115
+ ... self.relu = nn.ReLU()
116
+ ... self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')
117
+ ...
118
+ ... self.layer = self.MakeLayer(
119
+ ... block, 50, in_channels=64, out_channels=2048, stride=2)
120
+ ... self.avgpool = nn.AvgPool2d(7, 1)
121
+ ... self.flatten = ops.Flatten()
122
+ ...
123
+ ... def MakeLayer(self, block, layer_num, in_channels, out_channels, stride):
124
+ ... layers = []
125
+ ... resblk = block(in_channels, out_channels,
126
+ ... stride=stride, down_sample=True)
127
+ ... layers.append(resblk)
128
+ ...
129
+ ... for _ in range(1, layer_num):
130
+ ... resblk = block(out_channels, out_channels, stride=1)
131
+ ... layers.append(resblk)
132
+ ...
133
+ ... return nn.SequentialCell(layers)
134
+ ...
135
+ ... def construct(self, x):
136
+ ... x = self.conv1(x)
137
+ ... x = self.bn1(x)
138
+ ... x = self.relu(x)
139
+ ... x = self.maxpool(x)
140
+ ... x = self.layer(x)
141
+ ... x = self.avgpool(x)
142
+ ... x = self.flatten(x)
143
+ ... return x
144
+ ...
145
+ >>> def test_compile():
146
+ ... net = Net(Block)
147
+ ... inp = Tensor(np.ones([1, 3, 224, 224]).astype(np.float32))
148
+ ... net(inp)
149
+ ...
150
+ >>> context.set_context(mode=context.GRAPH_MODE,
151
+ ... save_graphs=True, save_graphs_path="./lazy")
152
+ ...
153
+ >>> test_compile()
35
154
  """
36
155
 
37
156
  def wrap_cell(fn):
@@ -45,7 +164,7 @@ def lazy_inline(fn=None, attrs=None):
45
164
  arguments = arguments.values()
46
165
  fn(self, *args, **kwargs)
47
166
  if attrs is None:
48
- self.cell_init_args = type(self).__name__ + str(arguments)
167
+ self.cell_init_args = "lazy_inline_" + type(self).__name__ + str(arguments)
49
168
  return
50
169
 
51
170
  if isinstance(attrs, list):
@@ -59,7 +178,7 @@ def lazy_inline(fn=None, attrs=None):
59
178
  arguments = getattr(self, attrs)
60
179
  else:
61
180
  raise ValueError(f"attrs must be list or string")
62
- self.cell_init_args = type(self).__name__ + str(arguments)
181
+ self.cell_init_args = "lazy_inline_" + type(self).__name__ + str(arguments)
63
182
 
64
183
  return deco
65
184
 
@@ -0,0 +1,101 @@
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """mindir utility."""
16
+ from __future__ import absolute_import
17
+
18
+ import os
19
+ from mindspore import log as logger
20
+ from mindspore import _checkparam as Validator
21
+ from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
22
+
23
+
24
+ def load_mindir(file_name):
25
+ """
26
+ load protobuf file.
27
+
28
+ Args:
29
+ file_name (str): File name.
30
+
31
+ Returns:
32
+ ModelProto, mindir proto object.
33
+
34
+ Raises:
35
+ ValueError: The file does not exist or the file name format is incorrect.
36
+
37
+ Supported Platforms:
38
+ ``Ascend`` ``GPU`` ``CPU``
39
+
40
+ Examples:
41
+ >>> import mindspore as ms
42
+ >>> md = ms.load_mindir("test.mindir")
43
+ """
44
+
45
+ Validator.check_file_name_by_regular(file_name)
46
+ file_name = os.path.realpath(file_name)
47
+ model = mindir_model()
48
+
49
+ try:
50
+ with open(file_name, "rb") as f:
51
+ pb_content = f.read()
52
+ model.ParseFromString(pb_content)
53
+ except BaseException as e:
54
+ logger.critical(f"Failed to parse the file: {file_name} "
55
+ f" please check the correct file.")
56
+ raise ValueError(e.__str__()) from e
57
+ finally:
58
+ pass
59
+
60
+ return model
61
+
62
+
63
+ def save_mindir(model, file_name):
64
+ """
65
+ save protobuf file.
66
+
67
+ Args:
68
+ model (ModelProto): mindir model
69
+ file_name (str): File name.
70
+
71
+ Raises:
72
+ TypeError: The argument `model` is not a ModelProto object.
73
+ ValueError: The file path does not exist or the `file_name` format is incorrect.
74
+
75
+ Supported Platforms:
76
+ ``Ascend`` ``GPU`` ``CPU``
77
+
78
+ Examples:
79
+ >>> import mindspore as ms
80
+ >>> md = ms.load_mindir("test.mindir")
81
+ >>> md.user_info["version"]="pangu v100"
82
+ >>> ms.save_mindir(md,"test_new.mindir")
83
+ >>> md_new = ms.load_mindir("test_new.mindir")
84
+ >>> md_new.user_info
85
+ """
86
+
87
+ Validator.check_file_name_by_regular(file_name)
88
+ file_name = os.path.realpath(file_name)
89
+
90
+ if not isinstance(model, mindir_model):
91
+ raise TypeError("For 'save_mindir', the argument 'model' must be ModelProto, "
92
+ "but got {}.".format(type(model)))
93
+ try:
94
+ with open(file_name, "wb") as f:
95
+ f.write(model.SerializeToString())
96
+ except BaseException as e:
97
+ logger.critical(f"Failed to save the file: {file_name} ,"
98
+ f" please check the correct file.")
99
+ raise ValueError(e.__str__()) from e
100
+ finally:
101
+ pass
@@ -125,15 +125,22 @@ class Parameter(Tensor_):
125
125
  the list of its parameters, and will appear, e.g. in `cell.get_parameters()` iterator.
126
126
 
127
127
  Note:
128
- In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
129
- a `Tensor`, the type of Parameter will be `Tensor`. `Tensor`
130
- will save the shape and type info of a tensor with no memory usage. The shape can be changed while
131
- compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
132
- If there is an operator in the network that requires part of the inputs to be Parameter,
133
- then the Parameters as this part of the inputs are not allowed to be cast.
134
- Give each `Parameter` a unique name to facilitate subsequent operations and updates.
135
- If there are two or more `Parameter` objects with the same name in a network,
136
- will be prompted to set a unique name when defining.
128
+ - In auto_parallel mode of `SEMI_AUTO_PARALLEL` and `AUTO_PARALLEL`, if init `Parameter` by
129
+ a `Tensor`, the type of Parameter will be `Tensor`. `Tensor` will save the shape and type info of a tensor
130
+ with no memory usage.
131
+
132
+ - The shape can be changed while
133
+ compiling for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
134
+
135
+ - If there is an operator in the network that requires part of the inputs to be Parameter,
136
+ then the Parameters as this part of the inputs are not allowed to be cast.
137
+
138
+ - Give each `Parameter` a unique name to facilitate subsequent operations and updates.
139
+ If there are two or more `Parameter` objects with the same name in a network,
140
+ will be prompted to set a unique name when defining.
141
+
142
+ - When directly printing a `Parameter`, you cannot view the actual values contained inside it.
143
+ You need to use the `Parameter.asnumpy()` method to access the actual values.
137
144
 
138
145
  Args:
139
146
  default_input (Union[Tensor, int, float, numpy.ndarray, list]): Parameter data,
@@ -174,11 +181,11 @@ class Parameter(Tensor_):
174
181
  self.param_tuple = (self.param_a, self.param_a)
175
182
 
176
183
  requires_grad (bool): True if the parameter requires gradient. Default: ``True`` .
177
- layerwise_parallel (bool): When layerwise_parallel is true in data/hybrid parallel mode,
178
- broadcast and gradients communication would not be applied to parameters. Default: ``False`` .
179
- parallel_optimizer (bool): It is used to filter the weight shard operation in semi auto or auto parallel
180
- mode. It works only when enable parallel optimizer in `mindspore.set_auto_parallel_context()`.
181
- Default: ``True`` .
184
+ layerwise_parallel (bool): When `layerwise_parallel` is true in data/hybrid parallel mode,
185
+ broadcast and gradients communication would not be applied to the `Parameter`. Default: ``False`` .
186
+ parallel_optimizer (bool): It is used to filter the weight shard operation in `SEMI_AUTO_PARALLEL` or
187
+ `AUTO_PARALLEL` mode. It works only when enable parallel optimizer in
188
+ `mindspore.set_auto_parallel_context()`. Default: ``True`` .
182
189
 
183
190
  Examples:
184
191
  >>> import numpy as np
@@ -322,6 +329,8 @@ class Parameter(Tensor_):
322
329
  # in other place, so we can make a Tensor without copy data.
323
330
  return (Tensor, data)
324
331
  # make a copy of Tensor to init the parameter.
332
+ if data.dtype == mstype.bfloat16:
333
+ return (Tensor, data.float().asnumpy(), mstype.bfloat16)
325
334
  return (Tensor, data.asnumpy())
326
335
 
327
336
  not_init_data = _is_role_sched() or (_is_role_pserver() and _cache_enable()) or _is_in_parallel_mode()
@@ -348,11 +357,9 @@ class Parameter(Tensor_):
348
357
  init_in_server (bool): Whether trainable parameter updated by parameter server is
349
358
  initialized on server. Default: ``False``.
350
359
 
351
- Examples:
352
- >>> from mindspore import Tensor, Parameter
353
- >>> import numpy as np
354
- >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
355
- >>> x.set_param_ps(True)
360
+ Tutorial Examples:
361
+ - `Parameter Server Mode
362
+ <https://www.mindspore.cn/tutorials/experts/en/r2.2/parallel/parameter_server_training.html>`_
356
363
  """
357
364
  if not _is_ps_mode() or not (_is_role_worker() or _is_role_pserver() or _is_role_sched()):
358
365
  raise RuntimeError("Must complete following two steps before calling set_param_ps: \n"
@@ -404,7 +411,7 @@ class Parameter(Tensor_):
404
411
  >>> from mindspore import Tensor, Parameter
405
412
  >>> import numpy as np
406
413
  >>> x = Parameter(Tensor(np.array([1, 2], dtype=np.float32)), name="param")
407
- >>> x.inited_param()
414
+ >>> x.inited_param
408
415
  """
409
416
  return self._inited_param
410
417
 
@@ -484,8 +491,9 @@ class Parameter(Tensor_):
484
491
  Get the fusion type (int) for communication operators corresponding to this parameter.
485
492
 
486
493
  In `AUTO_PARALLEL` and `SEMI_AUTO_PARALLEL` mode, some communication operators used for parameters or
487
- gradients aggregation are inserted automatically. The value of fusion must be greater than or equal to 0.
488
- When the value of fusion is 0, operators will not be fused together.
494
+ gradients aggregation are inserted automatically.
495
+ The value of `comm_fusion` must be greater than or equal to 0.
496
+ When the value of `comm_fusion` is ``0`` , operators will not be fused together.
489
497
 
490
498
  Examples:
491
499
  >>> from mindspore import Tensor, Parameter
@@ -563,8 +571,8 @@ class Parameter(Tensor_):
563
571
  If `init` is a `Tensor` or `numbers.Number`, clone a new parameter with the same shape
564
572
  and dtype, and the data of the new parameter will be set according to `init`. If `init`
565
573
  is a `str`, the `init` should be the alias of the class inheriting from `Initializer`.
566
- For example, if `init` is 'same', clone a new parameter with the same data, shape, and
567
- dtype. Default: 'same'.
574
+ For example, if `init` is ``'same'``, clone a new parameter with the same data, shape, and
575
+ dtype. Default: ``'same'``.
568
576
 
569
577
  Returns:
570
578
  Parameter, a new parameter.
@@ -606,8 +614,8 @@ class Parameter(Tensor_):
606
614
  """
607
615
  Get the layerwise parallel status(bool) of the parameter.
608
616
 
609
- When layerwise_parallel is true in `DATA_PARALLEL` and `HYBRID_PARALLEL` parallel mode, broadcast and gradients
610
- communication would not be applied to parameters.
617
+ When `layerwise_parallel` is ``True`` in `DATA_PARALLEL` and `HYBRID_PARALLEL` parallel mode,
618
+ broadcast and gradients communication would not be applied to parameters.
611
619
 
612
620
  Examples:
613
621
  >>> from mindspore import Tensor, Parameter
@@ -745,7 +753,7 @@ class Parameter(Tensor_):
745
753
  >>> import numpy as np
746
754
  >>> x = Parameter(Tensor(np.array([[1, 2], [3, 4]], dtype=np.float32)), name="param")
747
755
  >>> x.data
748
- Parameter (name=Parameter, shape=(2, 2), dtype=float32, requires=True)
756
+ Parameter (name=param, shape=(2, 2), dtype=Float32, requires_grad=True)
749
757
  """
750
758
  return self
751
759
 
@@ -806,6 +814,7 @@ class Parameter(Tensor_):
806
814
  Tensor_.__init__(param, tensor)
807
815
  param.init = None
808
816
  param.init_mode = None
817
+ param.has_init = False
809
818
  param.is_default_input_init = False
810
819
  Parameter.__init__(param, tensor, *args, **kwargs)
811
820
  return param
@@ -817,8 +826,9 @@ class Parameter(Tensor_):
817
826
 
818
827
  Args:
819
828
  data (Union[Tensor, int, float]): New data.
820
- slice_shape (bool): If slice the parameter is set to true, the shape is not checked for consistency.
821
- Default: ``False``.
829
+ slice_shape (bool): If slice the parameter is set to ``True``, the shape consistency will not be checked.
830
+ Default: ``False``. When `slice_shape` is ``True``, and the shapes are not consistent, a
831
+ ValueError will be thrown.
822
832
 
823
833
  Returns:
824
834
  Parameter, the parameter after set data.
@@ -828,7 +838,7 @@ class Parameter(Tensor_):
828
838
  >>> import numpy as np
829
839
  >>> x = Parameter(Tensor(np.array([[1, 2], [3, 4]], dtype=np.float32)), name="param")
830
840
  >>> x.set_data(Tensor(np.array([[6, 6], [6, 6]], dtype=np.float32)))
831
- Parameter (name=Parameter, shape=(2, 2), dtype=float32, requires=True)
841
+ Parameter (name=param, shape=(2, 2), dtype=Float32, requires_grad=True)
832
842
  """
833
843
  if not isinstance(data, (Tensor, int, float)):
834
844
  raise TypeError(f"Parameter data must be [`Tensor`, `int`, `float`] or a kind of `Tensor` "
@@ -843,7 +853,7 @@ class Parameter(Tensor_):
843
853
  Parameter._set_data_check_input_valid(self.shape, data.shape, current_tensor_is_init, incoming_tensor_is_init,
844
854
  slice_shape, self.slice_num)
845
855
  if self.dtype != data.dtype:
846
- if mstype.implicit_conversion_seq[self.dtype] < mstype.implicit_conversion_seq[data.dtype]:
856
+ if mstype.implicit_conversion_seq.get(self.dtype) < mstype.implicit_conversion_seq.get(data.dtype):
847
857
  self._raise_type_error(data.dtype)
848
858
  else:
849
859
  from mindspore.ops import functional as F
@@ -950,12 +960,12 @@ class ParameterTuple(tuple):
950
960
  It is used to store the parameters of the network into the parameter tuple collection.
951
961
 
952
962
  Examples:
953
- >>> from mindspore import Tensor, Parameter, ParameterTuple
954
- >>> import numpy as np
955
- >>> x = Parameter(Tensor(np.array([[1, 2], [3, 4]], dtype=np.float32)), name="param")
956
- >>> y = Parameter(Tensor(np.array([[5, 6], [7, 8]], dtype=np.float32)), name="param1")
957
- >>> pt = ParameterTuple([x, y])
958
- >>> pt1 = pt.clone(prefix="new")
963
+ >>> from mindspore import Tensor, Parameter, ParameterTuple
964
+ >>> import numpy as np
965
+ >>> x = Parameter(Tensor(np.array([[1, 2], [3, 4]], dtype=np.float32)), name="param")
966
+ >>> y = Parameter(Tensor(np.array([[5, 6], [7, 8]], dtype=np.float32)), name="param1")
967
+ >>> pt = ParameterTuple([x, y])
968
+ >>> pt1 = pt.clone(prefix="new")
959
969
  """
960
970
 
961
971
  def __new__(cls, iterable):
@@ -985,20 +995,20 @@ class ParameterTuple(tuple):
985
995
  in parametertuple.
986
996
 
987
997
  init (Union[Tensor, str, numbers.Number]): Clone the shape and dtype of Parameters in ParameterTuple and
988
- set data according to `init`. Default: 'same'.
998
+ set data according to `init`. Default: ``'same'``.
989
999
 
990
1000
  - If `init` is a `Tensor` , set the new Parameter data to the input Tensor.
991
1001
  - If `init` is `numbers.Number` , set the new Parameter data to the input number.
992
1002
  - If `init` is a `str`, data will be set according to the initialization method of the same name in
993
- the `Initializer`.
994
- - If `init` is 'same', the new Parameter has the same value with the original Parameter.
1003
+ the `Initializer`. When it is ``'same'``, the new Parameter will have the same value
1004
+ with the original Parameter.
995
1005
 
996
1006
  Returns:
997
1007
  Tuple, the new Parameter tuple.
998
1008
 
999
1009
  Tutorial Examples:
1000
1010
  - `Cell and Parameter - Parameter Tuple
1001
- <https://mindspore.cn/tutorials/en/r2.1/advanced/modules/layer.html#parameter-tuple>`_
1011
+ <https://mindspore.cn/tutorials/en/r2.2/advanced/modules/layer.html#parameter-tuple>`_
1002
1012
  """
1003
1013
  Validator.check_str_by_regular(prefix)
1004
1014
  new = []
mindspore/common/seed.py CHANGED
@@ -41,12 +41,11 @@ def set_seed(seed):
41
41
  Set global seed.
42
42
 
43
43
  Note:
44
- The global seed is used by numpy.random, mindspore.common.Initializer, mindspore.ops.function.random_func and
44
+ The global seed is used by numpy.random, mindspore.common.Initializer and
45
45
  mindspore.nn.probability.distribution.
46
46
 
47
47
  If global seed is not set, these packages will use their own default seed independently, numpy.random and
48
- mindspore.common.Initializer will choose a random seed, mindspore.ops.function.random_func and
49
- mindspore.nn.probability.distribution will use zero.
48
+ mindspore.common.Initializer will choose a random seed, mindspore.nn.probability.distribution will use zero.
50
49
 
51
50
  Seed set by numpy.random.seed() only used by numpy.random, while seed set by this API will also used by
52
51
  numpy.random, so just set all seed by this API is recommended.
@@ -198,7 +197,8 @@ def _update_seeds(op_seed, kernel_name):
198
197
  """
199
198
  global _KERNEL_SEED
200
199
  if op_seed is not None:
201
- _KERNEL_SEED[(kernel_name, op_seed)] = _KERNEL_SEED[(kernel_name, op_seed)] + (keyConstant[0] ^ keyConstant[2])
200
+ _KERNEL_SEED[(kernel_name, op_seed)] = _KERNEL_SEED.get((kernel_name, op_seed)) + \
201
+ (keyConstant[0] ^ keyConstant[2])
202
202
 
203
203
 
204
204
  def _get_op_seed(op_seed, kernel_name):