mindspore 2.1.0__cp38-none-any.whl → 2.2.10__cp38-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (569) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +46 -19
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/__init__.py +0 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  25. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  26. mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +38 -0
  31. mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +12 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +61 -71
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +74 -104
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +13 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +28 -5
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8928 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  196. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  197. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  198. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  199. mindspore/nn/__init__.py +0 -2
  200. mindspore/nn/cell.py +313 -74
  201. mindspore/nn/dynamic_lr.py +21 -21
  202. mindspore/nn/layer/activation.py +22 -30
  203. mindspore/nn/layer/basic.py +15 -13
  204. mindspore/nn/layer/channel_shuffle.py +1 -1
  205. mindspore/nn/layer/container.py +271 -9
  206. mindspore/nn/layer/conv.py +323 -204
  207. mindspore/nn/layer/dense.py +8 -5
  208. mindspore/nn/layer/embedding.py +33 -27
  209. mindspore/nn/layer/flash_attention.py +141 -88
  210. mindspore/nn/layer/image.py +8 -6
  211. mindspore/nn/layer/math.py +16 -25
  212. mindspore/nn/layer/normalization.py +107 -66
  213. mindspore/nn/layer/padding.py +1 -1
  214. mindspore/nn/layer/pooling.py +131 -109
  215. mindspore/nn/layer/rnn_cells.py +27 -22
  216. mindspore/nn/layer/rnns.py +13 -16
  217. mindspore/nn/layer/thor_layer.py +1 -1
  218. mindspore/nn/layer/transformer.py +221 -154
  219. mindspore/nn/learning_rate_schedule.py +9 -1
  220. mindspore/nn/loss/loss.py +235 -174
  221. mindspore/nn/optim/ada_grad.py +2 -1
  222. mindspore/nn/optim/adadelta.py +1 -0
  223. mindspore/nn/optim/adafactor.py +2 -1
  224. mindspore/nn/optim/adam.py +7 -4
  225. mindspore/nn/optim/adamax.py +3 -2
  226. mindspore/nn/optim/adasum.py +2 -2
  227. mindspore/nn/optim/asgd.py +2 -3
  228. mindspore/nn/optim/ftrl.py +6 -5
  229. mindspore/nn/optim/lamb.py +7 -4
  230. mindspore/nn/optim/lars.py +1 -1
  231. mindspore/nn/optim/lazyadam.py +5 -3
  232. mindspore/nn/optim/momentum.py +2 -1
  233. mindspore/nn/optim/optimizer.py +53 -4
  234. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  235. mindspore/nn/optim/rmsprop.py +4 -3
  236. mindspore/nn/optim/rprop.py +23 -12
  237. mindspore/nn/optim/sgd.py +26 -11
  238. mindspore/nn/optim/thor.py +9 -7
  239. mindspore/nn/probability/bijector/bijector.py +5 -5
  240. mindspore/nn/probability/bijector/power_transform.py +27 -27
  241. mindspore/nn/probability/bijector/softplus.py +3 -3
  242. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  243. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  244. mindspore/nn/probability/distribution/beta.py +3 -3
  245. mindspore/nn/probability/distribution/categorical.py +7 -7
  246. mindspore/nn/probability/distribution/cauchy.py +0 -1
  247. mindspore/nn/probability/distribution/distribution.py +3 -3
  248. mindspore/nn/probability/distribution/gamma.py +3 -3
  249. mindspore/nn/probability/distribution/geometric.py +4 -4
  250. mindspore/nn/probability/distribution/gumbel.py +4 -4
  251. mindspore/nn/probability/distribution/log_normal.py +2 -2
  252. mindspore/nn/probability/distribution/logistic.py +2 -2
  253. mindspore/nn/probability/distribution/poisson.py +4 -4
  254. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  255. mindspore/nn/probability/distribution/uniform.py +6 -6
  256. mindspore/nn/wrap/cell_wrapper.py +84 -34
  257. mindspore/nn/wrap/grad_reducer.py +8 -5
  258. mindspore/nn/wrap/loss_scale.py +105 -42
  259. mindspore/numpy/array_creations.py +1 -2
  260. mindspore/numpy/array_ops.py +3 -2
  261. mindspore/numpy/utils_const.py +5 -5
  262. mindspore/offline_debug/convert_async.py +2 -2
  263. mindspore/ops/_grad_experimental/__init__.py +0 -5
  264. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  265. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  266. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  267. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  268. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  269. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  270. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  271. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  272. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  273. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
  274. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
  275. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
  276. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
  277. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
  278. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
  279. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  280. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  281. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  282. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  283. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  284. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  285. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  286. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  287. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  288. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  289. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  290. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  291. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  292. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  293. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  294. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  295. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  296. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  297. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  298. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  299. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  300. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  301. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  302. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  303. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  304. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  305. mindspore/ops/_primitive_cache.py +1 -1
  306. mindspore/ops/_tracefunc.py +45 -13
  307. mindspore/ops/_utils/utils.py +6 -1
  308. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  309. mindspore/ops/_vmap/vmap_base.py +3 -3
  310. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  311. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  312. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  313. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  314. mindspore/ops/arg_dtype_cast.py +54 -0
  315. mindspore/ops/composite/base.py +37 -10
  316. mindspore/ops/composite/math_ops.py +5 -4
  317. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  318. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  319. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  320. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  321. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  322. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  323. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  324. mindspore/ops/deprecated.py +304 -0
  325. mindspore/ops/function/__init__.py +4 -1
  326. mindspore/ops/function/array_func.py +174 -193
  327. mindspore/ops/function/clip_func.py +81 -13
  328. mindspore/ops/function/debug_func.py +1 -1
  329. mindspore/ops/function/grad/grad_func.py +18 -9
  330. mindspore/ops/function/image_func.py +10 -4
  331. mindspore/ops/function/linalg_func.py +5 -5
  332. mindspore/ops/function/math_func.py +575 -386
  333. mindspore/ops/function/nn_func.py +568 -260
  334. mindspore/ops/function/random_func.py +88 -57
  335. mindspore/ops/function/sparse_func.py +1 -1
  336. mindspore/ops/function/sparse_unary_func.py +14 -12
  337. mindspore/ops/function/vmap_func.py +6 -5
  338. mindspore/ops/functional.py +15 -10
  339. mindspore/ops/op_info_register.py +244 -25
  340. mindspore/ops/operations/__init__.py +28 -19
  341. mindspore/ops/operations/_grad_ops.py +72 -7
  342. mindspore/ops/operations/_inner_ops.py +350 -17
  343. mindspore/ops/operations/_quant_ops.py +4 -8
  344. mindspore/ops/operations/_sequence_ops.py +42 -0
  345. mindspore/ops/operations/array_ops.py +68 -282
  346. mindspore/ops/operations/comm_ops.py +107 -59
  347. mindspore/ops/operations/custom_ops.py +94 -70
  348. mindspore/ops/operations/debug_ops.py +8 -4
  349. mindspore/ops/operations/image_ops.py +18 -12
  350. mindspore/ops/operations/inner_ops.py +26 -3
  351. mindspore/ops/operations/math_ops.py +189 -141
  352. mindspore/ops/operations/nn_ops.py +794 -489
  353. mindspore/ops/operations/other_ops.py +0 -22
  354. mindspore/ops/operations/random_ops.py +53 -111
  355. mindspore/ops/operations/sparse_ops.py +3 -1
  356. mindspore/ops/primitive.py +24 -18
  357. mindspore/parallel/_auto_parallel_context.py +68 -8
  358. mindspore/parallel/_cost_model_context.py +2 -2
  359. mindspore/parallel/_offload_context.py +17 -3
  360. mindspore/parallel/_parallel_serialization.py +12 -5
  361. mindspore/parallel/_ps_context.py +12 -0
  362. mindspore/parallel/_tensor.py +18 -13
  363. mindspore/parallel/_transformer/layers.py +5 -3
  364. mindspore/parallel/_transformer/loss.py +1 -0
  365. mindspore/parallel/_transformer/moe.py +2 -2
  366. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  367. mindspore/parallel/_transformer/transformer.py +23 -3
  368. mindspore/parallel/_utils.py +11 -7
  369. mindspore/parallel/algo_parameter_config.py +85 -5
  370. mindspore/parallel/checkpoint_transform.py +19 -12
  371. mindspore/parallel/shard.py +21 -14
  372. mindspore/profiler/common/struct_type.py +3 -3
  373. mindspore/profiler/common/util.py +4 -2
  374. mindspore/profiler/envprofiling.py +1 -1
  375. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  376. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  377. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  378. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  379. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  380. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  381. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  382. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  383. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  384. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  385. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  386. mindspore/profiler/parser/flops_parser.py +15 -11
  387. mindspore/profiler/parser/framework_parser.py +38 -22
  388. mindspore/profiler/parser/hccl_parser.py +16 -12
  389. mindspore/profiler/parser/integrator.py +22 -11
  390. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  391. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  392. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  393. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  394. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  395. mindspore/profiler/parser/optime_parser.py +1 -1
  396. mindspore/profiler/parser/profiler_info.py +21 -2
  397. mindspore/profiler/parser/step_trace_parser.py +11 -14
  398. mindspore/profiler/profiling.py +179 -89
  399. mindspore/rewrite/api/node.py +102 -19
  400. mindspore/rewrite/api/node_type.py +5 -1
  401. mindspore/rewrite/api/pattern_engine.py +1 -1
  402. mindspore/rewrite/api/scoped_value.py +9 -17
  403. mindspore/rewrite/api/symbol_tree.py +131 -47
  404. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  405. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  406. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  407. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  408. mindspore/rewrite/common/rewrite_elog.py +5 -1
  409. mindspore/rewrite/namer.py +33 -24
  410. mindspore/rewrite/namespace.py +14 -5
  411. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  412. mindspore/rewrite/node/call_function.py +79 -0
  413. mindspore/rewrite/node/cell_container.py +135 -0
  414. mindspore/rewrite/node/control_flow.py +88 -0
  415. mindspore/rewrite/{node.py → node/node.py} +273 -234
  416. mindspore/rewrite/node/node_manager.py +254 -0
  417. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  418. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  419. mindspore/rewrite/parsers/assign_parser.py +216 -221
  420. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  421. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  422. mindspore/rewrite/parsers/constant_parser.py +9 -6
  423. mindspore/rewrite/parsers/container_parser.py +9 -7
  424. mindspore/rewrite/parsers/for_parser.py +36 -15
  425. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  426. mindspore/rewrite/parsers/if_parser.py +28 -24
  427. mindspore/rewrite/parsers/module_parser.py +196 -25
  428. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  429. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  430. mindspore/rewrite/parsers/return_parser.py +6 -6
  431. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  432. mindspore/rewrite/sparsify/utils.py +1 -1
  433. mindspore/rewrite/symbol_tree.py +523 -578
  434. mindspore/rewrite/symbol_tree_builder.py +9 -193
  435. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  436. mindspore/run_check/_check_version.py +6 -4
  437. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  438. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  439. mindspore/scipy/linalg.py +1 -1
  440. mindspore/scipy/optimize/minimize.py +7 -3
  441. mindspore/train/_utils.py +7 -3
  442. mindspore/train/amp.py +323 -123
  443. mindspore/train/anf_ir_pb2.py +14 -2
  444. mindspore/train/callback/_backup_and_restore.py +2 -12
  445. mindspore/train/callback/_callback.py +29 -4
  446. mindspore/train/callback/_checkpoint.py +23 -8
  447. mindspore/train/callback/_early_stop.py +2 -2
  448. mindspore/train/callback/_landscape.py +4 -4
  449. mindspore/train/callback/_loss_monitor.py +2 -2
  450. mindspore/train/callback/_on_request_exit.py +2 -2
  451. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  452. mindspore/train/callback/_summary_collector.py +15 -8
  453. mindspore/train/callback/_time_monitor.py +58 -5
  454. mindspore/train/data_sink.py +5 -11
  455. mindspore/train/dataset_helper.py +84 -57
  456. mindspore/train/loss_scale_manager.py +2 -2
  457. mindspore/train/metrics/__init__.py +3 -3
  458. mindspore/train/metrics/cosine_similarity.py +1 -1
  459. mindspore/train/metrics/hausdorff_distance.py +3 -2
  460. mindspore/train/metrics/mean_surface_distance.py +3 -2
  461. mindspore/train/metrics/metric.py +39 -19
  462. mindspore/train/metrics/roc.py +2 -2
  463. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  464. mindspore/train/mind_ir_pb2.py +85 -36
  465. mindspore/train/model.py +187 -47
  466. mindspore/train/serialization.py +487 -161
  467. mindspore/train/summary/_summary_adapter.py +1 -1
  468. mindspore/train/summary/_writer_pool.py +3 -2
  469. mindspore/train/summary/summary_record.py +37 -17
  470. mindspore/train/train_thor/convert_utils.py +3 -3
  471. mindspore/train/train_thor/dataset_helper.py +1 -1
  472. mindspore/version.py +1 -1
  473. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +6 -7
  474. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +477 -517
  475. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -1
  476. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  477. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  478. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  479. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  480. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  481. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  482. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  483. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  484. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  485. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  486. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  487. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  488. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  489. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  490. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  491. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  492. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  493. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  494. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  495. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  496. mindspore/_extends/graph_kernel/expander.py +0 -80
  497. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  498. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  499. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  500. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  501. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  502. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  503. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  504. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  505. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  506. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  507. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  508. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  509. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  510. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  511. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  512. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  513. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  514. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  516. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  517. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  518. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  519. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  520. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  521. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  522. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  523. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  524. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  525. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  526. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  527. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  528. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  529. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  530. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  531. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  532. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  533. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  534. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  535. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  536. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  537. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  538. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  539. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  540. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  541. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  542. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  543. mindspore/dataset/datapreprocess/__init__.py +0 -20
  544. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  545. mindspore/include/api/net.h +0 -142
  546. mindspore/nn/lr_scheduler.py +0 -262
  547. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  548. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  549. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  550. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  551. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  552. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  553. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  554. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  555. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  556. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  557. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  558. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  559. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  560. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  561. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  562. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  563. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  564. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  565. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  566. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  567. mindspore/rewrite/node_visitor.py +0 -44
  568. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
  569. {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,541 @@
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """obfuscate network based on rewrite interfaces."""
16
+ import os
17
+ import re
18
+ import secrets
19
+ from pathlib import Path
20
+
21
+ from mindspore import ops, nn
22
+ from mindspore.common.tensor import Tensor
23
+ from mindspore import log as logger
24
+ from mindspore import load_checkpoint, save_checkpoint
25
+ from mindspore.rewrite import SymbolTree, Node, NodeType, TreeNodeHelper, ScopedValue
26
+ from mindspore.rewrite.parsers.class_def_parser import ClassDefParser
27
+ from mindspore.rewrite.parsers.class_def_parser import ModuleParser
28
+
29
+ OBF_RATIOS_LENGTH = 1
30
+ MAX_OBF_RATIOS_NUM = 50
31
+ OBF_RATIOS_WIDTH = 0
32
+ OBF_RATIOS_INSERT_INDEX = 0
33
+
34
+
35
+ def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./', obfuscate_scale=100):
36
+ """
37
+ obfuscate the plaintext checkpoint files. Usually used in conjunction with
38
+ :func:`mindspore.load_obf_params_into_net`.
39
+ interface.
40
+
41
+ Args:
42
+ network (nn.Cell): The original network that need to be obfuscated.
43
+ ckpt_files (str): The directory path of original ckpt files.
44
+ target_modules (list[str]): The target module of network that need to be obfuscated. The first string
45
+ represents the network path of target module in original network, which should be in form of ``'A/B/C'``.
46
+ The second string represents the obfuscation target module, which should be in form of ``'D|E|F'``. For
47
+ example, thr target_modules of GPT2 can be ``['backbone/blocks/attention', 'dense1|dense2|dense3']``.
48
+ If target_modules has the third value, it should be in the format of 'obfuscate_layers:all' or
49
+ 'obfuscate_layers:int', which represents the number of layers need to be obfuscated of duplicate layers
50
+ (such as transformer layers or resnet blocks). If target_modules is ``None``, the function would search
51
+ target modules by itself. If found, the searched target module would be used, otherwise suggested target
52
+ modules would be given with warning log. Default: ``None``.
53
+ saved_path (str): The directory path for saving obfuscated ckpt files. Default: ``'./'``.
54
+ obfuscate_scale (Union[float, int]): Obfuscate scale of weights. The generated random obf_ratios will be in
55
+ range of (1 / obfuscate_scale, obfuscate_scale). Default: 100.
56
+
57
+ Raises:
58
+ TypeError: If `network` is not nn.Cell.
59
+ TypeError: If `ckpt_files` is not string or `saved_path` is not string.
60
+ TypeError: If `target_modules` is not list.
61
+ TypeError: If target_modules's elements are not string.
62
+ ValueError: If `ckpt_files` is not exist or `saved_path` is not exist.
63
+ ValueError: If the number of elements of `target_modules` is less than ``2``.
64
+ ValueError: If the first string of `target_modules` contains characters other than uppercase and lowercase
65
+ letters, numbers, ``'_'`` and ``'/'``.
66
+ ValueError: If the second string of `target_modules` is empty or contains characters other than uppercase and
67
+ lowercase letters, numbers, ``'_'`` and ``'|'``.
68
+ ValueError: If the third string of `target_modules` is not in the format of 'obfuscate_layers:all' or
69
+ 'obfuscate_layers:int'.
70
+
71
+ Returns:
72
+ list[float], obf_ratios, which is the necessary data that needs to be load when running obfuscated network.
73
+
74
+ Examples:
75
+ >>> from mindspore import obfuscate_ckpt, save_checkpoint
76
+ >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
77
+ >>> net = LeNet5()
78
+ >>> save_checkpoint(net, './test_net.ckpt')
79
+ >>> target_modules = ['', 'fc1|fc2']
80
+ >>> obfuscate_ckpt(net, target_modules, './', './')
81
+ """
82
+ if not isinstance(network, nn.Cell):
83
+ raise TypeError("network must be nn.Cell, but got {}.".format(type(network)))
84
+ _check_dir_path('ckpt_files', ckpt_files)
85
+ _check_dir_path('saved_path', saved_path)
86
+ # Try to find default target modules
87
+ if target_modules is None:
88
+ to_split_modules = _get_default_target_modules(ckpt_files)
89
+ else:
90
+ if len(target_modules) >= 1 and target_modules[0] == '/':
91
+ target_modules[0] = ''
92
+ to_split_modules = target_modules
93
+ if not _check_valid_target(network, to_split_modules):
94
+ raise ValueError("The obfuscate module path {} is not exist, please check the input 'target_modules'."
95
+ .format(to_split_modules))
96
+ if (not isinstance(obfuscate_scale, (float, int))) or (obfuscate_scale <= 1):
97
+ raise ValueError("obfuscate_scale must be float or int, and larger than 1, but got {}."
98
+ .format(obfuscate_scale))
99
+ # generate and save obf_ratios to saved_path
100
+ path_list = to_split_modules[0].split('/')
101
+ target_list = to_split_modules[1].split('|')
102
+ global OBF_RATIOS_LENGTH
103
+ number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
104
+ if number_of_ratios > MAX_OBF_RATIOS_NUM:
105
+ OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH
106
+ number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
107
+ obf_ratios = []
108
+ secrets_generator = secrets.SystemRandom()
109
+ for _ in range(number_of_ratios):
110
+ secure_float = secrets_generator.uniform(1 / obfuscate_scale, obfuscate_scale)
111
+ obf_ratios.append(secure_float)
112
+ # start obfuscate ckpt
113
+ ckpt_dir_files = os.listdir(ckpt_files)
114
+ for ckpt_name in ckpt_dir_files:
115
+ sub_path = os.path.abspath(ckpt_files) + '/' + ckpt_name
116
+ if Path(sub_path).is_dir():
117
+ sub_ckpt_file_list = os.listdir(sub_path)
118
+ new_saved_path = os.path.abspath(saved_path) + '/' + ckpt_name
119
+ if not os.path.exists(new_saved_path):
120
+ try:
121
+ os.mkdir(new_saved_path, mode=0o700)
122
+ except FileExistsError:
123
+ pass
124
+ for sub_ckpt_name in sub_ckpt_file_list:
125
+ if not sub_ckpt_name.endswith('.ckpt'):
126
+ continue
127
+ _obfuscate_single_ckpt(os.path.abspath(sub_path) + '/' + sub_ckpt_name, obf_ratios, path_list,
128
+ target_list, new_saved_path)
129
+ else:
130
+ if not ckpt_name.endswith('.ckpt'):
131
+ continue
132
+ _obfuscate_single_ckpt(os.path.abspath(ckpt_files) + '/' + ckpt_name, obf_ratios, path_list,
133
+ target_list, saved_path)
134
+ return obf_ratios
135
+
136
+
137
+ def _obfuscate_single_ckpt(ckpt_name, obf_ratios, path_list, target_list, saved_path):
138
+ """Obfuscate single ckpt file"""
139
+ module_has_been_obfuscated = set()
140
+ try:
141
+ ckpt_param = load_checkpoint(ckpt_name)
142
+ except (ValueError, TypeError, OSError):
143
+ logger.error("Load checkpoint failed for file {}.".format(ckpt_name))
144
+ return None
145
+ obf_ratios_index = -1
146
+ for item in ckpt_param:
147
+ module = _get_valid_module(item, path_list, target_list)
148
+ if module:
149
+ layer_index = _judge_layer_index(item)
150
+ if layer_index >= OBF_RATIOS_LENGTH:
151
+ continue
152
+ if module not in module_has_been_obfuscated:
153
+ module_has_been_obfuscated.add(module)
154
+ obf_ratios_index += 1
155
+ ratio_total_index = layer_index * OBF_RATIOS_WIDTH + obf_ratios_index % OBF_RATIOS_WIDTH
156
+ ckpt_param[item].set_data(ckpt_param[item].value() / obf_ratios[ratio_total_index])
157
+ # save the obfuscated model to saved_path
158
+ obf_param_list = []
159
+ for item in ckpt_param:
160
+ obf_param_list.append({'name': item, 'data': ckpt_param[item]})
161
+ ckpt_file_name = ckpt_name.split('/')[-1]
162
+ obf_ckpt_file_name = ckpt_file_name.split('.')[0] + '_obf' + '.ckpt'
163
+ save_checkpoint(obf_param_list, os.path.abspath(saved_path) + '/' + obf_ckpt_file_name)
164
+ return None
165
+
166
+
167
+ def load_obf_params_into_net(network, target_modules, obf_ratios, data_parallel_num=1, **kwargs):
168
+ """
169
+ load obfuscate ratios into obfuscated network. Usually used in conjunction with :func:`mindspore.obfuscate_ckpt`
170
+ interface.
171
+
172
+ Args:
173
+ network (nn.Cell): The original network that need to be obfuscated.
174
+ target_modules (list[str]): The target module of network that need to be obfuscated. The first string
175
+ represents the network path of target module in original network, which should be in form of ``'A/B/C'``.
176
+ The second string represents the obfuscation target module, which should be in form of ``'D|E|F'``. For
177
+ example, thr target_modules of GPT2 can be ``['backbone/blocks/attention', 'dense1|dense2|dense3']``.
178
+ If target_modules has the third value, it should be in the format of 'obfuscate_layers:all' or
179
+ 'obfuscate_layers:int', which represents the number of layers need to be obfuscated of duplicate layers
180
+ (such as transformer layers or resnet blocks).
181
+ data_parallel_num (int): The data parallel number of parallel training. Default: 1.
182
+ obf_ratios (Tensor): The obf ratios generated when execute :func:`mindspore.obfuscate_ckpt`.
183
+ kwargs (dict): Configuration options dictionary.
184
+
185
+ - ignored_func_decorators (list[str]): The name list of function decorators in network's python code.
186
+ - ignored_class_decorators (list[str]): The name list of class decorators in network's python code.
187
+
188
+ Raises:
189
+ TypeError: If `network` is not nn.Cell.
190
+ TypeError: If `obf_ratios` is not Tensor.
191
+ TypeError: If `target_modules` is not list.
192
+ TypeError: If target_modules's elements are not string.
193
+ ValueError: If the number of elements of `target_modules` is less than ``2``.
194
+ ValueError: If `obf_ratios` is empty Tensor.
195
+ ValueError: If the first string of `target_modules` contains characters other than uppercase and lowercase
196
+ letters, numbers, ``'_'`` and ``'/'``.
197
+ ValueError: If the second string of `target_modules` is empty or contains characters other than uppercase and
198
+ lowercase letters, numbers, ``'_'`` and ``'|'``.
199
+ ValueError: If the third string of `target_modules` is not in the format of 'obfuscate_layers:all' or
200
+ 'obfuscate_layers:int'.
201
+ TypeError: If `ignored_func_decorators` is not list[str] or `ignored_class_decorators` is not list[str].
202
+
203
+ Examples:
204
+ >>> from mindspore import obfuscate_ckpt, save_checkpoint, load_checkpoint, Tensor
205
+ >>> import mindspore.common.dtype as mstype
206
+ >>> import numpy as np
207
+ >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
208
+ >>> net = LeNet5()
209
+ >>> save_checkpoint(net, './test_net.ckpt')
210
+ >>> target_modules = ['', 'fc1|fc2']
211
+ >>> # obfuscate ckpt files
212
+ >>> obfuscate_ckpt(net, target_modules, './', './')
213
+ >>> # load obf ckpt into network
214
+ >>> new_net = LeNet5()
215
+ >>> load_checkpoint('./test_net_obf.ckpt', new_net)
216
+ >>> obf_ratios = Tensor(np.load('./obf_ratios.npy'), mstype.float16)
217
+ >>> obf_net = load_obf_params_into_net(new_net, target_modules, obf_ratios)
218
+ """
219
+ if not isinstance(network, nn.Cell):
220
+ raise TypeError("network must be nn.Cell, but got {}.".format(type(network)))
221
+ if not isinstance(obf_ratios, Tensor):
222
+ raise TypeError("obf_ratios must be MindSpore Tensor, but got {}.".format(type(obf_ratios)))
223
+ if obf_ratios.size == 0:
224
+ raise ValueError("obf_ratios can not be empty.")
225
+ if not _check_valid_target(network, target_modules):
226
+ raise ValueError("{} is not exist, please check the input 'target_modules'.".format(target_modules))
227
+ if (not isinstance(data_parallel_num, int)) or (data_parallel_num <= 0):
228
+ raise ValueError("data_parallel_num must be positive number, but got {}.".format(data_parallel_num))
229
+ if len(target_modules) >= 1 and target_modules[0] == '/':
230
+ target_modules[0] = ''
231
+ path_list = target_modules[0].split('/')
232
+ path_len = len(path_list)
233
+ target_list = []
234
+ for _ in range(path_len):
235
+ target_list.append([])
236
+ target_list.append(target_modules[1].split('|'))
237
+ global MAX_OBF_RATIOS_NUM, OBF_RATIOS_LENGTH
238
+ number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
239
+ if number_of_ratios > MAX_OBF_RATIOS_NUM:
240
+ OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH
241
+ number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
242
+ MAX_OBF_RATIOS_NUM = number_of_ratios
243
+ rewrite_network = _obfuscate_network(network, path_list, target_list, data_parallel_num=data_parallel_num, **kwargs)
244
+ setattr(rewrite_network, 'obf_ratios', obf_ratios)
245
+ return rewrite_network
246
+
247
+
248
+ def _check_dir_path(name, dir_path):
249
+ """check directory path"""
250
+ if not isinstance(dir_path, str):
251
+ raise TypeError("{} must be string, but got {}.".format(name, type(dir_path)))
252
+ if not os.path.exists(dir_path):
253
+ raise ValueError("{} is not exist, please check the input {}.".format(dir_path, name))
254
+ if not Path(dir_path).is_dir():
255
+ raise TypeError("{} must be a directory path, but got {}.".format(name, dir_path))
256
+
257
+
258
+ def _judge_layer_index(layer_name):
259
+ """Judge the layer index of target layers"""
260
+ split_name = layer_name.split('.')
261
+ for split_str in split_name[:]:
262
+ if split_str.isdigit():
263
+ return int(split_str)
264
+ return 0
265
+
266
+
267
+ def _check_valid_target(network, target_modules):
268
+ """check whether the input 'target_modules' exists"""
269
+ if not isinstance(target_modules, list):
270
+ raise TypeError("target_modules type should be list, but got {}.".format(type(target_modules)))
271
+ if len(target_modules) < 2:
272
+ raise ValueError("target_modules should contain at least two string values, in the form of ['A/B/C', 'D1|D2'],"
273
+ "but got {}.".format(target_modules))
274
+ if (not isinstance(target_modules[0], str)) or (not isinstance(target_modules[1], str)):
275
+ raise TypeError("The values of target_modules should be string, but got {} and {}.".
276
+ format(type(target_modules[0]), type(target_modules[1])))
277
+
278
+ if not target_modules[1]:
279
+ raise ValueError("{} should be a non-empty string value, in the form of 'D1|D2'"
280
+ .format(target_modules[1]))
281
+ if not re.fullmatch(pattern=r'([a-zA-Z]*[0-9]*\/*_*)*', string=target_modules[0]) \
282
+ or not re.fullmatch(pattern=r'([a-zA-Z]*[0-9]*\|*_*)*', string=target_modules[1]):
283
+ raise ValueError("please check the input 'target_modules'{},it should be in the form of ['A/B/C', 'D1|D2']."
284
+ "target_modules[0] can only contain uppercase and lowercase letters, numbers, '_' and '/',"
285
+ "target_modules[1] can only contain uppercase and lowercase letters, numbers, '_' and '|'"
286
+ .format(target_modules))
287
+ # target_modules[0] is allowed to be '', it means the main network path
288
+ path_list = target_modules[0].split('/')
289
+ target_list = target_modules[1].split('|')
290
+ net = network
291
+ # DFS check whether path_list is valid
292
+ stk = [net]
293
+ i = 0
294
+ global OBF_RATIOS_LENGTH
295
+ OBF_RATIOS_LENGTH = 1
296
+ while stk and i < len(path_list):
297
+ net = stk.pop()
298
+ if hasattr(net, path_list[i]):
299
+ net = getattr(net, path_list[i])
300
+ i += 1
301
+ if isinstance(net, nn.CellList):
302
+ OBF_RATIOS_LENGTH *= len(net)
303
+ for n in net:
304
+ stk.append(n)
305
+ elif isinstance(net, nn.Cell):
306
+ stk.append(net)
307
+ else:
308
+ raise TypeError("Target_modules[0] should be a subgraph and it's type should be nn.Cell(nn.CellList),"
309
+ "but got type {}".format(type(net)))
310
+ if target_modules[0] != '' and i != len(path_list):
311
+ raise ValueError("the path {} does not exist.".format(target_modules[0]))
312
+ # check whether target_list is valid
313
+ global OBF_RATIOS_WIDTH
314
+ OBF_RATIOS_WIDTH = 0
315
+ for target in target_list:
316
+ if not hasattr(net, target):
317
+ logger.warning("{} does not exist in the path {}".format(target, target_modules[0]))
318
+ else:
319
+ OBF_RATIOS_WIDTH += 1
320
+ if OBF_RATIOS_WIDTH == 0:
321
+ raise ValueError("all targets {} do not exist in the path {}.".format(target_list, target_modules[0]))
322
+ _update_max_obf_ratios_num(target_modules)
323
+ return True
324
+
325
+
326
+ def _update_max_obf_ratios_num(target_modules):
327
+ """Update MAX_OBF_RATIOS_NUM"""
328
+ if len(target_modules) >= 3:
329
+ obfuscate_layers = target_modules[2].split(':')
330
+ if len(obfuscate_layers) != 2 or obfuscate_layers[0] != 'obfuscate_layers':
331
+ raise ValueError("The third value of target_modules should be in the format of 'obfuscate_layers:all' or"
332
+ "'obfuscate_layers:int'")
333
+ global MAX_OBF_RATIOS_NUM
334
+ if obfuscate_layers[1] == 'all':
335
+ MAX_OBF_RATIOS_NUM = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
336
+ else:
337
+ if not obfuscate_layers[1].isdigit():
338
+ raise ValueError(
339
+ "The third value of target_modules should be in the format of 'obfuscate_layers:all' or"
340
+ "'obfuscate_layers:int'")
341
+ MAX_OBF_RATIOS_NUM = int(obfuscate_layers[1]) * OBF_RATIOS_WIDTH
342
+
343
+
344
+ def _get_default_target_modules(ckpt_files):
345
+ """Get the default or suggested target modules, if the target modules is None."""
346
+
347
+ def _split_to_path_and_target(module, target):
348
+ # split module into path list and target list
349
+ target_index = module.index(target)
350
+ path = module[:target_index - 1]
351
+ target = module[target_index:].split('/')[0]
352
+ return path, target
353
+
354
+ def _find_default_obfuscate_modules(net_path):
355
+ # find modules including the default paths
356
+ default_module = {'attention'}
357
+ for module in default_module:
358
+ if module in net_path and module not in candidate_modules:
359
+ candidate_modules.append(net_path)
360
+ # find the default targets in the default module
361
+ default_target = {'dense', 'query', 'key', 'value'}
362
+ for target in default_target:
363
+ for candidate in candidate_modules:
364
+ if target in candidate:
365
+ path, target = _split_to_path_and_target(candidate, target)
366
+ if path not in paths:
367
+ paths.append(path)
368
+ if target not in targets:
369
+ targets.append(target)
370
+
371
+ def _find_suggested_obfuscate_modules(net_path):
372
+ default_target = {'dense', 'query', 'key', 'value'}
373
+ for target in default_target:
374
+ # find the suggest modules
375
+ if target in net_path:
376
+ path, target = _split_to_path_and_target(net_path, target)
377
+ if [path, target] not in suggest_modules:
378
+ suggest_modules.append([path, target])
379
+
380
+ # store the potential candidate_modules
381
+ candidate_modules = []
382
+ suggest_modules = []
383
+ paths = []
384
+ targets = []
385
+ ckpt_dir_files = os.listdir(ckpt_files)
386
+ for ckpt_name in ckpt_dir_files:
387
+ if not ckpt_name.endswith('.ckpt'):
388
+ continue
389
+ try:
390
+ ckpt_param = load_checkpoint(os.path.abspath(ckpt_files) + '/' + ckpt_name)
391
+ except (ValueError, TypeError, OSError):
392
+ logger.error("Load checkpoint failed for file {}.".format(os.path.abspath(ckpt_files) + '/' + ckpt_name))
393
+ return None
394
+ for item in ckpt_param:
395
+ param_path = _remove_digit(item)
396
+ param_path = '/'.join(param_path)
397
+ # find candidate modules including the default paths and append candidate_modules
398
+ _find_default_obfuscate_modules(param_path)
399
+ # give the suggested modules and find the default targets in the default module
400
+ _find_suggested_obfuscate_modules(param_path)
401
+ if paths and targets:
402
+ target_modules = [paths[0], '|'.join(targets)]
403
+ logger.warning("The default obfuscate modules is obtained:{}".format(target_modules))
404
+ return target_modules
405
+ # logging the suggested target module
406
+ logger.warning("The default obfuscate modules can not be obtained. The suggested possible paths are given below: {}"
407
+ .format(suggest_modules))
408
+ raise ValueError("Can not get the default path, please specify the path in the form of ['A/B/C', 'D1|D2']")
409
+
410
+
411
+ def _get_valid_module(item, path_list, target_list):
412
+ """get the valid module"""
413
+ number_path = len(path_list)
414
+ net_path = _remove_digit(item)
415
+ net_path = '/'.join(net_path[:number_path])
416
+ tar_path = '/'.join(path_list)
417
+ # update the weights with obf_ratios in target module
418
+ if net_path == tar_path:
419
+ for target in target_list:
420
+ if target in item.split('.'):
421
+ target_index = item.split('.').index(target)
422
+ module = ''.join(item.split('.')[:target_index + 1])
423
+ return module
424
+ return None
425
+
426
+
427
+ def _remove_digit(item):
428
+ """remove digit in the parameter path"""
429
+ param_path = item.split('.')
430
+ for tmp_str in param_path[:]:
431
+ if tmp_str.isdigit():
432
+ param_path.remove(tmp_str)
433
+ return param_path
434
+
435
+
436
+ def _obfuscate_network(model, path_list, target_list, data_parallel_num=1, **kwargs):
437
+ """obfuscate original network, including add mul operation and add inputs for passing obf_ratio."""
438
+
439
+ def _insert_input(stree: SymbolTree, arg_name: str = 'y_obf'):
440
+ """add inputs for passing obf_ratio"""
441
+ last_input = None
442
+ for node in stree.nodes():
443
+ if node.get_node_type() == NodeType.Input:
444
+ last_input = node
445
+ position = stree.after(last_input)
446
+ # the insert input node name would be 'input_y_obf'
447
+ new_input_node = last_input.create_input(arg_name)
448
+ stree.insert(position, new_input_node)
449
+
450
+ def _insert_mul(stree: SymbolTree, node: Node, index: int):
451
+ """add mul operation for original network"""
452
+ arg_list = node.get_targets().copy()
453
+ input_y_node = stree.get_node("input_y_obf")
454
+ v: str = input_y_node.get_targets()[0].value
455
+ sv: ScopedValue = ScopedValue.create_naming_value(v + f'[{index}]')
456
+ arg_list.append(sv)
457
+ target_list = node.get_targets().copy()
458
+ if data_parallel_num > 1:
459
+ logger.info("Data parallel number is: {}".format(data_parallel_num))
460
+ new_mul_node = node.create_call_cell(cell=ops.Mul().shard(((data_parallel_num, 1), ())),
461
+ targets=target_list, args=arg_list, name='mul')
462
+ else:
463
+ new_mul_node = node.create_call_cell(cell=ops.Mul(), targets=target_list, args=arg_list, name='mul')
464
+ position = stree.after(node)
465
+ stree.insert(position, new_mul_node)
466
+
467
+ def _insert_mul_by_name(stree: SymbolTree, after_name_list: list):
468
+ """add mul operation after the target nodes according the name of them"""
469
+ if not after_name_list:
470
+ return
471
+ for node in stree.nodes():
472
+ for after_name in after_name_list:
473
+ if node.get_name() == after_name:
474
+ global OBF_RATIOS_INSERT_INDEX
475
+ if OBF_RATIOS_INSERT_INDEX < MAX_OBF_RATIOS_NUM:
476
+ _insert_mul(stree, node, OBF_RATIOS_INSERT_INDEX)
477
+ OBF_RATIOS_INSERT_INDEX += 1
478
+
479
+ def _update_subnet(stree: SymbolTree, substree: SymbolTree, subnode: Node):
480
+ """update the network once the subnet is obfuscated"""
481
+ new_net = substree.get_network()
482
+ input_y_node = substree.get_node("input_y_obf")
483
+ if input_y_node is None:
484
+ return
485
+ arg_list = subnode.get_args().copy()
486
+ kwargs_list = list(subnode.get_kwargs().values())
487
+ arg_list.extend(kwargs_list)
488
+ v: str = input_y_node.get_targets()[0].value
489
+ arg_obf: ScopedValue = ScopedValue.create_naming_value("y_obf=" + v)
490
+ arg_list.append(arg_obf)
491
+ target_list = subnode.get_targets().copy()
492
+ name = subnode.get_name()
493
+ new_node = subnode.create_call_cell(cell=new_net, targets=target_list, args=arg_list, name=name)
494
+ stree.replace(subnode, [new_node])
495
+
496
+ def _traverse(stree, i=0):
497
+ """traverse and obfuscate the original network"""
498
+ if len(path_list) == i:
499
+ return
500
+ for node in stree.nodes():
501
+ node_name = node.get_name()
502
+ if node.get_node_type() == NodeType.Tree and node_name.startswith(path_list[i]):
503
+ sub_stree = TreeNodeHelper.get_sub_tree(node)
504
+ _traverse(sub_stree, i + 1)
505
+ _insert_input(sub_stree, arg_name='y_obf')
506
+ _insert_mul_by_name(sub_stree, after_name_list=target_list[i + 1])
507
+ _update_subnet(stree, sub_stree, node)
508
+
509
+ def _register_denied_func_decorators(fn):
510
+ """set the function decorators which should be denied for parse"""
511
+ name = "denied_function_decorator_list"
512
+ setattr(ClassDefParser, name, fn)
513
+
514
+ def _register_denied_class_decorators(fn):
515
+ """set the class decorators which should be denied for parse"""
516
+ name = "denied_class_decorator_list"
517
+ setattr(ModuleParser, name, fn)
518
+
519
+ if 'ignored_func_decorators' in kwargs.keys():
520
+ kw_func_dec = kwargs["ignored_func_decorators"]
521
+ if not isinstance(kw_func_dec, list):
522
+ raise TypeError('{} should be list, but got {}'.format(kw_func_dec, type(kw_func_dec)))
523
+ if kw_func_dec and not isinstance(kw_func_dec[0], str):
524
+ raise TypeError('elements of {} should be str, but got {}'.format(kw_func_dec, type(kw_func_dec[0])))
525
+ _register_denied_func_decorators(kw_func_dec)
526
+ else:
527
+ _register_denied_func_decorators(["_args_type_validator_check", "_LogActionOnce", "cell_attr_register"])
528
+ if 'ignored_class_decorators' in kwargs.keys():
529
+ kw_class_dec = kwargs["ignored_class_decorators"]
530
+ _register_denied_class_decorators(kw_class_dec)
531
+ if not isinstance(kw_class_dec, list):
532
+ raise TypeError('{} should be list[str] type, but got {}'.format(kw_class_dec, type(kw_class_dec)))
533
+ if kw_class_dec and not isinstance(kw_class_dec[0], str):
534
+ raise TypeError('elements of {} should be str, but got {}'.format(kw_class_dec, type(kw_class_dec[0])))
535
+
536
+ main_stree = SymbolTree.create(model)
537
+ _traverse(main_stree, 0)
538
+ _insert_input(main_stree, arg_name='y_obf')
539
+ _insert_mul_by_name(main_stree, after_name_list=target_list[0])
540
+ new_net = main_stree.get_network()
541
+ return new_net
mindspore/scipy/linalg.py CHANGED
@@ -461,8 +461,8 @@ def lu_pivots_to_permutation(pivots, permutation_size: int):
461
461
  loc = mnp.ix_(*(mnp.arange(0, b) for b in batch_dims))
462
462
  x = permutation[..., i]
463
463
  y = permutation[loc + (j,)]
464
- permutation[..., i] = y
465
464
  permutation[loc + (j,)] = x
465
+ permutation[..., i] = y
466
466
  return permutation
467
467
 
468
468
 
@@ -99,10 +99,14 @@ def minimize(func, x0, args=(), method=None, jac=None, hess=None, hessp=None, bo
99
99
  if it is a callable, it should be a function that returns the gradient vector:
100
100
  :math:`jac(x, *args) -> array\_like, shape (n,)`
101
101
  where x is an array with shape :math:`(n,)` and args is a tuple with the fixed parameters.
102
+ hess (Callable, optional): Method for calculating the Hessian Matrix. Not implemented yet.
103
+ hessp (Callable, optional): Hessian of objective function times an arbitrary vector `p`. Not implemented yet.
104
+ bounds (Sequence, optional): Sequence of `(min, max)` pairs for each element in `x`. Not implemented yet.
105
+ constraints (Callable, optional): representing the inequality constrains, each function in constrains indicates
106
+ the function < 0 as an inequality constrain.
102
107
  tol (float, optional): tolerance for termination. For detailed control, use solver-specific
103
108
  options. Default: ``None`` .
104
- constraints(Callable, optional): representing the inequality constrains, each function in constrains indicates
105
- the function < 0 as an inequality constrain.
109
+ callback (Callable, optional): A callable called after each iteration. Not implemented yet.
106
110
  options (Mapping[str, Any], optional): a dictionary of solver options. All methods accept the following
107
111
  generic options. Default: ``None`` .
108
112
 
@@ -111,7 +115,7 @@ def minimize(func, x0, args=(), method=None, jac=None, hess=None, hessp=None, bo
111
115
  - maxiter (int): Maximum number of iterations to perform. Depending on the
112
116
  method each iteration may use several function evaluations.
113
117
 
114
- The follow options are exclusive to Lagrange method:
118
+ The follow options are exclusive to Lagrange method:
115
119
 
116
120
  - save_tol (list): list of saving tolerance, with the same length with 'constrains'.
117
121
  - obj_weight (float): weight for objective function, usually between 1.0 - 100000.0.
mindspore/train/_utils.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020-2021 Huawei Technologies Co., Ltd
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -66,7 +66,11 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset', create_data_inf
66
66
  # transform data format
67
67
  dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
68
68
  send_epoch_end = bool(dataset_size == -1)
69
- exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end, create_data_info_queue=create_data_info_queue)
69
+ queue_name = _cell_graph_executor.get_queue_name(phase)
70
+ if queue_name is None:
71
+ queue_name = str("")
72
+ exec_dataset = exec_dataset.device_que(send_epoch_end=send_epoch_end,
73
+ create_data_info_queue=create_data_info_queue, queue_name=queue_name)
70
74
  _cell_graph_executor.init_dataset(exec_dataset.queue_name,
71
75
  dataset_size,
72
76
  batch_size,
@@ -105,7 +109,7 @@ def _construct_tensor_list(types, shapes, batch_expand_num=1):
105
109
  new_shape += (item * batch_expand_num,)
106
110
  else:
107
111
  new_shape += (item,)
108
- tensor = Tensor(np.zeros(new_shape, dtype_to_nptype(type_)))
112
+ tensor = Tensor(np.zeros(new_shape, dtype_to_nptype(type_)), dtype=type_)
109
113
  tensor.virtual_flag = True
110
114
  tensor_list.append(tensor)
111
115
  return tensor_list