mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.11__cp38-cp38-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (589) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +4 -1
  3. mindspore/_akg/akg/build_module.py +5 -6
  4. mindspore/_akg/akg/composite/build_module.py +139 -22
  5. mindspore/_akg/akg/composite/split_stitch.py +10 -11
  6. mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
  7. mindspore/_akg/akg/tvm/api.py +4 -3
  8. mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
  9. mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
  10. mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
  11. mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
  12. mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
  13. mindspore/_akg/akg/tvm/build_module.py +16 -1
  14. mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
  15. mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
  16. mindspore/_akg/akg/tvm/ir_builder.py +1 -1
  17. mindspore/_akg/akg/tvm/module.py +1 -2
  18. mindspore/_akg/akg/tvm/stmt.py +2 -2
  19. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  20. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  21. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  22. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  23. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  24. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  25. mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
  26. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  27. mindspore/_akg/akg/utils/kernel_exec.py +98 -274
  28. mindspore/_akg/akg/utils/result_analysis.py +4 -24
  29. mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
  30. mindspore/_akg/akg/utils/util.py +56 -1
  31. mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
  32. mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
  33. mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
  34. mindspore/_check_jit_forbidden_api.py +3 -1
  35. mindspore/_checkparam.py +23 -29
  36. mindspore/_extends/graph_kernel/__init__.py +0 -1
  37. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  38. mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
  39. mindspore/_extends/graph_kernel/splitter.py +4 -11
  40. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
  41. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
  42. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
  43. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  44. mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
  45. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
  46. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
  47. mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
  48. mindspore/_extends/parse/__init__.py +13 -15
  49. mindspore/_extends/parse/namespace.py +7 -33
  50. mindspore/_extends/parse/parser.py +67 -72
  51. mindspore/_extends/parse/resources.py +1 -1
  52. mindspore/_extends/parse/standard_method.py +86 -106
  53. mindspore/_extends/parse/trope.py +1 -1
  54. mindspore/_extends/remote/kernel_build_server.py +25 -7
  55. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  56. mindspore/_install_custom.py +43 -0
  57. mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
  58. mindspore/amp.py +47 -11
  59. mindspore/bin/cache_admin +0 -0
  60. mindspore/bin/cache_server +0 -0
  61. mindspore/boost/boost.py +1 -8
  62. mindspore/boost/boost_cell_wrapper.py +3 -2
  63. mindspore/boost/grad_accumulation.py +1 -1
  64. mindspore/boost/group_loss_scale_manager.py +8 -7
  65. mindspore/common/__init__.py +5 -3
  66. mindspore/common/_jit_fallback_utils.py +6 -0
  67. mindspore/common/_register_for_adapter.py +2 -0
  68. mindspore/common/_register_for_tensor.py +2 -2
  69. mindspore/common/_stub_tensor.py +13 -0
  70. mindspore/common/_utils.py +29 -0
  71. mindspore/common/api.py +174 -259
  72. mindspore/common/auto_dynamic_shape.py +494 -0
  73. mindspore/common/dtype.py +18 -11
  74. mindspore/common/dump.py +6 -4
  75. mindspore/common/initializer.py +14 -14
  76. mindspore/common/jit_config.py +33 -15
  77. mindspore/common/lazy_inline.py +126 -7
  78. mindspore/common/mindir_util.py +101 -0
  79. mindspore/common/parameter.py +51 -41
  80. mindspore/common/seed.py +4 -4
  81. mindspore/common/sparse_tensor.py +13 -14
  82. mindspore/common/tensor.py +243 -165
  83. mindspore/communication/__init__.py +7 -4
  84. mindspore/communication/_comm_helper.py +83 -4
  85. mindspore/communication/management.py +152 -84
  86. mindspore/config/op_info.config +14 -3
  87. mindspore/config/super_bar_config.json +4 -2
  88. mindspore/context.py +152 -61
  89. mindspore/dataset/__init__.py +5 -5
  90. mindspore/dataset/audio/__init__.py +2 -2
  91. mindspore/dataset/audio/transforms.py +52 -52
  92. mindspore/dataset/callback/ds_callback.py +16 -2
  93. mindspore/dataset/core/config.py +68 -51
  94. mindspore/dataset/engine/cache_client.py +33 -7
  95. mindspore/dataset/engine/datasets.py +250 -112
  96. mindspore/dataset/engine/datasets_audio.py +43 -211
  97. mindspore/dataset/engine/datasets_standard_format.py +16 -35
  98. mindspore/dataset/engine/datasets_text.py +43 -67
  99. mindspore/dataset/engine/datasets_user_defined.py +86 -100
  100. mindspore/dataset/engine/datasets_vision.py +219 -1029
  101. mindspore/dataset/engine/iterators.py +11 -4
  102. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
  103. mindspore/dataset/engine/obs/util.py +3 -0
  104. mindspore/dataset/engine/samplers.py +1 -1
  105. mindspore/dataset/engine/validators.py +19 -5
  106. mindspore/dataset/text/__init__.py +3 -3
  107. mindspore/dataset/text/transforms.py +101 -127
  108. mindspore/dataset/text/utils.py +205 -138
  109. mindspore/dataset/transforms/__init__.py +1 -1
  110. mindspore/dataset/transforms/py_transforms_util.py +40 -12
  111. mindspore/dataset/transforms/transforms.py +95 -40
  112. mindspore/dataset/utils/browse_dataset.py +8 -2
  113. mindspore/dataset/utils/line_reader.py +17 -19
  114. mindspore/dataset/vision/__init__.py +3 -3
  115. mindspore/dataset/vision/c_transforms.py +6 -3
  116. mindspore/dataset/vision/transforms.py +409 -287
  117. mindspore/dataset/vision/utils.py +13 -14
  118. mindspore/dataset/vision/validators.py +11 -1
  119. mindspore/experimental/map_parameter.py +14 -0
  120. mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
  121. mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
  122. mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
  123. mindspore/experimental/optim/lr_scheduler.py +1427 -0
  124. mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
  125. mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
  126. mindspore/gen_ops.py +273 -0
  127. mindspore/include/OWNERS +0 -1
  128. mindspore/include/api/data_type.h +2 -1
  129. mindspore/include/api/graph.h +0 -15
  130. mindspore/include/api/kernel.h +2 -0
  131. mindspore/include/api/kernel_api.h +37 -12
  132. mindspore/include/api/model.h +17 -14
  133. mindspore/include/api/status.h +8 -3
  134. mindspore/include/api/types.h +37 -4
  135. mindspore/include/c_api/ms/abstract.h +67 -0
  136. mindspore/include/c_api/ms/attribute.h +197 -0
  137. mindspore/include/c_api/ms/base/handle_types.h +43 -0
  138. mindspore/include/c_api/ms/base/macros.h +32 -0
  139. mindspore/include/c_api/ms/base/status.h +33 -0
  140. mindspore/include/c_api/ms/base/types.h +282 -0
  141. mindspore/include/c_api/ms/context.h +102 -0
  142. mindspore/include/c_api/ms/graph.h +160 -0
  143. mindspore/include/c_api/ms/node.h +606 -0
  144. mindspore/include/c_api/ms/tensor.h +161 -0
  145. mindspore/include/c_api/ms/value.h +84 -0
  146. mindspore/include/dataset/constants.h +6 -5
  147. mindspore/include/dataset/execute.h +23 -13
  148. mindspore/include/dataset/text.h +26 -26
  149. mindspore/include/dataset/transforms.h +13 -13
  150. mindspore/include/dataset/vision.h +60 -60
  151. mindspore/include/dataset/vision_ascend.h +5 -6
  152. mindspore/include/dataset/vision_lite.h +17 -17
  153. mindspore/include/mindapi/base/type_id.h +1 -0
  154. mindspore/include/mindapi/base/types.h +1 -0
  155. mindspore/lib/libdnnl.so.2 +0 -0
  156. mindspore/lib/libjemalloc.so.2 +0 -0
  157. mindspore/lib/libmindspore.so +0 -0
  158. mindspore/lib/libmindspore_backend.so +0 -0
  159. mindspore/lib/libmindspore_common.so +0 -0
  160. mindspore/lib/libmindspore_core.so +0 -0
  161. mindspore/lib/libmindspore_glog.so.0 +0 -0
  162. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  163. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  164. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  165. mindspore/lib/libmindspore_shared_lib.so +0 -0
  166. mindspore/lib/libnnacl.so +0 -0
  167. mindspore/lib/libopencv_core.so.4.5 +0 -0
  168. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  169. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  170. mindspore/lib/libps_cache.so +0 -0
  171. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  172. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  173. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  174. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  175. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  176. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  177. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  178. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  179. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  180. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  181. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  182. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  183. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  184. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  185. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  186. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
  187. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  188. mindspore/lib/plugin/ascend/libakg.so +0 -0
  189. mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
  190. mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
  191. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  192. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  193. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  194. mindspore/lib/plugin/cpu/libakg.so +0 -0
  195. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  196. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  197. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  198. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  199. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  200. mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
  201. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  202. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  203. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  204. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  205. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  206. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  207. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  208. mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
  209. mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
  210. mindspore/nn/__init__.py +0 -2
  211. mindspore/nn/cell.py +313 -74
  212. mindspore/nn/dynamic_lr.py +21 -21
  213. mindspore/nn/layer/activation.py +22 -30
  214. mindspore/nn/layer/basic.py +15 -13
  215. mindspore/nn/layer/channel_shuffle.py +1 -1
  216. mindspore/nn/layer/container.py +271 -9
  217. mindspore/nn/layer/conv.py +323 -204
  218. mindspore/nn/layer/dense.py +8 -5
  219. mindspore/nn/layer/embedding.py +33 -27
  220. mindspore/nn/layer/flash_attention.py +61 -95
  221. mindspore/nn/layer/image.py +8 -6
  222. mindspore/nn/layer/math.py +16 -25
  223. mindspore/nn/layer/normalization.py +107 -66
  224. mindspore/nn/layer/padding.py +1 -1
  225. mindspore/nn/layer/pooling.py +131 -109
  226. mindspore/nn/layer/rnn_cells.py +27 -22
  227. mindspore/nn/layer/rnns.py +13 -16
  228. mindspore/nn/layer/thor_layer.py +1 -1
  229. mindspore/nn/layer/transformer.py +221 -154
  230. mindspore/nn/learning_rate_schedule.py +9 -1
  231. mindspore/nn/loss/loss.py +235 -174
  232. mindspore/nn/optim/ada_grad.py +2 -1
  233. mindspore/nn/optim/adadelta.py +1 -0
  234. mindspore/nn/optim/adafactor.py +2 -1
  235. mindspore/nn/optim/adam.py +7 -4
  236. mindspore/nn/optim/adamax.py +3 -2
  237. mindspore/nn/optim/adasum.py +2 -2
  238. mindspore/nn/optim/asgd.py +2 -3
  239. mindspore/nn/optim/ftrl.py +6 -5
  240. mindspore/nn/optim/lamb.py +7 -4
  241. mindspore/nn/optim/lars.py +1 -1
  242. mindspore/nn/optim/lazyadam.py +5 -3
  243. mindspore/nn/optim/momentum.py +2 -1
  244. mindspore/nn/optim/optimizer.py +53 -4
  245. mindspore/nn/optim/proximal_ada_grad.py +3 -4
  246. mindspore/nn/optim/rmsprop.py +4 -3
  247. mindspore/nn/optim/rprop.py +23 -12
  248. mindspore/nn/optim/sgd.py +26 -11
  249. mindspore/nn/optim/thor.py +9 -7
  250. mindspore/nn/probability/bijector/bijector.py +5 -5
  251. mindspore/nn/probability/bijector/power_transform.py +27 -27
  252. mindspore/nn/probability/bijector/softplus.py +3 -3
  253. mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
  254. mindspore/nn/probability/distribution/bernoulli.py +5 -5
  255. mindspore/nn/probability/distribution/beta.py +3 -3
  256. mindspore/nn/probability/distribution/categorical.py +7 -7
  257. mindspore/nn/probability/distribution/cauchy.py +0 -1
  258. mindspore/nn/probability/distribution/distribution.py +3 -3
  259. mindspore/nn/probability/distribution/gamma.py +3 -3
  260. mindspore/nn/probability/distribution/geometric.py +4 -4
  261. mindspore/nn/probability/distribution/gumbel.py +4 -4
  262. mindspore/nn/probability/distribution/log_normal.py +2 -2
  263. mindspore/nn/probability/distribution/logistic.py +2 -2
  264. mindspore/nn/probability/distribution/poisson.py +4 -4
  265. mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
  266. mindspore/nn/probability/distribution/uniform.py +6 -6
  267. mindspore/nn/wrap/__init__.py +4 -2
  268. mindspore/nn/wrap/cell_wrapper.py +87 -34
  269. mindspore/nn/wrap/grad_reducer.py +8 -5
  270. mindspore/nn/wrap/loss_scale.py +105 -42
  271. mindspore/numpy/array_creations.py +1 -2
  272. mindspore/numpy/array_ops.py +3 -2
  273. mindspore/numpy/utils_const.py +5 -5
  274. mindspore/offline_debug/convert_async.py +2 -2
  275. mindspore/ops/_grad_experimental/__init__.py +0 -5
  276. mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
  277. mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
  278. mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
  279. mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
  280. mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
  281. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
  282. mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
  283. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  284. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
  285. mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
  286. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  287. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
  288. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  289. mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
  290. mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
  291. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  292. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
  293. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
  294. mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
  295. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
  296. mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
  297. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
  298. mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
  299. mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
  300. mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
  301. mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
  302. mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
  303. mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
  304. mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
  305. mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
  306. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
  307. mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
  308. mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
  309. mindspore/ops/_op_impl/tbe/__init__.py +4 -4
  310. mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
  311. mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
  312. mindspore/ops/_primitive_cache.py +1 -1
  313. mindspore/ops/_tracefunc.py +45 -13
  314. mindspore/ops/_utils/utils.py +6 -1
  315. mindspore/ops/_vmap/vmap_array_ops.py +3 -3
  316. mindspore/ops/_vmap/vmap_base.py +3 -3
  317. mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
  318. mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
  319. mindspore/ops/_vmap/vmap_math_ops.py +5 -2
  320. mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
  321. mindspore/ops/arg_dtype_cast.py +54 -0
  322. mindspore/ops/composite/base.py +37 -10
  323. mindspore/ops/composite/math_ops.py +5 -4
  324. mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
  325. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
  326. mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
  327. mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
  328. mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
  329. mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
  330. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
  331. mindspore/ops/deprecated.py +304 -0
  332. mindspore/ops/function/__init__.py +4 -1
  333. mindspore/ops/function/array_func.py +174 -193
  334. mindspore/ops/function/clip_func.py +81 -13
  335. mindspore/ops/function/debug_func.py +1 -1
  336. mindspore/ops/function/grad/grad_func.py +18 -9
  337. mindspore/ops/function/image_func.py +10 -4
  338. mindspore/ops/function/linalg_func.py +5 -5
  339. mindspore/ops/function/math_func.py +575 -386
  340. mindspore/ops/function/nn_func.py +568 -260
  341. mindspore/ops/function/random_func.py +88 -57
  342. mindspore/ops/function/sparse_func.py +1 -1
  343. mindspore/ops/function/sparse_unary_func.py +14 -12
  344. mindspore/ops/function/vmap_func.py +6 -5
  345. mindspore/ops/functional.py +15 -10
  346. mindspore/ops/op_info_register.py +244 -25
  347. mindspore/ops/operations/__init__.py +31 -19
  348. mindspore/ops/operations/_grad_ops.py +71 -7
  349. mindspore/ops/operations/_inner_ops.py +350 -17
  350. mindspore/ops/operations/_quant_ops.py +4 -8
  351. mindspore/ops/operations/_sequence_ops.py +42 -0
  352. mindspore/ops/operations/array_ops.py +68 -282
  353. mindspore/ops/operations/comm_ops.py +107 -59
  354. mindspore/ops/operations/custom_ops.py +94 -70
  355. mindspore/ops/operations/debug_ops.py +8 -4
  356. mindspore/ops/operations/image_ops.py +18 -12
  357. mindspore/ops/operations/inner_ops.py +26 -3
  358. mindspore/ops/operations/math_ops.py +192 -144
  359. mindspore/ops/operations/nn_ops.py +857 -489
  360. mindspore/ops/operations/other_ops.py +0 -22
  361. mindspore/ops/operations/random_ops.py +53 -111
  362. mindspore/ops/operations/sparse_ops.py +3 -1
  363. mindspore/ops/primitive.py +24 -18
  364. mindspore/parallel/_auto_parallel_context.py +68 -8
  365. mindspore/parallel/_cost_model_context.py +2 -2
  366. mindspore/parallel/_offload_context.py +17 -3
  367. mindspore/parallel/_parallel_serialization.py +12 -5
  368. mindspore/parallel/_ps_context.py +12 -0
  369. mindspore/parallel/_tensor.py +18 -13
  370. mindspore/parallel/_transformer/layers.py +5 -3
  371. mindspore/parallel/_transformer/loss.py +1 -0
  372. mindspore/parallel/_transformer/moe.py +2 -2
  373. mindspore/parallel/_transformer/op_parallel_config.py +12 -1
  374. mindspore/parallel/_transformer/transformer.py +23 -3
  375. mindspore/parallel/_utils.py +11 -7
  376. mindspore/parallel/algo_parameter_config.py +85 -5
  377. mindspore/parallel/checkpoint_transform.py +19 -12
  378. mindspore/parallel/shard.py +21 -14
  379. mindspore/profiler/common/struct_type.py +3 -3
  380. mindspore/profiler/common/util.py +4 -2
  381. mindspore/profiler/envprofiling.py +1 -1
  382. mindspore/profiler/parser/aicpu_data_parser.py +5 -3
  383. mindspore/profiler/parser/ascend_flops_generator.py +2 -2
  384. mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
  385. mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
  386. mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
  387. mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
  388. mindspore/profiler/parser/ascend_op_generator.py +6 -6
  389. mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
  390. mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
  391. mindspore/profiler/parser/base_timeline_generator.py +10 -8
  392. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
  393. mindspore/profiler/parser/flops_parser.py +15 -11
  394. mindspore/profiler/parser/framework_parser.py +38 -22
  395. mindspore/profiler/parser/hccl_parser.py +16 -12
  396. mindspore/profiler/parser/integrator.py +22 -11
  397. mindspore/profiler/parser/memory_usage_parser.py +2 -2
  398. mindspore/profiler/parser/minddata_analyzer.py +12 -14
  399. mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
  400. mindspore/profiler/parser/msadvisor_parser.py +8 -4
  401. mindspore/profiler/parser/op_intermediate_parser.py +5 -2
  402. mindspore/profiler/parser/optime_parser.py +1 -1
  403. mindspore/profiler/parser/profiler_info.py +21 -2
  404. mindspore/profiler/parser/step_trace_parser.py +11 -14
  405. mindspore/profiler/profiling.py +179 -89
  406. mindspore/rewrite/api/node.py +102 -19
  407. mindspore/rewrite/api/node_type.py +5 -1
  408. mindspore/rewrite/api/pattern_engine.py +1 -1
  409. mindspore/rewrite/api/scoped_value.py +9 -17
  410. mindspore/rewrite/api/symbol_tree.py +131 -47
  411. mindspore/rewrite/ast_helpers/__init__.py +2 -1
  412. mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
  413. mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
  414. mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
  415. mindspore/rewrite/common/rewrite_elog.py +5 -1
  416. mindspore/rewrite/namer.py +33 -24
  417. mindspore/rewrite/namespace.py +14 -5
  418. mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
  419. mindspore/rewrite/node/call_function.py +79 -0
  420. mindspore/rewrite/node/cell_container.py +135 -0
  421. mindspore/rewrite/node/control_flow.py +88 -0
  422. mindspore/rewrite/{node.py → node/node.py} +273 -234
  423. mindspore/rewrite/node/node_manager.py +254 -0
  424. mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
  425. mindspore/rewrite/parsers/arguments_parser.py +22 -21
  426. mindspore/rewrite/parsers/assign_parser.py +216 -221
  427. mindspore/rewrite/parsers/attribute_parser.py +9 -7
  428. mindspore/rewrite/parsers/class_def_parser.py +174 -113
  429. mindspore/rewrite/parsers/constant_parser.py +9 -6
  430. mindspore/rewrite/parsers/container_parser.py +9 -7
  431. mindspore/rewrite/parsers/for_parser.py +42 -21
  432. mindspore/rewrite/parsers/function_def_parser.py +24 -16
  433. mindspore/rewrite/parsers/if_parser.py +28 -24
  434. mindspore/rewrite/parsers/module_parser.py +196 -25
  435. mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
  436. mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
  437. mindspore/rewrite/parsers/return_parser.py +6 -6
  438. mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
  439. mindspore/rewrite/sparsify/utils.py +1 -1
  440. mindspore/rewrite/symbol_tree.py +523 -578
  441. mindspore/rewrite/symbol_tree_builder.py +9 -193
  442. mindspore/rewrite/symbol_tree_dumper.py +2 -2
  443. mindspore/run_check/_check_version.py +6 -4
  444. mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
  445. mindspore/safeguard/rewrite_obfuscation.py +541 -0
  446. mindspore/scipy/linalg.py +1 -1
  447. mindspore/scipy/ops.py +55 -5
  448. mindspore/scipy/optimize/__init__.py +3 -2
  449. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  450. mindspore/scipy/optimize/minimize.py +7 -3
  451. mindspore/train/_utils.py +7 -3
  452. mindspore/train/amp.py +323 -123
  453. mindspore/train/anf_ir_pb2.py +14 -2
  454. mindspore/train/callback/_backup_and_restore.py +2 -12
  455. mindspore/train/callback/_callback.py +29 -4
  456. mindspore/train/callback/_checkpoint.py +23 -8
  457. mindspore/train/callback/_early_stop.py +2 -2
  458. mindspore/train/callback/_landscape.py +4 -4
  459. mindspore/train/callback/_loss_monitor.py +2 -2
  460. mindspore/train/callback/_on_request_exit.py +2 -2
  461. mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
  462. mindspore/train/callback/_summary_collector.py +15 -8
  463. mindspore/train/callback/_time_monitor.py +58 -5
  464. mindspore/train/data_sink.py +5 -11
  465. mindspore/train/dataset_helper.py +84 -57
  466. mindspore/train/loss_scale_manager.py +2 -2
  467. mindspore/train/metrics/__init__.py +3 -3
  468. mindspore/train/metrics/cosine_similarity.py +1 -1
  469. mindspore/train/metrics/hausdorff_distance.py +3 -2
  470. mindspore/train/metrics/mean_surface_distance.py +3 -2
  471. mindspore/train/metrics/metric.py +39 -19
  472. mindspore/train/metrics/roc.py +2 -2
  473. mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
  474. mindspore/train/mind_ir_pb2.py +85 -36
  475. mindspore/train/model.py +187 -47
  476. mindspore/train/serialization.py +487 -161
  477. mindspore/train/summary/_summary_adapter.py +1 -1
  478. mindspore/train/summary/_writer_pool.py +3 -2
  479. mindspore/train/summary/summary_record.py +37 -17
  480. mindspore/train/train_thor/convert_utils.py +3 -3
  481. mindspore/train/train_thor/dataset_helper.py +1 -1
  482. mindspore/version.py +1 -1
  483. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
  484. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
  485. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
  486. mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
  487. mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
  488. mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
  489. mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
  490. mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
  491. mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
  492. mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
  493. mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
  494. mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
  495. mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
  496. mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
  497. mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
  498. mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
  499. mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
  500. mindspore/_akg/akg/tvm/rpc/base.py +0 -182
  501. mindspore/_akg/akg/tvm/rpc/client.py +0 -436
  502. mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
  503. mindspore/_akg/akg/tvm/rpc/server.py +0 -413
  504. mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
  505. mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
  506. mindspore/_extends/graph_kernel/expander.py +0 -80
  507. mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
  508. mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
  509. mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
  510. mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
  511. mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
  512. mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
  513. mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
  514. mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
  515. mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
  516. mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
  517. mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
  518. mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
  519. mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
  520. mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
  521. mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
  522. mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
  523. mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
  524. mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
  525. mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
  526. mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
  527. mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
  528. mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
  529. mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
  530. mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
  531. mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
  532. mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
  533. mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
  534. mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
  535. mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
  536. mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
  537. mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
  538. mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
  539. mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
  540. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
  541. mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
  542. mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
  543. mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
  544. mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
  545. mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
  546. mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
  547. mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
  548. mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
  549. mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
  550. mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
  551. mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
  552. mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
  553. mindspore/dataset/datapreprocess/__init__.py +0 -20
  554. mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
  555. mindspore/include/api/net.h +0 -142
  556. mindspore/nn/lr_scheduler.py +0 -262
  557. mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
  558. mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
  559. mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
  560. mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
  561. mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
  562. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
  563. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
  564. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
  565. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
  566. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
  567. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  568. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  569. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  570. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  571. mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
  572. mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
  573. mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
  574. mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
  575. mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
  576. mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
  577. mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
  578. mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
  579. mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
  580. mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
  581. mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
  582. mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
  583. mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
  584. mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
  585. mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
  586. mindspore/rewrite/node_visitor.py +0 -44
  587. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  588. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  589. {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,541 @@
1
+ # Copyright 2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """obfuscate network based on rewrite interfaces."""
16
+ import os
17
+ import re
18
+ import secrets
19
+ from pathlib import Path
20
+
21
+ from mindspore import ops, nn
22
+ from mindspore.common.tensor import Tensor
23
+ from mindspore import log as logger
24
+ from mindspore import load_checkpoint, save_checkpoint
25
+ from mindspore.rewrite import SymbolTree, Node, NodeType, TreeNodeHelper, ScopedValue
26
+ from mindspore.rewrite.parsers.class_def_parser import ClassDefParser
27
+ from mindspore.rewrite.parsers.class_def_parser import ModuleParser
28
+
29
+ OBF_RATIOS_LENGTH = 1
30
+ MAX_OBF_RATIOS_NUM = 50
31
+ OBF_RATIOS_WIDTH = 0
32
+ OBF_RATIOS_INSERT_INDEX = 0
33
+
34
+
35
+ def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./', obfuscate_scale=100):
36
+ """
37
+ obfuscate the plaintext checkpoint files. Usually used in conjunction with
38
+ :func:`mindspore.load_obf_params_into_net`.
39
+ interface.
40
+
41
+ Args:
42
+ network (nn.Cell): The original network that need to be obfuscated.
43
+ ckpt_files (str): The directory path of original ckpt files.
44
+ target_modules (list[str]): The target module of network that need to be obfuscated. The first string
45
+ represents the network path of target module in original network, which should be in form of ``'A/B/C'``.
46
+ The second string represents the obfuscation target module, which should be in form of ``'D|E|F'``. For
47
+ example, thr target_modules of GPT2 can be ``['backbone/blocks/attention', 'dense1|dense2|dense3']``.
48
+ If target_modules has the third value, it should be in the format of 'obfuscate_layers:all' or
49
+ 'obfuscate_layers:int', which represents the number of layers need to be obfuscated of duplicate layers
50
+ (such as transformer layers or resnet blocks). If target_modules is ``None``, the function would search
51
+ target modules by itself. If found, the searched target module would be used, otherwise suggested target
52
+ modules would be given with warning log. Default: ``None``.
53
+ saved_path (str): The directory path for saving obfuscated ckpt files. Default: ``'./'``.
54
+ obfuscate_scale (Union[float, int]): Obfuscate scale of weights. The generated random obf_ratios will be in
55
+ range of (1 / obfuscate_scale, obfuscate_scale). Default: 100.
56
+
57
+ Raises:
58
+ TypeError: If `network` is not nn.Cell.
59
+ TypeError: If `ckpt_files` is not string or `saved_path` is not string.
60
+ TypeError: If `target_modules` is not list.
61
+ TypeError: If target_modules's elements are not string.
62
+ ValueError: If `ckpt_files` is not exist or `saved_path` is not exist.
63
+ ValueError: If the number of elements of `target_modules` is less than ``2``.
64
+ ValueError: If the first string of `target_modules` contains characters other than uppercase and lowercase
65
+ letters, numbers, ``'_'`` and ``'/'``.
66
+ ValueError: If the second string of `target_modules` is empty or contains characters other than uppercase and
67
+ lowercase letters, numbers, ``'_'`` and ``'|'``.
68
+ ValueError: If the third string of `target_modules` is not in the format of 'obfuscate_layers:all' or
69
+ 'obfuscate_layers:int'.
70
+
71
+ Returns:
72
+ list[float], obf_ratios, which is the necessary data that needs to be load when running obfuscated network.
73
+
74
+ Examples:
75
+ >>> from mindspore import obfuscate_ckpt, save_checkpoint
76
+ >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
77
+ >>> net = LeNet5()
78
+ >>> save_checkpoint(net, './test_net.ckpt')
79
+ >>> target_modules = ['', 'fc1|fc2']
80
+ >>> obfuscate_ckpt(net, target_modules, './', './')
81
+ """
82
+ if not isinstance(network, nn.Cell):
83
+ raise TypeError("network must be nn.Cell, but got {}.".format(type(network)))
84
+ _check_dir_path('ckpt_files', ckpt_files)
85
+ _check_dir_path('saved_path', saved_path)
86
+ # Try to find default target modules
87
+ if target_modules is None:
88
+ to_split_modules = _get_default_target_modules(ckpt_files)
89
+ else:
90
+ if len(target_modules) >= 1 and target_modules[0] == '/':
91
+ target_modules[0] = ''
92
+ to_split_modules = target_modules
93
+ if not _check_valid_target(network, to_split_modules):
94
+ raise ValueError("The obfuscate module path {} is not exist, please check the input 'target_modules'."
95
+ .format(to_split_modules))
96
+ if (not isinstance(obfuscate_scale, (float, int))) or (obfuscate_scale <= 1):
97
+ raise ValueError("obfuscate_scale must be float or int, and larger than 1, but got {}."
98
+ .format(obfuscate_scale))
99
+ # generate and save obf_ratios to saved_path
100
+ path_list = to_split_modules[0].split('/')
101
+ target_list = to_split_modules[1].split('|')
102
+ global OBF_RATIOS_LENGTH
103
+ number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
104
+ if number_of_ratios > MAX_OBF_RATIOS_NUM:
105
+ OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH
106
+ number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
107
+ obf_ratios = []
108
+ secrets_generator = secrets.SystemRandom()
109
+ for _ in range(number_of_ratios):
110
+ secure_float = secrets_generator.uniform(1 / obfuscate_scale, obfuscate_scale)
111
+ obf_ratios.append(secure_float)
112
+ # start obfuscate ckpt
113
+ ckpt_dir_files = os.listdir(ckpt_files)
114
+ for ckpt_name in ckpt_dir_files:
115
+ sub_path = os.path.abspath(ckpt_files) + '/' + ckpt_name
116
+ if Path(sub_path).is_dir():
117
+ sub_ckpt_file_list = os.listdir(sub_path)
118
+ new_saved_path = os.path.abspath(saved_path) + '/' + ckpt_name
119
+ if not os.path.exists(new_saved_path):
120
+ try:
121
+ os.mkdir(new_saved_path, mode=0o700)
122
+ except FileExistsError:
123
+ pass
124
+ for sub_ckpt_name in sub_ckpt_file_list:
125
+ if not sub_ckpt_name.endswith('.ckpt'):
126
+ continue
127
+ _obfuscate_single_ckpt(os.path.abspath(sub_path) + '/' + sub_ckpt_name, obf_ratios, path_list,
128
+ target_list, new_saved_path)
129
+ else:
130
+ if not ckpt_name.endswith('.ckpt'):
131
+ continue
132
+ _obfuscate_single_ckpt(os.path.abspath(ckpt_files) + '/' + ckpt_name, obf_ratios, path_list,
133
+ target_list, saved_path)
134
+ return obf_ratios
135
+
136
+
137
+ def _obfuscate_single_ckpt(ckpt_name, obf_ratios, path_list, target_list, saved_path):
138
+ """Obfuscate single ckpt file"""
139
+ module_has_been_obfuscated = set()
140
+ try:
141
+ ckpt_param = load_checkpoint(ckpt_name)
142
+ except (ValueError, TypeError, OSError):
143
+ logger.error("Load checkpoint failed for file {}.".format(ckpt_name))
144
+ return None
145
+ obf_ratios_index = -1
146
+ for item in ckpt_param:
147
+ module = _get_valid_module(item, path_list, target_list)
148
+ if module:
149
+ layer_index = _judge_layer_index(item)
150
+ if layer_index >= OBF_RATIOS_LENGTH:
151
+ continue
152
+ if module not in module_has_been_obfuscated:
153
+ module_has_been_obfuscated.add(module)
154
+ obf_ratios_index += 1
155
+ ratio_total_index = layer_index * OBF_RATIOS_WIDTH + obf_ratios_index % OBF_RATIOS_WIDTH
156
+ ckpt_param[item].set_data(ckpt_param[item].value() / obf_ratios[ratio_total_index])
157
+ # save the obfuscated model to saved_path
158
+ obf_param_list = []
159
+ for item in ckpt_param:
160
+ obf_param_list.append({'name': item, 'data': ckpt_param[item]})
161
+ ckpt_file_name = ckpt_name.split('/')[-1]
162
+ obf_ckpt_file_name = ckpt_file_name.split('.')[0] + '_obf' + '.ckpt'
163
+ save_checkpoint(obf_param_list, os.path.abspath(saved_path) + '/' + obf_ckpt_file_name)
164
+ return None
165
+
166
+
167
+ def load_obf_params_into_net(network, target_modules, obf_ratios, data_parallel_num=1, **kwargs):
168
+ """
169
+ load obfuscate ratios into obfuscated network. Usually used in conjunction with :func:`mindspore.obfuscate_ckpt`
170
+ interface.
171
+
172
+ Args:
173
+ network (nn.Cell): The original network that need to be obfuscated.
174
+ target_modules (list[str]): The target module of network that need to be obfuscated. The first string
175
+ represents the network path of target module in original network, which should be in form of ``'A/B/C'``.
176
+ The second string represents the obfuscation target module, which should be in form of ``'D|E|F'``. For
177
+ example, thr target_modules of GPT2 can be ``['backbone/blocks/attention', 'dense1|dense2|dense3']``.
178
+ If target_modules has the third value, it should be in the format of 'obfuscate_layers:all' or
179
+ 'obfuscate_layers:int', which represents the number of layers need to be obfuscated of duplicate layers
180
+ (such as transformer layers or resnet blocks).
181
+ data_parallel_num (int): The data parallel number of parallel training. Default: 1.
182
+ obf_ratios (Tensor): The obf ratios generated when execute :func:`mindspore.obfuscate_ckpt`.
183
+ kwargs (dict): Configuration options dictionary.
184
+
185
+ - ignored_func_decorators (list[str]): The name list of function decorators in network's python code.
186
+ - ignored_class_decorators (list[str]): The name list of class decorators in network's python code.
187
+
188
+ Raises:
189
+ TypeError: If `network` is not nn.Cell.
190
+ TypeError: If `obf_ratios` is not Tensor.
191
+ TypeError: If `target_modules` is not list.
192
+ TypeError: If target_modules's elements are not string.
193
+ ValueError: If the number of elements of `target_modules` is less than ``2``.
194
+ ValueError: If `obf_ratios` is empty Tensor.
195
+ ValueError: If the first string of `target_modules` contains characters other than uppercase and lowercase
196
+ letters, numbers, ``'_'`` and ``'/'``.
197
+ ValueError: If the second string of `target_modules` is empty or contains characters other than uppercase and
198
+ lowercase letters, numbers, ``'_'`` and ``'|'``.
199
+ ValueError: If the third string of `target_modules` is not in the format of 'obfuscate_layers:all' or
200
+ 'obfuscate_layers:int'.
201
+ TypeError: If `ignored_func_decorators` is not list[str] or `ignored_class_decorators` is not list[str].
202
+
203
+ Examples:
204
+ >>> from mindspore import obfuscate_ckpt, save_checkpoint, load_checkpoint, Tensor
205
+ >>> import mindspore.common.dtype as mstype
206
+ >>> import numpy as np
207
+ >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
208
+ >>> net = LeNet5()
209
+ >>> save_checkpoint(net, './test_net.ckpt')
210
+ >>> target_modules = ['', 'fc1|fc2']
211
+ >>> # obfuscate ckpt files
212
+ >>> obfuscate_ckpt(net, target_modules, './', './')
213
+ >>> # load obf ckpt into network
214
+ >>> new_net = LeNet5()
215
+ >>> load_checkpoint('./test_net_obf.ckpt', new_net)
216
+ >>> obf_ratios = Tensor(np.load('./obf_ratios.npy'), mstype.float16)
217
+ >>> obf_net = load_obf_params_into_net(new_net, target_modules, obf_ratios)
218
+ """
219
+ if not isinstance(network, nn.Cell):
220
+ raise TypeError("network must be nn.Cell, but got {}.".format(type(network)))
221
+ if not isinstance(obf_ratios, Tensor):
222
+ raise TypeError("obf_ratios must be MindSpore Tensor, but got {}.".format(type(obf_ratios)))
223
+ if obf_ratios.size == 0:
224
+ raise ValueError("obf_ratios can not be empty.")
225
+ if not _check_valid_target(network, target_modules):
226
+ raise ValueError("{} is not exist, please check the input 'target_modules'.".format(target_modules))
227
+ if (not isinstance(data_parallel_num, int)) or (data_parallel_num <= 0):
228
+ raise ValueError("data_parallel_num must be positive number, but got {}.".format(data_parallel_num))
229
+ if len(target_modules) >= 1 and target_modules[0] == '/':
230
+ target_modules[0] = ''
231
+ path_list = target_modules[0].split('/')
232
+ path_len = len(path_list)
233
+ target_list = []
234
+ for _ in range(path_len):
235
+ target_list.append([])
236
+ target_list.append(target_modules[1].split('|'))
237
+ global MAX_OBF_RATIOS_NUM, OBF_RATIOS_LENGTH
238
+ number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
239
+ if number_of_ratios > MAX_OBF_RATIOS_NUM:
240
+ OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH
241
+ number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
242
+ MAX_OBF_RATIOS_NUM = number_of_ratios
243
+ rewrite_network = _obfuscate_network(network, path_list, target_list, data_parallel_num=data_parallel_num, **kwargs)
244
+ setattr(rewrite_network, 'obf_ratios', obf_ratios)
245
+ return rewrite_network
246
+
247
+
248
+ def _check_dir_path(name, dir_path):
249
+ """check directory path"""
250
+ if not isinstance(dir_path, str):
251
+ raise TypeError("{} must be string, but got {}.".format(name, type(dir_path)))
252
+ if not os.path.exists(dir_path):
253
+ raise ValueError("{} is not exist, please check the input {}.".format(dir_path, name))
254
+ if not Path(dir_path).is_dir():
255
+ raise TypeError("{} must be a directory path, but got {}.".format(name, dir_path))
256
+
257
+
258
+ def _judge_layer_index(layer_name):
259
+ """Judge the layer index of target layers"""
260
+ split_name = layer_name.split('.')
261
+ for split_str in split_name[:]:
262
+ if split_str.isdigit():
263
+ return int(split_str)
264
+ return 0
265
+
266
+
267
+ def _check_valid_target(network, target_modules):
268
+ """check whether the input 'target_modules' exists"""
269
+ if not isinstance(target_modules, list):
270
+ raise TypeError("target_modules type should be list, but got {}.".format(type(target_modules)))
271
+ if len(target_modules) < 2:
272
+ raise ValueError("target_modules should contain at least two string values, in the form of ['A/B/C', 'D1|D2'],"
273
+ "but got {}.".format(target_modules))
274
+ if (not isinstance(target_modules[0], str)) or (not isinstance(target_modules[1], str)):
275
+ raise TypeError("The values of target_modules should be string, but got {} and {}.".
276
+ format(type(target_modules[0]), type(target_modules[1])))
277
+
278
+ if not target_modules[1]:
279
+ raise ValueError("{} should be a non-empty string value, in the form of 'D1|D2'"
280
+ .format(target_modules[1]))
281
+ if not re.fullmatch(pattern=r'([a-zA-Z]*[0-9]*\/*_*)*', string=target_modules[0]) \
282
+ or not re.fullmatch(pattern=r'([a-zA-Z]*[0-9]*\|*_*)*', string=target_modules[1]):
283
+ raise ValueError("please check the input 'target_modules'{},it should be in the form of ['A/B/C', 'D1|D2']."
284
+ "target_modules[0] can only contain uppercase and lowercase letters, numbers, '_' and '/',"
285
+ "target_modules[1] can only contain uppercase and lowercase letters, numbers, '_' and '|'"
286
+ .format(target_modules))
287
+ # target_modules[0] is allowed to be '', it means the main network path
288
+ path_list = target_modules[0].split('/')
289
+ target_list = target_modules[1].split('|')
290
+ net = network
291
+ # DFS check whether path_list is valid
292
+ stk = [net]
293
+ i = 0
294
+ global OBF_RATIOS_LENGTH
295
+ OBF_RATIOS_LENGTH = 1
296
+ while stk and i < len(path_list):
297
+ net = stk.pop()
298
+ if hasattr(net, path_list[i]):
299
+ net = getattr(net, path_list[i])
300
+ i += 1
301
+ if isinstance(net, nn.CellList):
302
+ OBF_RATIOS_LENGTH *= len(net)
303
+ for n in net:
304
+ stk.append(n)
305
+ elif isinstance(net, nn.Cell):
306
+ stk.append(net)
307
+ else:
308
+ raise TypeError("Target_modules[0] should be a subgraph and it's type should be nn.Cell(nn.CellList),"
309
+ "but got type {}".format(type(net)))
310
+ if target_modules[0] != '' and i != len(path_list):
311
+ raise ValueError("the path {} does not exist.".format(target_modules[0]))
312
+ # check whether target_list is valid
313
+ global OBF_RATIOS_WIDTH
314
+ OBF_RATIOS_WIDTH = 0
315
+ for target in target_list:
316
+ if not hasattr(net, target):
317
+ logger.warning("{} does not exist in the path {}".format(target, target_modules[0]))
318
+ else:
319
+ OBF_RATIOS_WIDTH += 1
320
+ if OBF_RATIOS_WIDTH == 0:
321
+ raise ValueError("all targets {} do not exist in the path {}.".format(target_list, target_modules[0]))
322
+ _update_max_obf_ratios_num(target_modules)
323
+ return True
324
+
325
+
326
+ def _update_max_obf_ratios_num(target_modules):
327
+ """Update MAX_OBF_RATIOS_NUM"""
328
+ if len(target_modules) >= 3:
329
+ obfuscate_layers = target_modules[2].split(':')
330
+ if len(obfuscate_layers) != 2 or obfuscate_layers[0] != 'obfuscate_layers':
331
+ raise ValueError("The third value of target_modules should be in the format of 'obfuscate_layers:all' or"
332
+ "'obfuscate_layers:int'")
333
+ global MAX_OBF_RATIOS_NUM
334
+ if obfuscate_layers[1] == 'all':
335
+ MAX_OBF_RATIOS_NUM = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
336
+ else:
337
+ if not obfuscate_layers[1].isdigit():
338
+ raise ValueError(
339
+ "The third value of target_modules should be in the format of 'obfuscate_layers:all' or"
340
+ "'obfuscate_layers:int'")
341
+ MAX_OBF_RATIOS_NUM = int(obfuscate_layers[1]) * OBF_RATIOS_WIDTH
342
+
343
+
344
+ def _get_default_target_modules(ckpt_files):
345
+ """Get the default or suggested target modules, if the target modules is None."""
346
+
347
+ def _split_to_path_and_target(module, target):
348
+ # split module into path list and target list
349
+ target_index = module.index(target)
350
+ path = module[:target_index - 1]
351
+ target = module[target_index:].split('/')[0]
352
+ return path, target
353
+
354
+ def _find_default_obfuscate_modules(net_path):
355
+ # find modules including the default paths
356
+ default_module = {'attention'}
357
+ for module in default_module:
358
+ if module in net_path and module not in candidate_modules:
359
+ candidate_modules.append(net_path)
360
+ # find the default targets in the default module
361
+ default_target = {'dense', 'query', 'key', 'value'}
362
+ for target in default_target:
363
+ for candidate in candidate_modules:
364
+ if target in candidate:
365
+ path, target = _split_to_path_and_target(candidate, target)
366
+ if path not in paths:
367
+ paths.append(path)
368
+ if target not in targets:
369
+ targets.append(target)
370
+
371
+ def _find_suggested_obfuscate_modules(net_path):
372
+ default_target = {'dense', 'query', 'key', 'value'}
373
+ for target in default_target:
374
+ # find the suggest modules
375
+ if target in net_path:
376
+ path, target = _split_to_path_and_target(net_path, target)
377
+ if [path, target] not in suggest_modules:
378
+ suggest_modules.append([path, target])
379
+
380
+ # store the potential candidate_modules
381
+ candidate_modules = []
382
+ suggest_modules = []
383
+ paths = []
384
+ targets = []
385
+ ckpt_dir_files = os.listdir(ckpt_files)
386
+ for ckpt_name in ckpt_dir_files:
387
+ if not ckpt_name.endswith('.ckpt'):
388
+ continue
389
+ try:
390
+ ckpt_param = load_checkpoint(os.path.abspath(ckpt_files) + '/' + ckpt_name)
391
+ except (ValueError, TypeError, OSError):
392
+ logger.error("Load checkpoint failed for file {}.".format(os.path.abspath(ckpt_files) + '/' + ckpt_name))
393
+ return None
394
+ for item in ckpt_param:
395
+ param_path = _remove_digit(item)
396
+ param_path = '/'.join(param_path)
397
+ # find candidate modules including the default paths and append candidate_modules
398
+ _find_default_obfuscate_modules(param_path)
399
+ # give the suggested modules and find the default targets in the default module
400
+ _find_suggested_obfuscate_modules(param_path)
401
+ if paths and targets:
402
+ target_modules = [paths[0], '|'.join(targets)]
403
+ logger.warning("The default obfuscate modules is obtained:{}".format(target_modules))
404
+ return target_modules
405
+ # logging the suggested target module
406
+ logger.warning("The default obfuscate modules can not be obtained. The suggested possible paths are given below: {}"
407
+ .format(suggest_modules))
408
+ raise ValueError("Can not get the default path, please specify the path in the form of ['A/B/C', 'D1|D2']")
409
+
410
+
411
+ def _get_valid_module(item, path_list, target_list):
412
+ """get the valid module"""
413
+ number_path = len(path_list)
414
+ net_path = _remove_digit(item)
415
+ net_path = '/'.join(net_path[:number_path])
416
+ tar_path = '/'.join(path_list)
417
+ # update the weights with obf_ratios in target module
418
+ if net_path == tar_path:
419
+ for target in target_list:
420
+ if target in item.split('.'):
421
+ target_index = item.split('.').index(target)
422
+ module = ''.join(item.split('.')[:target_index + 1])
423
+ return module
424
+ return None
425
+
426
+
427
+ def _remove_digit(item):
428
+ """remove digit in the parameter path"""
429
+ param_path = item.split('.')
430
+ for tmp_str in param_path[:]:
431
+ if tmp_str.isdigit():
432
+ param_path.remove(tmp_str)
433
+ return param_path
434
+
435
+
436
+ def _obfuscate_network(model, path_list, target_list, data_parallel_num=1, **kwargs):
437
+ """obfuscate original network, including add mul operation and add inputs for passing obf_ratio."""
438
+
439
+ def _insert_input(stree: SymbolTree, arg_name: str = 'y_obf'):
440
+ """add inputs for passing obf_ratio"""
441
+ last_input = None
442
+ for node in stree.nodes():
443
+ if node.get_node_type() == NodeType.Input:
444
+ last_input = node
445
+ position = stree.after(last_input)
446
+ # the insert input node name would be 'input_y_obf'
447
+ new_input_node = last_input.create_input(arg_name)
448
+ stree.insert(position, new_input_node)
449
+
450
+ def _insert_mul(stree: SymbolTree, node: Node, index: int):
451
+ """add mul operation for original network"""
452
+ arg_list = node.get_targets().copy()
453
+ input_y_node = stree.get_node("input_y_obf")
454
+ v: str = input_y_node.get_targets()[0].value
455
+ sv: ScopedValue = ScopedValue.create_naming_value(v + f'[{index}]')
456
+ arg_list.append(sv)
457
+ target_list = node.get_targets().copy()
458
+ if data_parallel_num > 1:
459
+ logger.info("Data parallel number is: {}".format(data_parallel_num))
460
+ new_mul_node = node.create_call_cell(cell=ops.Mul().shard(((data_parallel_num, 1), ())),
461
+ targets=target_list, args=arg_list, name='mul')
462
+ else:
463
+ new_mul_node = node.create_call_cell(cell=ops.Mul(), targets=target_list, args=arg_list, name='mul')
464
+ position = stree.after(node)
465
+ stree.insert(position, new_mul_node)
466
+
467
+ def _insert_mul_by_name(stree: SymbolTree, after_name_list: list):
468
+ """add mul operation after the target nodes according the name of them"""
469
+ if not after_name_list:
470
+ return
471
+ for node in stree.nodes():
472
+ for after_name in after_name_list:
473
+ if node.get_name() == after_name:
474
+ global OBF_RATIOS_INSERT_INDEX
475
+ if OBF_RATIOS_INSERT_INDEX < MAX_OBF_RATIOS_NUM:
476
+ _insert_mul(stree, node, OBF_RATIOS_INSERT_INDEX)
477
+ OBF_RATIOS_INSERT_INDEX += 1
478
+
479
+ def _update_subnet(stree: SymbolTree, substree: SymbolTree, subnode: Node):
480
+ """update the network once the subnet is obfuscated"""
481
+ new_net = substree.get_network()
482
+ input_y_node = substree.get_node("input_y_obf")
483
+ if input_y_node is None:
484
+ return
485
+ arg_list = subnode.get_args().copy()
486
+ kwargs_list = list(subnode.get_kwargs().values())
487
+ arg_list.extend(kwargs_list)
488
+ v: str = input_y_node.get_targets()[0].value
489
+ arg_obf: ScopedValue = ScopedValue.create_naming_value("y_obf=" + v)
490
+ arg_list.append(arg_obf)
491
+ target_list = subnode.get_targets().copy()
492
+ name = subnode.get_name()
493
+ new_node = subnode.create_call_cell(cell=new_net, targets=target_list, args=arg_list, name=name)
494
+ stree.replace(subnode, [new_node])
495
+
496
+ def _traverse(stree, i=0):
497
+ """traverse and obfuscate the original network"""
498
+ if len(path_list) == i:
499
+ return
500
+ for node in stree.nodes():
501
+ node_name = node.get_name()
502
+ if node.get_node_type() == NodeType.Tree and node_name.startswith(path_list[i]):
503
+ sub_stree = TreeNodeHelper.get_sub_tree(node)
504
+ _traverse(sub_stree, i + 1)
505
+ _insert_input(sub_stree, arg_name='y_obf')
506
+ _insert_mul_by_name(sub_stree, after_name_list=target_list[i + 1])
507
+ _update_subnet(stree, sub_stree, node)
508
+
509
+ def _register_denied_func_decorators(fn):
510
+ """set the function decorators which should be denied for parse"""
511
+ name = "denied_function_decorator_list"
512
+ setattr(ClassDefParser, name, fn)
513
+
514
+ def _register_denied_class_decorators(fn):
515
+ """set the class decorators which should be denied for parse"""
516
+ name = "denied_class_decorator_list"
517
+ setattr(ModuleParser, name, fn)
518
+
519
+ if 'ignored_func_decorators' in kwargs.keys():
520
+ kw_func_dec = kwargs["ignored_func_decorators"]
521
+ if not isinstance(kw_func_dec, list):
522
+ raise TypeError('{} should be list, but got {}'.format(kw_func_dec, type(kw_func_dec)))
523
+ if kw_func_dec and not isinstance(kw_func_dec[0], str):
524
+ raise TypeError('elements of {} should be str, but got {}'.format(kw_func_dec, type(kw_func_dec[0])))
525
+ _register_denied_func_decorators(kw_func_dec)
526
+ else:
527
+ _register_denied_func_decorators(["_args_type_validator_check", "_LogActionOnce", "cell_attr_register"])
528
+ if 'ignored_class_decorators' in kwargs.keys():
529
+ kw_class_dec = kwargs["ignored_class_decorators"]
530
+ _register_denied_class_decorators(kw_class_dec)
531
+ if not isinstance(kw_class_dec, list):
532
+ raise TypeError('{} should be list[str] type, but got {}'.format(kw_class_dec, type(kw_class_dec)))
533
+ if kw_class_dec and not isinstance(kw_class_dec[0], str):
534
+ raise TypeError('elements of {} should be str, but got {}'.format(kw_class_dec, type(kw_class_dec[0])))
535
+
536
+ main_stree = SymbolTree.create(model)
537
+ _traverse(main_stree, 0)
538
+ _insert_input(main_stree, arg_name='y_obf')
539
+ _insert_mul_by_name(main_stree, after_name_list=target_list[0])
540
+ new_net = main_stree.get_network()
541
+ return new_net
mindspore/scipy/linalg.py CHANGED
@@ -461,8 +461,8 @@ def lu_pivots_to_permutation(pivots, permutation_size: int):
461
461
  loc = mnp.ix_(*(mnp.arange(0, b) for b in batch_dims))
462
462
  x = permutation[..., i]
463
463
  y = permutation[loc + (j,)]
464
- permutation[..., i] = y
465
464
  permutation[loc + (j,)] = x
465
+ permutation[..., i] = y
466
466
  return permutation
467
467
 
468
468
 
mindspore/scipy/ops.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
1
+ # Copyright 2021-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -156,14 +156,64 @@ class LU(PrimitiveWithInfer):
156
156
 
157
157
 
158
158
  class LinearSumAssignment(Primitive):
159
- """Solve the linear sum assignment problem."""
159
+ r"""
160
+ Solve the linear sum assignment problem.
161
+
162
+ The assignment problem is represented as follows:
163
+
164
+ .. math::
165
+ min\sum_{i}^{} \sum_{j}^{} C_{i,j} X_{i,j}
166
+
167
+ where :math:`C` is cost matrix, :math:`X_{i,j} = 1` means column :math:`j` is assigned to row :math:`i` .
168
+
169
+ Inputs:
170
+ - **cost_matrix** (Tensor) - 2-D cost matrix. Tensor of shape :math:`(M, N)` .
171
+ - **dimension_limit** (Tensor, optional) - A scalar used to limit the actual size of the 2nd dimension of
172
+ ``cost_matrix``. Default is ``Tensor(sys.maxsize)``, which means no limitation. The type is 0-D int64
173
+ Tensor.
174
+ - **maximize** (bool) - Calculate a maximum weight matching if true, otherwise calculate a minimum weight
175
+ matching.
176
+
177
+ Outputs:
178
+ A tuple of tensors containing 'row_idx' and 'col_idx'.
179
+
180
+ - **row_idx** (Tensor) - Row indices of the problem. If `dimension_limit` is given, -1 would be padded at the
181
+ end. The shape is :math:`(N, )` , where :math:`N` is the minimum value of `cost_matrix` dimension.
182
+ - **col_idx** (Tensor) - Column indices of the problem. If `dimension_limit` is given, -1 would be padded at
183
+ the end. The shape is :math:`(N, )` , where :math:`N` is the minimum value of `cost_matrix` dimension.
184
+
185
+ Raises:
186
+ TypeError: If the data type of `cost_matrix` is not the type in [float16, float32, float64,
187
+ int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool]
188
+ TypeError: If the type of `maximize` is not bool.
189
+ TypeError: If the data type of `dimension_limit` is not int64.
190
+ ValueError: If the rank of `cost_matrix` is not 2.
191
+ ValueError: If the number of input args is not 3.
192
+
193
+
194
+ Supported Platforms:
195
+ ``Ascend`` ``CPU``
196
+
197
+ Examples:
198
+ >>> import mindspore as ms
199
+ >>> import numpy as np
200
+ >>> from mindspore import Tensor
201
+ >>> from mindspore.scipy.ops import LinearSumAssignment
202
+ >>> lsap = LinearSumAssignment()
203
+ >>> cost_matrix = Tensor(np.array([[2, 3, 3], [3, 2, 3], [3, 3, 2]])).astype(ms.float64)
204
+ >>> dimension_limit = Tensor(2)
205
+ >>> maximize = False
206
+ >>> a, b = lsap(cost_matrix, dimension_limit, maximize)
207
+ >>> print(a)
208
+ [0 1 -1]
209
+ >>> print(b)
210
+ [0 1 -1]
211
+ """
160
212
 
161
213
  @prim_attr_register
162
214
  def __init__(self):
163
- super().__init__("LinearSumAssignment")
215
+ super().__init__(name="LinearSumAssignment")
164
216
  self.init_prim_io_names(inputs=['cost_matrix', 'dimension_limit', 'maximize'], outputs=['row_ind', 'col_ind'])
165
- self.add_prim_attr("cust_aicpu", "mindspore_aicpu_kernels")
166
-
167
217
 
168
218
  # pylint: disable=C0413,W0611
169
219
  from .ops_grad import get_bprpo_eigh, get_bprpo_trsm
@@ -1,4 +1,4 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
1
+ # Copyright 2021-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -15,5 +15,6 @@
15
15
  """Optimize submodule"""
16
16
  from .minimize import minimize
17
17
  from .line_search import line_search
18
+ from .linear_sum_assignment import linear_sum_assignment
18
19
 
19
- __all__ = ["minimize", "line_search"]
20
+ __all__ = ["minimize", "line_search", "linear_sum_assignment"]