mindspore 2.2.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/_akg/akg/composite/build_module.py +104 -20
  3. mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
  4. mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
  5. mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
  6. mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
  7. mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
  8. mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
  9. mindspore/_akg/akg/utils/composite_op_helper.py +7 -2
  10. mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
  11. mindspore/_akg/akg/utils/kernel_exec.py +41 -15
  12. mindspore/_akg/akg/utils/tbe_codegen_utils.py +27 -6
  13. mindspore/_akg/akg/utils/util.py +56 -1
  14. mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
  15. mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
  16. mindspore/_checkparam.py +3 -3
  17. mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
  18. mindspore/_extends/graph_kernel/splitter.py +3 -2
  19. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +83 -66
  20. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -4
  21. mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
  22. mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +2 -1
  23. mindspore/_extends/parse/__init__.py +3 -2
  24. mindspore/_extends/parse/parser.py +6 -1
  25. mindspore/_extends/parse/standard_method.py +14 -11
  26. mindspore/_extends/remote/kernel_build_server.py +2 -1
  27. mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
  28. mindspore/bin/cache_admin +0 -0
  29. mindspore/bin/cache_server +0 -0
  30. mindspore/common/_utils.py +16 -0
  31. mindspore/common/api.py +1 -1
  32. mindspore/common/auto_dynamic_shape.py +81 -85
  33. mindspore/common/dump.py +1 -1
  34. mindspore/common/tensor.py +3 -20
  35. mindspore/config/op_info.config +1 -1
  36. mindspore/context.py +11 -4
  37. mindspore/dataset/engine/cache_client.py +8 -5
  38. mindspore/dataset/engine/datasets_standard_format.py +5 -0
  39. mindspore/dataset/vision/transforms.py +21 -21
  40. mindspore/experimental/optim/adam.py +1 -1
  41. mindspore/gen_ops.py +1 -1
  42. mindspore/include/api/model.h +17 -0
  43. mindspore/include/api/status.h +8 -3
  44. mindspore/lib/libdnnl.so.2 +0 -0
  45. mindspore/lib/libmindspore.so +0 -0
  46. mindspore/lib/libmindspore_backend.so +0 -0
  47. mindspore/lib/libmindspore_common.so +0 -0
  48. mindspore/lib/libmindspore_core.so +0 -0
  49. mindspore/lib/libmindspore_glog.so.0 +0 -0
  50. mindspore/lib/libmindspore_gpr.so.15 +0 -0
  51. mindspore/lib/libmindspore_grpc++.so.1 +0 -0
  52. mindspore/lib/libmindspore_grpc.so.15 +0 -0
  53. mindspore/lib/libmindspore_shared_lib.so +0 -0
  54. mindspore/lib/libnnacl.so +0 -0
  55. mindspore/lib/libopencv_core.so.4.5 +0 -0
  56. mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
  57. mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
  58. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
  59. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
  60. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
  61. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
  62. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  63. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  64. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  65. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  66. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
  67. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
  68. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
  69. mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
  70. mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
  71. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
  72. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
  73. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +78 -80
  74. mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
  75. mindspore/lib/plugin/ascend/libakg.so +0 -0
  76. mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
  77. mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
  78. mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
  79. mindspore/lib/plugin/cpu/libakg.so +0 -0
  80. mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
  81. mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
  82. mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
  83. mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
  84. mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
  85. mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
  86. mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
  87. mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
  88. mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
  89. mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
  90. mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
  91. mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
  92. mindspore/nn/cell.py +0 -3
  93. mindspore/nn/layer/activation.py +4 -5
  94. mindspore/nn/layer/conv.py +39 -23
  95. mindspore/nn/layer/flash_attention.py +54 -129
  96. mindspore/nn/layer/math.py +3 -7
  97. mindspore/nn/layer/rnn_cells.py +5 -5
  98. mindspore/nn/wrap/__init__.py +4 -2
  99. mindspore/nn/wrap/cell_wrapper.py +12 -3
  100. mindspore/numpy/utils_const.py +5 -5
  101. mindspore/ops/_grad_experimental/grad_array_ops.py +1 -1
  102. mindspore/ops/_grad_experimental/grad_implementations.py +2 -2
  103. mindspore/ops/_grad_experimental/grad_math_ops.py +19 -18
  104. mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
  105. mindspore/ops/_op_impl/aicpu/add.py +3 -3
  106. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
  107. mindspore/ops/_utils/utils.py +2 -0
  108. mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
  109. mindspore/ops/composite/multitype_ops/getitem_impl.py +2 -2
  110. mindspore/ops/function/array_func.py +10 -7
  111. mindspore/ops/function/grad/grad_func.py +0 -1
  112. mindspore/ops/function/nn_func.py +98 -9
  113. mindspore/ops/function/random_func.py +2 -1
  114. mindspore/ops/op_info_register.py +24 -21
  115. mindspore/ops/operations/__init__.py +6 -2
  116. mindspore/ops/operations/_grad_ops.py +25 -6
  117. mindspore/ops/operations/_inner_ops.py +155 -23
  118. mindspore/ops/operations/array_ops.py +9 -7
  119. mindspore/ops/operations/comm_ops.py +2 -2
  120. mindspore/ops/operations/custom_ops.py +85 -68
  121. mindspore/ops/operations/inner_ops.py +26 -3
  122. mindspore/ops/operations/math_ops.py +7 -6
  123. mindspore/ops/operations/nn_ops.py +193 -49
  124. mindspore/parallel/_parallel_serialization.py +10 -3
  125. mindspore/parallel/_tensor.py +4 -1
  126. mindspore/parallel/checkpoint_transform.py +13 -2
  127. mindspore/parallel/shard.py +17 -10
  128. mindspore/profiler/common/util.py +1 -0
  129. mindspore/profiler/parser/ascend_hccl_generator.py +232 -0
  130. mindspore/profiler/parser/ascend_msprof_exporter.py +86 -43
  131. mindspore/profiler/parser/ascend_msprof_generator.py +196 -9
  132. mindspore/profiler/parser/ascend_op_generator.py +1 -1
  133. mindspore/profiler/parser/ascend_timeline_generator.py +6 -182
  134. mindspore/profiler/parser/base_timeline_generator.py +1 -1
  135. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +2 -2
  136. mindspore/profiler/parser/framework_parser.py +1 -1
  137. mindspore/profiler/parser/profiler_info.py +19 -0
  138. mindspore/profiler/profiling.py +46 -24
  139. mindspore/rewrite/api/pattern_engine.py +1 -1
  140. mindspore/rewrite/parsers/for_parser.py +7 -7
  141. mindspore/rewrite/parsers/module_parser.py +4 -4
  142. mindspore/rewrite/symbol_tree.py +1 -4
  143. mindspore/run_check/_check_version.py +5 -3
  144. mindspore/safeguard/rewrite_obfuscation.py +52 -28
  145. mindspore/scipy/ops.py +55 -5
  146. mindspore/scipy/optimize/__init__.py +3 -2
  147. mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
  148. mindspore/train/callback/_summary_collector.py +1 -1
  149. mindspore/train/dataset_helper.py +1 -0
  150. mindspore/train/model.py +2 -2
  151. mindspore/train/serialization.py +97 -11
  152. mindspore/train/summary/_summary_adapter.py +1 -1
  153. mindspore/train/summary/summary_record.py +23 -7
  154. mindspore/version.py +1 -1
  155. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +3 -2
  156. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +160 -151
  157. mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -406
  158. mindspore/ops/_op_impl/_custom_op/flash_attention/constants.py +0 -41
  159. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -467
  160. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -563
  161. mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -193
  162. mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -435
  163. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
  164. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
  165. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
  166. mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
  167. /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
  168. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
  169. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
  170. {mindspore-2.2.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
@@ -48,12 +48,12 @@ from mindspore.profiler.parser.msadvisor_analyzer import Msadvisor
48
48
  from mindspore.profiler.parser.profiler_info import ProfilerInfo
49
49
  from mindspore.common.api import _pynative_executor
50
50
  from mindspore.profiler.parser.ascend_msprof_exporter import AscendMsprofExporter
51
- from mindspore.profiler.parser.ascend_msprof_generator import AscendMsprofDataGenerator
51
+ from mindspore.profiler.parser.ascend_msprof_generator import AscendMsprofDataGenerator, AscendMsprofDataGeneratorOld
52
52
  from mindspore.profiler.parser.ascend_fpbp_generator import AscendFPBPGenerator
53
53
  from mindspore.profiler.parser.ascend_op_generator import AscendOPGenerator
54
54
  from mindspore.profiler.parser.ascend_steptrace_generator import AscendStepTraceGenerator
55
55
  from mindspore.profiler.parser.ascend_flops_generator import AscendFlopsGenerator
56
- from mindspore.profiler.parser.ascend_hccl_generator import AscendHCCLGenerator
56
+ from mindspore.profiler.parser.ascend_hccl_generator import AscendHCCLGenerator, AscendHCCLGeneratorOld
57
57
 
58
58
  INIT_OP_NAME = 'Default/InitDataSetQueue'
59
59
 
@@ -274,16 +274,20 @@ def _parse_host_info(input_file, output_timeline_file, output_memory_file, is_de
274
274
 
275
275
 
276
276
  def _ascend_graph_msprof_generator(source_path, model_iteration_dict):
277
+ """Executing the msprof export mode."""
277
278
  try:
279
+ ProfilerInfo.set_export_start_time(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
278
280
  msprof_exporter = AscendMsprofExporter(source_path)
279
- msprof_exporter.export(model_iteration_dict)
281
+ flag = msprof_exporter.export(model_iteration_dict)
282
+ ProfilerInfo.set_export_end_time(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
283
+ return flag
284
+
280
285
  except ProfilerException as err:
281
286
  logger.warning(err.message)
282
- finally:
283
- pass
287
+ return False
284
288
 
285
289
 
286
- def _ascend_graph_msprof_analyse(source_path):
290
+ def _ascend_graph_msprof_analyse(source_path, flag):
287
291
  """
288
292
  Ascend graph model msprof data analyse.
289
293
 
@@ -294,7 +298,10 @@ def _ascend_graph_msprof_analyse(source_path):
294
298
  df_op_statistic = []
295
299
  df_step_trace = []
296
300
  try:
297
- msprof_analyser = AscendMsprofDataGenerator(os.path.join(source_path, 'summary'))
301
+ if flag:
302
+ msprof_analyser = AscendMsprofDataGenerator(os.path.join(source_path, 'summary'))
303
+ else:
304
+ msprof_analyser = AscendMsprofDataGeneratorOld(os.path.join(source_path, 'summary'))
298
305
  df_op_summary, df_op_statistic, df_step_trace = msprof_analyser.parse()
299
306
  except ProfilerException as err:
300
307
  logger.warning(err.message)
@@ -436,6 +443,7 @@ class Profiler:
436
443
  self._ascend_profiler = None
437
444
  self._timeline_size_limit_byte = 500 * 1024 * 1024 # 500MB
438
445
  self._parallel_strategy = True
446
+ self._model_iteration_dict = None
439
447
  _environment_check()
440
448
  # default aicore_metrics type is ArithmeticUtilization
441
449
  self._aicore_metrics_id = 0
@@ -450,7 +458,6 @@ class Profiler:
450
458
  self._sync_enable = True
451
459
  self._stop_time = 0
452
460
  self._dynamic_status = False
453
- self._model_iteration_dict = None
454
461
  self._profile_framework = "all"
455
462
  self._msprof_enable = os.getenv("PROFILER_SAMPLECONFIG")
456
463
  if self._msprof_enable:
@@ -609,9 +616,21 @@ class Profiler:
609
616
  model_iteration_dict: Dictionary with model id as the key and iteration id as the value, Default: ``None``.
610
617
  """
611
618
  self._model_iteration_dict = model_iteration_dict
619
+
620
+ self._init_profiler_info()
621
+ self._is_support_step_info_collect()
622
+ parallel_mode = get_auto_parallel_context("parallel_mode")
623
+ stage_num = get_auto_parallel_context("pipeline_stages")
624
+
625
+ ProfilerInfo.set_parallel_info(parallel_mode, stage_num)
626
+ ProfilerInfo.set_rank_size(self._rank_size)
627
+ ProfilerInfo.set_heterogeneous(self._is_heterogeneous)
612
628
  if offline_path:
613
629
  if self._is_offline_parser():
630
+ ProfilerInfo.set_analyse_start_time(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
614
631
  self._ascend_graph_analyse()
632
+ ProfilerInfo.set_analyse_end_time(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
633
+ ProfilerInfo.save(self._output_path)
615
634
  _offline_parse(offline_path)
616
635
  return
617
636
  if self._msprof_enable:
@@ -645,15 +664,7 @@ class Profiler:
645
664
  logger.warning("The parameter 'profile_framework' is not support for CPU, so there no host_info"
646
665
  " directory in the output path.")
647
666
  logger.info("Profiling: all the data have been analyzed.")
648
- self._init_profiler_info()
649
- self._is_support_step_info_collect()
650
- parallel_mode = get_auto_parallel_context("parallel_mode")
651
- stage_num = get_auto_parallel_context("pipeline_stages")
652
-
653
- ProfilerInfo.set_parallel_info(parallel_mode, stage_num)
654
667
  ProfilerInfo.set_analyse_end_time(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
655
- ProfilerInfo.set_rank_size(self._rank_size)
656
- ProfilerInfo.set_heterogeneous(self._is_heterogeneous)
657
668
  ProfilerInfo.save(self._output_path)
658
669
 
659
670
  def start(self):
@@ -785,6 +796,8 @@ class Profiler:
785
796
 
786
797
  self._stop_time = int(time.time() * 10000000)
787
798
  ProfilerInfo.set_profiling_stop_time(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
799
+ self._init_profiler_info()
800
+ ProfilerInfo.save(self._output_path)
788
801
  logger.info("Profiling: stop time: %d", self._stop_time)
789
802
 
790
803
  def _profiler_init(self, kwargs):
@@ -1158,7 +1171,7 @@ class Profiler:
1158
1171
 
1159
1172
  def _ascend_flops_analyse(self, op_summary):
1160
1173
  """Get op FLOPs from op_summary, write output_op_flops_x.csv."""
1161
- if len(op_summary.dtype) != 18:
1174
+ if 'vector_fops' not in op_summary.dtype.names and 'cube_fops' not in op_summary.dtype.names:
1162
1175
  logger.warning("[Profiler] Can not found cube fops and vector fops data in the op summary.")
1163
1176
  return
1164
1177
 
@@ -1195,7 +1208,7 @@ class Profiler:
1195
1208
  finally:
1196
1209
  pass
1197
1210
 
1198
- def _ascend_graph_hccl_analyse(self, source_path):
1211
+ def _ascend_graph_hccl_analyse(self, source_path, steptrace, flag):
1199
1212
  """Analyse hccl profiler info."""
1200
1213
  if not self._profile_communication:
1201
1214
  return
@@ -1209,8 +1222,10 @@ class Profiler:
1209
1222
 
1210
1223
  hccl_raw_path = os.path.join(self._output_path, f'hccl_raw_{dev_id}.csv')
1211
1224
  hccl_raw_path = validate_and_normalize_path(hccl_raw_path)
1212
-
1213
- hccl_analyse = AscendHCCLGenerator(os.path.join(source_path, 'timeline'))
1225
+ if flag:
1226
+ hccl_analyse = AscendHCCLGenerator(os.path.join(source_path, 'timeline'), steptrace)
1227
+ else:
1228
+ hccl_analyse = AscendHCCLGeneratorOld(os.path.join(source_path, 'timeline'))
1214
1229
  hccl_analyse.parse()
1215
1230
  hccl_analyse.write(hccl_raw_path)
1216
1231
 
@@ -1252,8 +1267,12 @@ class Profiler:
1252
1267
  source_path = os.path.join(self._output_path, job_id)
1253
1268
  self._minddata_analyse(source_path)
1254
1269
  if self._op_time:
1255
- _ascend_graph_msprof_generator(source_path, self._model_iteration_dict)
1256
- op_summary, op_statistic, steptrace = _ascend_graph_msprof_analyse(source_path)
1270
+ flag = _ascend_graph_msprof_generator(source_path, self._model_iteration_dict)
1271
+ if not flag:
1272
+ logger.warning('Current driver package not support all export mode, use single export mode, '
1273
+ 'this may lead to performance degradation. Suggest upgrading the driver package.')
1274
+ ProfilerInfo.set_export_flag(flag)
1275
+ op_summary, op_statistic, steptrace = _ascend_graph_msprof_analyse(source_path, flag)
1257
1276
  self._ascend_op_analyse(op_summary, op_statistic, self._dynamic_status)
1258
1277
  self._ascend_timeline_analyse(op_summary, steptrace)
1259
1278
  graph_ids = np.unique(op_summary['Model ID']).tolist()
@@ -1264,7 +1283,7 @@ class Profiler:
1264
1283
  self._ascend_dynamic_net_analyse(op_summary)
1265
1284
  self._ascend_flops_analyse(op_summary)
1266
1285
  self._ascend_graph_memory_analyse(points)
1267
- self._ascend_graph_hccl_analyse(source_path)
1286
+ self._ascend_graph_hccl_analyse(source_path, steptrace, flag)
1268
1287
  self._ascend_graph_msadvisor_analyse(job_id)
1269
1288
  ProfilerInfo.set_graph_ids(graph_ids)
1270
1289
 
@@ -1459,6 +1478,9 @@ class Profiler:
1459
1478
  job_id = self._ascend_job_id.rstrip('/').split('/')[-1]
1460
1479
  if job_id.startswith('PROF'):
1461
1480
  device_dir = [dir for dir in os.listdir(self._ascend_job_id) if dir.startswith('device')]
1481
+ info_file_path = get_file_path(os.path.join(self._ascend_job_id, device_dir[0]), "info.json")
1482
+ training_rank_id, _ = self._parse_info_json(info_file_path)
1483
+ self._rank_id = int(training_rank_id)
1462
1484
  return os.path.join(job_id, device_dir[0])
1463
1485
  return job_id
1464
1486
 
@@ -1489,8 +1511,8 @@ class Profiler:
1489
1511
  "profiler will ignore this job dir.", job_dir)
1490
1512
  continue
1491
1513
 
1492
- _, training_device_id = self._parse_info_json(info_file_path)
1493
1514
  job_start_time = self._parse_start_log(start_file_path)
1515
+ _, training_device_id = self._parse_info_json(info_file_path)
1494
1516
 
1495
1517
  if self._dev_id != training_device_id:
1496
1518
  logger.debug("Find profiling find job path %s, but not current training device id. "
@@ -364,7 +364,7 @@ class PatternEngine:
364
364
  continue
365
365
  if cur_node.get_node_type() == NodeType.Tree:
366
366
  subtree = TreeNodeHelper.get_sub_tree(cur_node)
367
- self.apply(subtree)
367
+ _ = self.apply(subtree)
368
368
  visited.append(cur_node)
369
369
  queue.extend(cur_node.get_users())
370
370
  continue
@@ -72,7 +72,7 @@ class ForParser(Parser):
72
72
  return
73
73
  iter_code = astunparse.unparse(node.iter)
74
74
  if not iter_code.startswith(EVAL_WHITE_LIST):
75
- logger.warning(
75
+ logger.info(
76
76
  f"For MindSpore Rewrtie, illegal iteration condition for For node, it must start with{EVAL_WHITE_LIST}")
77
77
  return
78
78
  if "self" in iter_code:
@@ -82,7 +82,7 @@ class ForParser(Parser):
82
82
  except (NameError, TypeError) as e:
83
83
  _info = f"For MindSpore Rewrtie, when eval '{iter_code}' by using JIT Fallback feature, " \
84
84
  f"an error occurred: {str(e)}"
85
- logger.warning(_info)
85
+ logger.info(_info)
86
86
  stree.try_append_python_node(node, node, node_manager)
87
87
  return
88
88
 
@@ -107,7 +107,7 @@ class ForParser(Parser):
107
107
  ast_functiondef.body.insert(index, new_node)
108
108
  index += 1
109
109
  # Expand "for" statement and replace the body with Pass
110
- for body in node.body:
110
+ for body in node.body[:]:
111
111
  node.body.remove(body)
112
112
  node.body.append(ast.Pass())
113
113
 
@@ -115,13 +115,13 @@ class ForParser(Parser):
115
115
  stree.on_change(Event.CodeChangeEvent)
116
116
  return
117
117
  if isinstance(iter_obj, range):
118
- logger.warning("For MindSpore Rewrite, range not support.")
118
+ logger.info("For MindSpore Rewrite, range not support.")
119
119
  elif isinstance(iter_obj, zip):
120
- logger.warning("For MindSpore Rewrite, zip not support.")
120
+ logger.info("For MindSpore Rewrite, zip not support.")
121
121
  elif isinstance(iter_obj, enumerate):
122
- logger.warning("For MindSpore Rewrite, enumerate not support.")
122
+ logger.info("For MindSpore Rewrite, enumerate not support.")
123
123
  else:
124
- logger.warning(f"For MindSpore Rewrite, not supported type: {type(iter_obj).__name__}")
124
+ logger.info(f"For MindSpore Rewrite, not supported type: {type(iter_obj).__name__}")
125
125
  stree.try_append_python_node(node, node, node_manager)
126
126
  return
127
127
 
@@ -170,15 +170,15 @@ class ModuleParser(Parser):
170
170
  level_count += 1
171
171
  continue
172
172
  except Exception as e: # pylint: disable=W0703
173
- logger.warning(f"For MindSpore Rewrite, in module parser, process import code: "
174
- f"{import_code} failed: {e}. Ignore this import code.")
173
+ logger.info(f"For MindSpore Rewrite, in module parser, process import code: "
174
+ f"{import_code} failed: {e}. Ignore this import code.")
175
175
  return None, None
176
176
  else:
177
177
  # try test code success
178
178
  return import_node_test.module, file_path_tmp
179
179
  # try codes with all level failed
180
- logger.warning(f"For MindSpore Rewrite, in module parser, test import code: "
181
- f"{astunparse.unparse(import_node).strip()} failed. Ignore this import code.")
180
+ logger.info(f"For MindSpore Rewrite, in module parser, test import code: "
181
+ f"{astunparse.unparse(import_node).strip()} failed. Ignore this import code.")
182
182
  return None, None
183
183
 
184
184
  @staticmethod
@@ -226,7 +226,6 @@ class SymbolTree(Observer, Observable, NodeManager):
226
226
  if class_str not in classes:
227
227
  classes.add(node.name)
228
228
  return node
229
- return
230
229
 
231
230
  def visit_Try(self, node: ast.Try) -> Any:
232
231
  if isinstance(node.body[0], (ast.Import, ast.ImportFrom)):
@@ -234,14 +233,12 @@ class SymbolTree(Observer, Observable, NodeManager):
234
233
  if import_str not in imports:
235
234
  imports.add(import_str)
236
235
  return node
237
- return
238
236
 
239
237
  def visit_Import(self, node: ast.Import) -> Any:
240
238
  import_str = astunparse.unparse(node)
241
239
  if import_str not in imports:
242
240
  imports.add(import_str)
243
241
  return node
244
- return
245
242
 
246
243
  def visit_ImportFrom(self, node: ast.ImportFrom) -> Any:
247
244
  """
@@ -818,7 +815,7 @@ class SymbolTree(Observer, Observable, NodeManager):
818
815
  for node in new_nodes:
819
816
  self.insert_node(node, base_node, False, node_manager, True)
820
817
  base_node = node
821
- self.erase_node(old_node)
818
+ _ = self.erase_node(old_node)
822
819
  return new_nodes[-1]
823
820
 
824
821
  def set_node_arg(self, node: Union[Node, str], index: int, arg: Union[ScopedValue, str]):
@@ -259,7 +259,7 @@ class AscendEnvChecker(EnvChecker):
259
259
 
260
260
  def __init__(self, library_path):
261
261
  self.library_path = library_path
262
- self.version = ["7.0"]
262
+ self.version = ["7.1"]
263
263
  atlas_nnae_version = "/usr/local/Ascend/nnae/latest/compiler/version.info"
264
264
  atlas_toolkit_version = "/usr/local/Ascend/ascend-toolkit/latest/compiler/version.info"
265
265
  hisi_fwk_version = "/usr/local/Ascend/latest/compiler/version.info"
@@ -398,11 +398,13 @@ class AscendEnvChecker(EnvChecker):
398
398
 
399
399
  def set_env(self):
400
400
  curr_path = os.path.abspath(os.path.dirname(__file__))
401
+ cust_aicpu_path = os.path.abspath(os.path.join(curr_path, "../lib/plugin/ascend/custom_aicpu_ops"))
402
+ cust_aicore_path = os.path.abspath(os.path.join(curr_path, "../lib/plugin/ascend/custom_aicore_ops"))
401
403
  if os.getenv('ASCEND_CUSTOM_OPP_PATH'):
402
404
  os.environ['ASCEND_CUSTOM_OPP_PATH'] = os.environ['ASCEND_CUSTOM_OPP_PATH'] + ":" + \
403
- curr_path + "/../lib/plugin/ascend/custom_aicpu_ops"
405
+ cust_aicore_path + ":" + cust_aicpu_path
404
406
  else:
405
- os.environ['ASCEND_CUSTOM_OPP_PATH'] = curr_path + "/../lib/plugin/ascend/custom_aicpu_ops"
407
+ os.environ['ASCEND_CUSTOM_OPP_PATH'] = cust_aicore_path + ":" + cust_aicpu_path
406
408
  plugin_dir = os.path.dirname(self.library_path)
407
409
  akg_dir = os.path.join(plugin_dir, "ascend")
408
410
  AscendEnvChecker._concat_variable('LD_LIBRARY_PATH', akg_dir)
@@ -17,13 +17,14 @@ import os
17
17
  import re
18
18
  import secrets
19
19
  from pathlib import Path
20
- import numpy as np
21
20
 
22
21
  from mindspore import ops, nn
23
22
  from mindspore.common.tensor import Tensor
24
23
  from mindspore import log as logger
25
24
  from mindspore import load_checkpoint, save_checkpoint
26
25
  from mindspore.rewrite import SymbolTree, Node, NodeType, TreeNodeHelper, ScopedValue
26
+ from mindspore.rewrite.parsers.class_def_parser import ClassDefParser
27
+ from mindspore.rewrite.parsers.class_def_parser import ModuleParser
27
28
 
28
29
  OBF_RATIOS_LENGTH = 1
29
30
  MAX_OBF_RATIOS_NUM = 50
@@ -31,7 +32,7 @@ OBF_RATIOS_WIDTH = 0
31
32
  OBF_RATIOS_INSERT_INDEX = 0
32
33
 
33
34
 
34
- def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./'):
35
+ def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./', obfuscate_scale=100):
35
36
  """
36
37
  obfuscate the plaintext checkpoint files. Usually used in conjunction with
37
38
  :func:`mindspore.load_obf_params_into_net`.
@@ -49,8 +50,9 @@ def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./'):
49
50
  (such as transformer layers or resnet blocks). If target_modules is ``None``, the function would search
50
51
  target modules by itself. If found, the searched target module would be used, otherwise suggested target
51
52
  modules would be given with warning log. Default: ``None``.
52
- saved_path (str): The directory path for saving obfuscated ckpt files and obf_ratios (a numpy file). obf_ratios
53
- is the necessary data that needs to be load when running obfuscated network. Default: ``'./'``.
53
+ saved_path (str): The directory path for saving obfuscated ckpt files. Default: ``'./'``.
54
+ obfuscate_scale (Union[float, int]): Obfuscate scale of weights. The generated random obf_ratios will be in
55
+ range of (1 / obfuscate_scale, obfuscate_scale). Default: 100.
54
56
 
55
57
  Raises:
56
58
  TypeError: If `network` is not nn.Cell.
@@ -66,6 +68,9 @@ def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./'):
66
68
  ValueError: If the third string of `target_modules` is not in the format of 'obfuscate_layers:all' or
67
69
  'obfuscate_layers:int'.
68
70
 
71
+ Returns:
72
+ list[float], obf_ratios, which is the necessary data that needs to be load when running obfuscated network.
73
+
69
74
  Examples:
70
75
  >>> from mindspore import obfuscate_ckpt, save_checkpoint
71
76
  >>> # Refer to https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
@@ -88,10 +93,13 @@ def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./'):
88
93
  if not _check_valid_target(network, to_split_modules):
89
94
  raise ValueError("The obfuscate module path {} is not exist, please check the input 'target_modules'."
90
95
  .format(to_split_modules))
96
+ if (not isinstance(obfuscate_scale, (float, int))) or (obfuscate_scale <= 1):
97
+ raise ValueError("obfuscate_scale must be float or int, and larger than 1, but got {}."
98
+ .format(obfuscate_scale))
91
99
  # generate and save obf_ratios to saved_path
92
100
  path_list = to_split_modules[0].split('/')
93
101
  target_list = to_split_modules[1].split('|')
94
- global OBF_RATIOS_WIDTH, OBF_RATIOS_LENGTH
102
+ global OBF_RATIOS_LENGTH
95
103
  number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
96
104
  if number_of_ratios > MAX_OBF_RATIOS_NUM:
97
105
  OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH
@@ -99,14 +107,13 @@ def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./'):
99
107
  obf_ratios = []
100
108
  secrets_generator = secrets.SystemRandom()
101
109
  for _ in range(number_of_ratios):
102
- secure_float = secrets_generator.uniform(0.01, 100)
110
+ secure_float = secrets_generator.uniform(1 / obfuscate_scale, obfuscate_scale)
103
111
  obf_ratios.append(secure_float)
104
- np.save(os.path.abspath(saved_path) + '/' + 'obf_ratios.npy', np.array(obf_ratios))
105
112
  # start obfuscate ckpt
106
113
  ckpt_dir_files = os.listdir(ckpt_files)
107
114
  for ckpt_name in ckpt_dir_files:
108
- if Path(ckpt_files + ckpt_name).is_dir():
109
- sub_path = ckpt_files + ckpt_name
115
+ sub_path = os.path.abspath(ckpt_files) + '/' + ckpt_name
116
+ if Path(sub_path).is_dir():
110
117
  sub_ckpt_file_list = os.listdir(sub_path)
111
118
  new_saved_path = os.path.abspath(saved_path) + '/' + ckpt_name
112
119
  if not os.path.exists(new_saved_path):
@@ -124,20 +131,24 @@ def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./'):
124
131
  continue
125
132
  _obfuscate_single_ckpt(os.path.abspath(ckpt_files) + '/' + ckpt_name, obf_ratios, path_list,
126
133
  target_list, saved_path)
134
+ return obf_ratios
127
135
 
128
136
 
129
137
  def _obfuscate_single_ckpt(ckpt_name, obf_ratios, path_list, target_list, saved_path):
130
138
  """Obfuscate single ckpt file"""
131
139
  module_has_been_obfuscated = set()
132
- ckpt_param = load_checkpoint(ckpt_name)
140
+ try:
141
+ ckpt_param = load_checkpoint(ckpt_name)
142
+ except (ValueError, TypeError, OSError):
143
+ logger.error("Load checkpoint failed for file {}.".format(ckpt_name))
144
+ return None
133
145
  obf_ratios_index = -1
134
- global OBF_RATIOS_LENGTH, OBF_RATIOS_WIDTH
135
146
  for item in ckpt_param:
136
147
  module = _get_valid_module(item, path_list, target_list)
137
148
  if module:
138
149
  layer_index = _judge_layer_index(item)
139
150
  if layer_index >= OBF_RATIOS_LENGTH:
140
- break
151
+ continue
141
152
  if module not in module_has_been_obfuscated:
142
153
  module_has_been_obfuscated.add(module)
143
154
  obf_ratios_index += 1
@@ -150,9 +161,10 @@ def _obfuscate_single_ckpt(ckpt_name, obf_ratios, path_list, target_list, saved_
150
161
  ckpt_file_name = ckpt_name.split('/')[-1]
151
162
  obf_ckpt_file_name = ckpt_file_name.split('.')[0] + '_obf' + '.ckpt'
152
163
  save_checkpoint(obf_param_list, os.path.abspath(saved_path) + '/' + obf_ckpt_file_name)
164
+ return None
153
165
 
154
166
 
155
- def load_obf_params_into_net(network, target_modules, obf_ratios, **kwargs):
167
+ def load_obf_params_into_net(network, target_modules, obf_ratios, data_parallel_num=1, **kwargs):
156
168
  """
157
169
  load obfuscate ratios into obfuscated network. Usually used in conjunction with :func:`mindspore.obfuscate_ckpt`
158
170
  interface.
@@ -166,6 +178,7 @@ def load_obf_params_into_net(network, target_modules, obf_ratios, **kwargs):
166
178
  If target_modules has the third value, it should be in the format of 'obfuscate_layers:all' or
167
179
  'obfuscate_layers:int', which represents the number of layers need to be obfuscated of duplicate layers
168
180
  (such as transformer layers or resnet blocks).
181
+ data_parallel_num (int): The data parallel number of parallel training. Default: 1.
169
182
  obf_ratios (Tensor): The obf ratios generated when execute :func:`mindspore.obfuscate_ckpt`.
170
183
  kwargs (dict): Configuration options dictionary.
171
184
 
@@ -211,6 +224,8 @@ def load_obf_params_into_net(network, target_modules, obf_ratios, **kwargs):
211
224
  raise ValueError("obf_ratios can not be empty.")
212
225
  if not _check_valid_target(network, target_modules):
213
226
  raise ValueError("{} is not exist, please check the input 'target_modules'.".format(target_modules))
227
+ if (not isinstance(data_parallel_num, int)) or (data_parallel_num <= 0):
228
+ raise ValueError("data_parallel_num must be positive number, but got {}.".format(data_parallel_num))
214
229
  if len(target_modules) >= 1 and target_modules[0] == '/':
215
230
  target_modules[0] = ''
216
231
  path_list = target_modules[0].split('/')
@@ -219,13 +234,13 @@ def load_obf_params_into_net(network, target_modules, obf_ratios, **kwargs):
219
234
  for _ in range(path_len):
220
235
  target_list.append([])
221
236
  target_list.append(target_modules[1].split('|'))
222
- global MAX_OBF_RATIOS_NUM, OBF_RATIOS_WIDTH, OBF_RATIOS_LENGTH
237
+ global MAX_OBF_RATIOS_NUM, OBF_RATIOS_LENGTH
223
238
  number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
224
239
  if number_of_ratios > MAX_OBF_RATIOS_NUM:
225
240
  OBF_RATIOS_LENGTH = MAX_OBF_RATIOS_NUM // OBF_RATIOS_WIDTH
226
241
  number_of_ratios = OBF_RATIOS_LENGTH * OBF_RATIOS_WIDTH
227
242
  MAX_OBF_RATIOS_NUM = number_of_ratios
228
- rewrite_network = _obfuscate_network(network, path_list, target_list, **kwargs)
243
+ rewrite_network = _obfuscate_network(network, path_list, target_list, data_parallel_num=data_parallel_num, **kwargs)
229
244
  setattr(rewrite_network, 'obf_ratios', obf_ratios)
230
245
  return rewrite_network
231
246
 
@@ -263,7 +278,7 @@ def _check_valid_target(network, target_modules):
263
278
  if not target_modules[1]:
264
279
  raise ValueError("{} should be a non-empty string value, in the form of 'D1|D2'"
265
280
  .format(target_modules[1]))
266
- if not re.fullmatch(pattern=r'([a-zA-Z]*[0-9]*\/*_*)*', string=target_modules[0])\
281
+ if not re.fullmatch(pattern=r'([a-zA-Z]*[0-9]*\/*_*)*', string=target_modules[0]) \
267
282
  or not re.fullmatch(pattern=r'([a-zA-Z]*[0-9]*\|*_*)*', string=target_modules[1]):
268
283
  raise ValueError("please check the input 'target_modules'{},it should be in the form of ['A/B/C', 'D1|D2']."
269
284
  "target_modules[0] can only contain uppercase and lowercase letters, numbers, '_' and '/',"
@@ -297,9 +312,9 @@ def _check_valid_target(network, target_modules):
297
312
  # check whether target_list is valid
298
313
  global OBF_RATIOS_WIDTH
299
314
  OBF_RATIOS_WIDTH = 0
300
- for j in range(len(target_list)):
301
- if not hasattr(net, target_list[j]):
302
- logger.warning("{} does not exist in the path {}".format(target_list[j], target_modules[0]))
315
+ for target in target_list:
316
+ if not hasattr(net, target):
317
+ logger.warning("{} does not exist in the path {}".format(target, target_modules[0]))
303
318
  else:
304
319
  OBF_RATIOS_WIDTH += 1
305
320
  if OBF_RATIOS_WIDTH == 0:
@@ -328,6 +343,7 @@ def _update_max_obf_ratios_num(target_modules):
328
343
 
329
344
  def _get_default_target_modules(ckpt_files):
330
345
  """Get the default or suggested target modules, if the target modules is None."""
346
+
331
347
  def _split_to_path_and_target(module, target):
332
348
  # split module into path list and target list
333
349
  target_index = module.index(target)
@@ -370,7 +386,11 @@ def _get_default_target_modules(ckpt_files):
370
386
  for ckpt_name in ckpt_dir_files:
371
387
  if not ckpt_name.endswith('.ckpt'):
372
388
  continue
373
- ckpt_param = load_checkpoint(os.path.abspath(ckpt_files) + '/' + ckpt_name)
389
+ try:
390
+ ckpt_param = load_checkpoint(os.path.abspath(ckpt_files) + '/' + ckpt_name)
391
+ except (ValueError, TypeError, OSError):
392
+ logger.error("Load checkpoint failed for file {}.".format(os.path.abspath(ckpt_files) + '/' + ckpt_name))
393
+ return None
374
394
  for item in ckpt_param:
375
395
  param_path = _remove_digit(item)
376
396
  param_path = '/'.join(param_path)
@@ -396,9 +416,9 @@ def _get_valid_module(item, path_list, target_list):
396
416
  tar_path = '/'.join(path_list)
397
417
  # update the weights with obf_ratios in target module
398
418
  if net_path == tar_path:
399
- for i in range(len(target_list)):
400
- if target_list[i] in item.split('.'):
401
- target_index = item.split('.').index(target_list[i])
419
+ for target in target_list:
420
+ if target in item.split('.'):
421
+ target_index = item.split('.').index(target)
402
422
  module = ''.join(item.split('.')[:target_index + 1])
403
423
  return module
404
424
  return None
@@ -413,7 +433,7 @@ def _remove_digit(item):
413
433
  return param_path
414
434
 
415
435
 
416
- def _obfuscate_network(model, path_list, target_list, **kwargs):
436
+ def _obfuscate_network(model, path_list, target_list, data_parallel_num=1, **kwargs):
417
437
  """obfuscate original network, including add mul operation and add inputs for passing obf_ratio."""
418
438
 
419
439
  def _insert_input(stree: SymbolTree, arg_name: str = 'y_obf'):
@@ -426,7 +446,6 @@ def _obfuscate_network(model, path_list, target_list, **kwargs):
426
446
  # the insert input node name would be 'input_y_obf'
427
447
  new_input_node = last_input.create_input(arg_name)
428
448
  stree.insert(position, new_input_node)
429
- return new_input_node
430
449
 
431
450
  def _insert_mul(stree: SymbolTree, node: Node, index: int):
432
451
  """add mul operation for original network"""
@@ -436,7 +455,12 @@ def _obfuscate_network(model, path_list, target_list, **kwargs):
436
455
  sv: ScopedValue = ScopedValue.create_naming_value(v + f'[{index}]')
437
456
  arg_list.append(sv)
438
457
  target_list = node.get_targets().copy()
439
- new_mul_node = node.create_call_cell(cell=ops.Mul(), targets=target_list, args=arg_list, name='mul')
458
+ if data_parallel_num > 1:
459
+ logger.info("Data parallel number is: {}".format(data_parallel_num))
460
+ new_mul_node = node.create_call_cell(cell=ops.Mul().shard(((data_parallel_num, 1), ())),
461
+ targets=target_list, args=arg_list, name='mul')
462
+ else:
463
+ new_mul_node = node.create_call_cell(cell=ops.Mul(), targets=target_list, args=arg_list, name='mul')
440
464
  position = stree.after(node)
441
465
  stree.insert(position, new_mul_node)
442
466
 
@@ -459,6 +483,8 @@ def _obfuscate_network(model, path_list, target_list, **kwargs):
459
483
  if input_y_node is None:
460
484
  return
461
485
  arg_list = subnode.get_args().copy()
486
+ kwargs_list = list(subnode.get_kwargs().values())
487
+ arg_list.extend(kwargs_list)
462
488
  v: str = input_y_node.get_targets()[0].value
463
489
  arg_obf: ScopedValue = ScopedValue.create_naming_value("y_obf=" + v)
464
490
  arg_list.append(arg_obf)
@@ -482,13 +508,11 @@ def _obfuscate_network(model, path_list, target_list, **kwargs):
482
508
 
483
509
  def _register_denied_func_decorators(fn):
484
510
  """set the function decorators which should be denied for parse"""
485
- from mindspore.rewrite.parsers.class_def_parser import ClassDefParser
486
511
  name = "denied_function_decorator_list"
487
512
  setattr(ClassDefParser, name, fn)
488
513
 
489
514
  def _register_denied_class_decorators(fn):
490
515
  """set the class decorators which should be denied for parse"""
491
- from mindspore.rewrite.parsers.class_def_parser import ModuleParser
492
516
  name = "denied_class_decorator_list"
493
517
  setattr(ModuleParser, name, fn)
494
518
 
mindspore/scipy/ops.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2021 Huawei Technologies Co., Ltd
1
+ # Copyright 2021-2023 Huawei Technologies Co., Ltd
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -156,14 +156,64 @@ class LU(PrimitiveWithInfer):
156
156
 
157
157
 
158
158
  class LinearSumAssignment(Primitive):
159
- """Solve the linear sum assignment problem."""
159
+ r"""
160
+ Solve the linear sum assignment problem.
161
+
162
+ The assignment problem is represented as follows:
163
+
164
+ .. math::
165
+ min\sum_{i}^{} \sum_{j}^{} C_{i,j} X_{i,j}
166
+
167
+ where :math:`C` is cost matrix, :math:`X_{i,j} = 1` means column :math:`j` is assigned to row :math:`i` .
168
+
169
+ Inputs:
170
+ - **cost_matrix** (Tensor) - 2-D cost matrix. Tensor of shape :math:`(M, N)` .
171
+ - **dimension_limit** (Tensor, optional) - A scalar used to limit the actual size of the 2nd dimension of
172
+ ``cost_matrix``. Default is ``Tensor(sys.maxsize)``, which means no limitation. The type is 0-D int64
173
+ Tensor.
174
+ - **maximize** (bool) - Calculate a maximum weight matching if true, otherwise calculate a minimum weight
175
+ matching.
176
+
177
+ Outputs:
178
+ A tuple of tensors containing 'row_idx' and 'col_idx'.
179
+
180
+ - **row_idx** (Tensor) - Row indices of the problem. If `dimension_limit` is given, -1 would be padded at the
181
+ end. The shape is :math:`(N, )` , where :math:`N` is the minimum value of `cost_matrix` dimension.
182
+ - **col_idx** (Tensor) - Column indices of the problem. If `dimension_limit` is given, -1 would be padded at
183
+ the end. The shape is :math:`(N, )` , where :math:`N` is the minimum value of `cost_matrix` dimension.
184
+
185
+ Raises:
186
+ TypeError: If the data type of `cost_matrix` is not the type in [float16, float32, float64,
187
+ int8, int16, int32, int64, uint8, uint16, uint32, uint64, bool]
188
+ TypeError: If the type of `maximize` is not bool.
189
+ TypeError: If the data type of `dimension_limit` is not int64.
190
+ ValueError: If the rank of `cost_matrix` is not 2.
191
+ ValueError: If the number of input args is not 3.
192
+
193
+
194
+ Supported Platforms:
195
+ ``Ascend`` ``CPU``
196
+
197
+ Examples:
198
+ >>> import mindspore as ms
199
+ >>> import numpy as np
200
+ >>> from mindspore import Tensor
201
+ >>> from mindspore.scipy.ops import LinearSumAssignment
202
+ >>> lsap = LinearSumAssignment()
203
+ >>> cost_matrix = Tensor(np.array([[2, 3, 3], [3, 2, 3], [3, 3, 2]])).astype(ms.float64)
204
+ >>> dimension_limit = Tensor(2)
205
+ >>> maximize = False
206
+ >>> a, b = lsap(cost_matrix, dimension_limit, maximize)
207
+ >>> print(a)
208
+ [0 1 -1]
209
+ >>> print(b)
210
+ [0 1 -1]
211
+ """
160
212
 
161
213
  @prim_attr_register
162
214
  def __init__(self):
163
- super().__init__("LinearSumAssignment")
215
+ super().__init__(name="LinearSumAssignment")
164
216
  self.init_prim_io_names(inputs=['cost_matrix', 'dimension_limit', 'maximize'], outputs=['row_ind', 'col_ind'])
165
- self.add_prim_attr("cust_aicpu", "mindspore_aicpu_kernels")
166
-
167
217
 
168
218
  # pylint: disable=C0413,W0611
169
219
  from .ops_grad import get_bprpo_eigh, get_bprpo_trsm