mindstudio-probe 1.2.2__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/METADATA +3 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/RECORD +143 -144
  3. msprobe/README.md +25 -20
  4. msprobe/core/common/const.py +110 -66
  5. msprobe/core/common/decorator.py +50 -0
  6. msprobe/core/common/exceptions.py +3 -1
  7. msprobe/core/common/file_utils.py +25 -2
  8. msprobe/core/common/utils.py +30 -34
  9. msprobe/core/compare/acc_compare.py +43 -74
  10. msprobe/core/compare/check.py +2 -6
  11. msprobe/core/compare/highlight.py +2 -0
  12. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -1
  13. msprobe/core/compare/merge_result/merge_result.py +8 -2
  14. msprobe/core/compare/multiprocessing_compute.py +19 -12
  15. msprobe/core/compare/npy_compare.py +30 -12
  16. msprobe/core/compare/utils.py +20 -10
  17. msprobe/core/data_dump/api_registry.py +176 -0
  18. msprobe/core/data_dump/data_processor/base.py +2 -2
  19. msprobe/core/data_dump/data_processor/mindspore_processor.py +19 -32
  20. msprobe/core/data_dump/data_processor/pytorch_processor.py +45 -15
  21. msprobe/core/data_dump/json_writer.py +38 -35
  22. msprobe/core/grad_probe/constant.py +1 -0
  23. msprobe/core/grad_probe/grad_compare.py +1 -1
  24. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  25. msprobe/docs/01.installation.md +2 -1
  26. msprobe/docs/02.config_introduction.md +17 -15
  27. msprobe/docs/05.data_dump_PyTorch.md +70 -2
  28. msprobe/docs/06.data_dump_MindSpore.md +33 -12
  29. msprobe/docs/07.accuracy_checker_PyTorch.md +11 -1
  30. msprobe/docs/08.accuracy_checker_online_PyTorch.md +3 -1
  31. msprobe/docs/09.accuracy_checker_MindSpore.md +1 -1
  32. msprobe/docs/10.accuracy_compare_PyTorch.md +59 -33
  33. msprobe/docs/11.accuracy_compare_MindSpore.md +40 -16
  34. msprobe/docs/12.overflow_check_PyTorch.md +3 -1
  35. msprobe/docs/13.overflow_check_MindSpore.md +4 -2
  36. msprobe/docs/14.data_parse_PyTorch.md +1 -7
  37. msprobe/docs/18.online_dispatch.md +1 -1
  38. msprobe/docs/19.monitor.md +124 -62
  39. msprobe/docs/21.visualization_PyTorch.md +32 -13
  40. msprobe/docs/22.visualization_MindSpore.md +32 -13
  41. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  42. msprobe/docs/27.dump_json_instruction.md +278 -8
  43. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  44. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  45. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  46. msprobe/docs/FAQ.md +3 -11
  47. msprobe/docs/img/compare_result.png +0 -0
  48. msprobe/docs/img/merge_result.png +0 -0
  49. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  50. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  51. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  52. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  53. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  54. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  55. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  56. msprobe/mindspore/__init__.py +4 -3
  57. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +6 -1
  58. msprobe/mindspore/api_accuracy_checker/api_runner.py +19 -9
  59. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  60. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +602 -0
  61. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  62. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -1
  63. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +2 -1
  64. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  65. msprobe/mindspore/common/const.py +61 -0
  66. msprobe/mindspore/common/utils.py +31 -19
  67. msprobe/mindspore/compare/ms_compare.py +27 -19
  68. msprobe/mindspore/compare/ms_graph_compare.py +6 -5
  69. msprobe/mindspore/debugger/debugger_config.py +6 -4
  70. msprobe/mindspore/debugger/precision_debugger.py +22 -10
  71. msprobe/mindspore/dump/dump_tool_factory.py +5 -3
  72. msprobe/mindspore/dump/hook_cell/api_register.py +142 -0
  73. msprobe/mindspore/dump/hook_cell/hook_cell.py +9 -10
  74. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +24 -26
  75. msprobe/mindspore/dump/jit_dump.py +14 -9
  76. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +22 -56
  77. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +0 -1
  78. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +10 -6
  79. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  80. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  81. msprobe/mindspore/grad_probe/global_context.py +2 -0
  82. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  83. msprobe/mindspore/grad_probe/hook.py +2 -4
  84. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  85. msprobe/mindspore/monitor/module_hook.py +354 -302
  86. msprobe/mindspore/monitor/utils.py +46 -4
  87. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  88. msprobe/mindspore/service.py +23 -17
  89. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  90. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +11 -6
  91. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +2 -2
  92. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +4 -5
  93. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +5 -5
  94. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +25 -6
  95. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -19
  96. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  97. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  98. msprobe/pytorch/common/utils.py +29 -7
  99. msprobe/pytorch/debugger/precision_debugger.py +10 -1
  100. msprobe/pytorch/dump/module_dump/module_dump.py +4 -3
  101. msprobe/pytorch/dump/module_dump/module_processer.py +12 -6
  102. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  103. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  104. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  105. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  106. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  107. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  108. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  109. msprobe/pytorch/function_factory.py +1 -1
  110. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  111. msprobe/pytorch/hook_module/api_register.py +131 -0
  112. msprobe/pytorch/hook_module/hook_module.py +19 -14
  113. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  114. msprobe/pytorch/hook_module/support_wrap_ops.yaml +172 -75
  115. msprobe/pytorch/monitor/csv2tb.py +8 -2
  116. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  117. msprobe/pytorch/monitor/module_hook.py +131 -105
  118. msprobe/pytorch/monitor/module_metric.py +3 -0
  119. msprobe/pytorch/monitor/optimizer_collect.py +55 -4
  120. msprobe/pytorch/monitor/unittest/test_monitor.py +1 -1
  121. msprobe/pytorch/monitor/utils.py +68 -1
  122. msprobe/pytorch/online_dispatch/compare.py +0 -2
  123. msprobe/pytorch/online_dispatch/dispatch.py +9 -0
  124. msprobe/pytorch/online_dispatch/dump_compare.py +3 -0
  125. msprobe/pytorch/online_dispatch/utils.py +3 -0
  126. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  127. msprobe/pytorch/parse_tool/lib/utils.py +2 -1
  128. msprobe/pytorch/pt_config.py +11 -7
  129. msprobe/pytorch/service.py +11 -8
  130. msprobe/visualization/builder/graph_builder.py +44 -5
  131. msprobe/visualization/builder/msprobe_adapter.py +0 -1
  132. msprobe/visualization/compare/graph_comparator.py +42 -38
  133. msprobe/visualization/compare/mode_adapter.py +0 -19
  134. msprobe/visualization/graph/base_node.py +8 -1
  135. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  136. msprobe/visualization/graph/graph.py +0 -11
  137. msprobe/visualization/graph/node_op.py +1 -2
  138. msprobe/visualization/graph_service.py +1 -1
  139. msprobe/visualization/utils.py +2 -33
  140. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  141. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  142. msprobe/pytorch/hook_module/api_registry.py +0 -166
  143. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  144. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  145. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  146. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  147. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  148. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  149. msprobe/pytorch/parse.py +0 -19
  150. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/LICENSE +0 -0
  151. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/WHEEL +0 -0
  152. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/entry_points.txt +0 -0
  153. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-1.3.0.dist-info}/top_level.txt +0 -0
@@ -564,15 +564,15 @@ tensor:
564
564
  - all
565
565
  - amax
566
566
  - amin
567
+ - angle
567
568
  - any
568
569
  - arccos
569
570
  - arccosh
570
- - argmax
571
- - angle
572
571
  - arcsin
573
572
  - arcsinh
574
573
  - arctan
575
574
  - arctanh
575
+ - argmax
576
576
  - argmin
577
577
  - argsort
578
578
  - asin
@@ -582,19 +582,23 @@ tensor:
582
582
  - atanh
583
583
  - baddbmm
584
584
  - bernoulli
585
+ - bfloat16
585
586
  - bincount
586
587
  - bitwise_and
587
588
  - bitwise_or
588
589
  - bitwise_xor
589
590
  - bmm
590
591
  - bool
592
+ - bool astype
591
593
  - broadcast_to
594
+ - byte
592
595
  - ceil
593
- - cholesky_solve
594
596
  - cholesky
597
+ - cholesky_solve
595
598
  - clamp
596
599
  - clip
597
600
  - conj
601
+ - copy
598
602
  - copysign
599
603
  - cos
600
604
  - cosh
@@ -606,11 +610,13 @@ tensor:
606
610
  - deg2rad
607
611
  - diag
608
612
  - diagflat
613
+ - diagonal
609
614
  - diff
610
615
  - digamma
611
616
  - div
612
617
  - div_
613
618
  - divide
619
+ - double
614
620
  - equal
615
621
  - erf
616
622
  - erfc
@@ -618,13 +624,16 @@ tensor:
618
624
  - exp
619
625
  - expand_as
620
626
  - expm1
627
+ - flatten
621
628
  - flip
622
629
  - fliplr
623
630
  - flipud
631
+ - float
624
632
  - float_power
625
633
  - floor
626
634
  - fmod
627
635
  - frac
636
+ - from_numpy
628
637
  - gather_elements
629
638
  - ge
630
639
  - geqrf
@@ -648,12 +657,12 @@ tensor:
648
657
  - inner
649
658
  - int
650
659
  - inverse
660
+ - is_complex
661
+ - is_signed
651
662
  - isclose
652
663
  - isfinite
653
664
  - isinf
654
665
  - isnan
655
- - is_complex
656
- - is_signed
657
666
  - isneginf
658
667
  - isposinf
659
668
  - isreal
@@ -704,28 +713,27 @@ tensor:
704
713
  - new_ones
705
714
  - new_zeros
706
715
  - nextafter
707
- - norm
708
716
  - nonzero
717
+ - norm
709
718
  - not_equal
710
719
  - ormqr
711
720
  - permute
712
721
  - pow
713
722
  - prod
714
723
  - qr
724
+ - rad2deg
715
725
  - ravel
716
726
  - real
717
727
  - reciprocal
718
728
  - remainder
719
729
  - renorm
720
- - rad2deg
721
- - tile
722
730
  - repeat_interleave
723
731
  - reshape
724
732
  - reshape
725
- - round
733
+ - resize
726
734
  - rot90
735
+ - round
727
736
  - rsqrt
728
- - sum_to_size
729
737
  - scatter
730
738
  - sgn
731
739
  - short
@@ -745,7 +753,8 @@ tensor:
745
753
  - sub
746
754
  - sub_
747
755
  - subtract
748
- - subtract
756
+ - sum
757
+ - sum_to_size
749
758
  - svd
750
759
  - swapaxes
751
760
  - swapdims
@@ -753,13 +762,13 @@ tensor:
753
762
  - take
754
763
  - tan
755
764
  - tanh
756
- - trace
757
- - swapaxes
765
+ - tensor_split
758
766
  - tile
767
+ - to
759
768
  - topk
760
- - tril
761
- - tensor_split
769
+ - trace
762
770
  - transpose
771
+ - tril
763
772
  - true_divide
764
773
  - trunc
765
774
  - unbind
@@ -769,17 +778,6 @@ tensor:
769
778
  - view
770
779
  - where
771
780
  - xlogy
772
- - from_numpy
773
- - std
774
- - take
775
- - var
776
- - all
777
- - any
778
- - copy
779
- - diagonal
780
- - flatten
781
- - resize
782
- - sum
783
781
 
784
782
  mint.ops:
785
783
  - abs
@@ -16,6 +16,7 @@
16
16
  import os
17
17
  from collections import defaultdict
18
18
 
19
+ import mindspore
19
20
  from mindspore._c_expression import PyNativeExecutor_
20
21
  try:
21
22
  from mindspore.common.api import _MindsporeFunctionExecutor
@@ -25,7 +26,10 @@ except ImportError:
25
26
  from msprobe.core.common.log import logger
26
27
  from msprobe.core.common.const import Const
27
28
  from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs
28
- from msprobe.mindspore.dump.hook_cell.api_registry import api_register
29
+ from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
30
+
31
+
32
+ _api_register = get_api_register()
29
33
 
30
34
 
31
35
  def dump_jit(name, in_feat, out_feat, is_forward):
@@ -57,7 +61,7 @@ def dump_jit(name, in_feat, out_feat, is_forward):
57
61
  class JitDump(_MindsporeFunctionExecutor):
58
62
  dump_config = None
59
63
  jit_enable = False
60
- jit_dump_switch = True
64
+ jit_dump_switch = False
61
65
  jit_count = defaultdict(int)
62
66
 
63
67
  def __init__(self, *args, **kwargs):
@@ -68,8 +72,7 @@ class JitDump(_MindsporeFunctionExecutor):
68
72
  self._executor = PyNativeExecutor_.get_instance()
69
73
 
70
74
  def __call__(self, *args, **kwargs):
71
- if JitDump.jit_dump_switch:
72
- api_register.api_set_ori_func()
75
+ _api_register.restore_all_api()
73
76
  out = super().__call__(*args, **kwargs)
74
77
  if JitDump.jit_dump_switch and len(args) > 0:
75
78
  if self.name and self.name != "construct":
@@ -79,8 +82,7 @@ class JitDump(_MindsporeFunctionExecutor):
79
82
  JitDump.jit_enable = True
80
83
  elif len(args) == 0:
81
84
  logger.warning(f"The jit function {self.name} has no input arguments, nothing will be dumped.")
82
- if JitDump.jit_dump_switch:
83
- api_register.api_set_hook_func()
85
+ _api_register.register_all_api()
84
86
  return out
85
87
 
86
88
  @classmethod
@@ -101,9 +103,12 @@ class JitDump(_MindsporeFunctionExecutor):
101
103
 
102
104
  def grad(self, obj, grad, weights, grad_position, *args, **kwargs):
103
105
  if JitDump.jit_dump_switch and JitDump.jit_enable:
104
- api_register.api_set_ori_func()
105
- output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
106
+ _api_register.restore_all_api()
107
+ if mindspore.__version__ >= "2.5":
108
+ output = self._executor.grad(grad, obj, weights, grad_position, False, *args, *(kwargs.values()))
109
+ else:
110
+ output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values()))
106
111
  if JitDump.jit_dump_switch and JitDump.jit_enable:
107
112
  dump_jit(obj, args, None, False)
108
- api_register.api_set_hook_func()
113
+ _api_register.register_all_api()
109
114
  return output
@@ -18,37 +18,10 @@
18
18
  #include <sys/stat.h>
19
19
  #include <cstdlib>
20
20
  #include <cstring>
21
+ #include <pybind11/embed.h>
21
22
  #include "utils/log_adapter.h"
22
23
 
23
- namespace {
24
-
25
- // Utility function to check if a file path is valid
26
- bool IsValidPath(const std::string &path) {
27
- struct stat fileStat;
28
- if (stat(path.c_str(), &fileStat) != 0) {
29
- MS_LOG(ERROR) << "File does not exist or cannot be accessed: " << path;
30
- return false;
31
- }
32
-
33
- if (S_ISLNK(fileStat.st_mode)) {
34
- MS_LOG(ERROR) << "File is a symbolic link, which is not allowed: " << path;
35
- return false;
36
- }
37
-
38
- if (!S_ISREG(fileStat.st_mode)) {
39
- MS_LOG(ERROR) << "File is not a regular file: " << path;
40
- return false;
41
- }
42
-
43
- if (path.substr(path.find_last_of(".")) != ".so") {
44
- MS_LOG(ERROR) << "File is not a .so file: " << path;
45
- return false;
46
- }
47
-
48
- return true;
49
- }
50
-
51
- } // namespace
24
+ namespace py = pybind11;
52
25
 
53
26
  HookDynamicLoader &HookDynamicLoader::GetInstance() {
54
27
  static HookDynamicLoader instance;
@@ -65,38 +38,31 @@ bool HookDynamicLoader::loadFunction(void *handle, const std::string &functionNa
65
38
  return true;
66
39
  }
67
40
 
68
- bool HookDynamicLoader::validateLibraryPath(const std::string &libPath) {
69
- char *realPath = realpath(libPath.c_str(), nullptr);
70
- if (!realPath) {
71
- MS_LOG(WARNING) << "Failed to resolve realpath for the library: " << libPath;
72
- return false;
73
- }
74
-
75
- bool isValid = IsValidPath(realPath);
76
- free(realPath); // Free memory allocated by realpath
77
- return isValid;
78
- }
79
-
80
41
  bool HookDynamicLoader::LoadLibrary() {
81
- const char *libPath = std::getenv("HOOK_TOOL_PATH");
82
- if (!libPath) {
83
- MS_LOG(WARNING) << "HOOK_TOOL_PATH is not set!";
84
- return false;
85
- }
86
-
87
- std::string resolvedLibPath(libPath);
88
- if (!validateLibraryPath(resolvedLibPath)) {
89
- MS_LOG(WARNING) << "Library path validation failed.";
90
- return false;
91
- }
92
-
42
+ std::string msprobePath = "";
43
+ // 获取gil锁
44
+ py::gil_scoped_acquire acquire;
45
+ try {
46
+ py::module msprobeMod = py::module::import("msprobe.lib._msprobe_c");
47
+ if (!py::hasattr(msprobeMod, "__file__")) {
48
+ MS_LOG(WARNING) << "Adump mod not found";
49
+ return false;
50
+ }
51
+ msprobePath = msprobeMod.attr("__file__").cast<std::string>();
52
+ } catch (const std::exception& e) {
53
+ MS_LOG(WARNING) << "Adump mod path unable to get: " << e.what();
54
+ return false;
55
+ }
93
56
  std::lock_guard<std::mutex> lock(mutex_);
94
57
  if (handle_) {
95
58
  MS_LOG(WARNING) << "Hook library already loaded!";
96
59
  return false;
97
60
  }
98
-
99
- handle_ = dlopen(resolvedLibPath.c_str(), RTLD_LAZY | RTLD_LOCAL);
61
+ if (msprobePath == "") {
62
+ MS_LOG(WARNING) << "Adump path not loaded";
63
+ return false;
64
+ }
65
+ handle_ = dlopen(msprobePath.c_str(), RTLD_LAZY | RTLD_LOCAL);
100
66
  if (!handle_) {
101
67
  MS_LOG(WARNING) << "Failed to load Hook library: " << dlerror();
102
68
  return false;
@@ -104,7 +70,7 @@ bool HookDynamicLoader::LoadLibrary() {
104
70
 
105
71
  for (const auto &functionName : functionList_) {
106
72
  if (!loadFunction(handle_, functionName)) {
107
- MS_LOG(WARNING) << "Failed to load function: " << functionName;
73
+ MS_LOG(WARNING) << "Failed to load adump function";
108
74
  dlclose(handle_);
109
75
  handle_ = nullptr;
110
76
  return false;
@@ -40,7 +40,6 @@ class HookDynamicLoader {
40
40
  private:
41
41
  // Helper functions
42
42
  bool loadFunction(void *handle, const std::string &functionName);
43
- bool validateLibraryPath(const std::string &libPath);
44
43
 
45
44
  HookDynamicLoader() = default;
46
45
 
@@ -19,6 +19,7 @@ import os
19
19
  import traceback
20
20
 
21
21
  import mindspore as ms
22
+
22
23
  from msprobe.core.common.const import Const
23
24
  from msprobe.core.common.exceptions import DistributedNotInitializedError
24
25
  from msprobe.core.common.file_utils import check_path_length, load_yaml
@@ -27,7 +28,7 @@ from msprobe.mindspore.common.const import FreeBenchmarkConst
27
28
  from msprobe.mindspore.common.log import logger
28
29
  from msprobe.mindspore.common.utils import get_rank_if_initialized
29
30
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
30
- from msprobe.mindspore.dump.hook_cell.api_registry import api_register
31
+ from msprobe.mindspore.dump.hook_cell.api_register import get_api_register
31
32
  from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
32
33
  from msprobe.mindspore.free_benchmark.common.config import Config
33
34
  from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams
@@ -37,6 +38,9 @@ from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import P
37
38
  from msprobe.mindspore.runtime import Runtime
38
39
 
39
40
 
41
+ _api_register = get_api_register()
42
+
43
+
40
44
  class ApiPyNativeSelfCheck:
41
45
  def __init__(self, config: DebuggerConfig):
42
46
  Config.is_enable = True
@@ -60,8 +64,8 @@ class ApiPyNativeSelfCheck:
60
64
  self.store_original_func()
61
65
 
62
66
  def handle(self):
63
- api_register.initialize_hook(self.build_hook)
64
- api_register.api_set_hook_func()
67
+ _api_register.initialize_hook(self.build_hook)
68
+ _api_register.register_all_api()
65
69
 
66
70
  def build_hook(self, api_name):
67
71
  def pre_hook(cell, input_data):
@@ -166,13 +170,13 @@ def check_self(api_name_with_id, output, ori_func, *args, **kwargs):
166
170
  return ret
167
171
 
168
172
  logger.info(f"[{api_name_with_id}] is {Config.handler_type}ing.")
169
- api_register.api_set_ori_func()
173
+ _api_register.restore_all_api()
170
174
 
171
175
  try:
172
176
  perturbation = PerturbationFactory.create(api_name_with_id)
173
177
  params.fuzzed_result = perturbation.handle(params)
174
178
  if params.fuzzed_result is False:
175
- api_register.api_set_hook_func()
179
+ _api_register.register_all_api()
176
180
  return ret
177
181
  if Config.stage == Const.BACKWARD:
178
182
  params.original_result = Tools.get_grad(params.original_func, *params.args, **params.kwargs)
@@ -183,7 +187,7 @@ def check_self(api_name_with_id, output, ori_func, *args, **kwargs):
183
187
  logger.error(f"[{api_name_with_id}] Error: {str(e)}")
184
188
  logger.error(f"[{api_name_with_id}] Error detail: {traceback.format_exc()}")
185
189
 
186
- api_register.api_set_hook_func()
190
+ _api_register.register_all_api()
187
191
  return ret
188
192
 
189
193
 
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from msprobe.mindspore.common.const import FreeBenchmarkConst
17
+ from msprobe.mindspore.common.log import logger
17
18
  from msprobe.mindspore.free_benchmark.common.config import Config
18
19
  from msprobe.mindspore.free_benchmark.perturbation.add_noise import AddNoisePerturbation
19
20
  from msprobe.mindspore.free_benchmark.perturbation.bit_noise import BitNoisePerturbation
@@ -41,4 +42,5 @@ class PerturbationFactory:
41
42
  if perturbation:
42
43
  return perturbation(api_name_with_id)
43
44
  else:
44
- raise Exception(f'{Config.pert_type} is a invalid perturbation type')
45
+ logger.error(f'{Config.pert_type} is a invalid perturbation type')
46
+ raise ValueError
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -14,6 +14,7 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from msprobe.mindspore.common.const import Const
17
+ from msprobe.core.common.log import logger
17
18
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
18
19
  from msprobe.mindspore.free_benchmark.api_pynative_self_check import ApiPyNativeSelfCheck
19
20
 
@@ -41,8 +42,10 @@ class SelfCheckToolFactory:
41
42
  def create(config: DebuggerConfig):
42
43
  tool = SelfCheckToolFactory.tools.get(config.level)
43
44
  if not tool:
44
- raise Exception(f"{config.level} is not supported.")
45
+ logger.error(f"{config.level} is not supported.")
46
+ raise ValueError
45
47
  tool = tool.get(config.execution_mode)
46
48
  if not tool:
47
- raise Exception(f"Task free_benchmark is not supported in this mode: {config.execution_mode}.")
49
+ logger.error(f"Task free_benchmark is not supported in this mode: {config.execution_mode}.")
50
+ raise ValueError
48
51
  return tool(config)
@@ -16,6 +16,7 @@
16
16
  import os
17
17
  import threading
18
18
  from typing import Dict, Union, Tuple
19
+ import time
19
20
 
20
21
  from msprobe.core.common.utils import is_int
21
22
  from msprobe.core.common.file_utils import create_directory, check_path_before_create
@@ -68,6 +69,7 @@ class GlobalContext:
68
69
  create_directory(self._setting.get(GradConst.OUTPUT_PATH))
69
70
  else:
70
71
  logger.warning("The output_path exists, the data will be covered.")
72
+ self._setting[GradConst.TIME_STAMP] = str(int(time.time()))
71
73
 
72
74
  def get_context(self, key: str):
73
75
  if key not in self._setting:
@@ -111,7 +111,8 @@ class CSVGenerator(Process):
111
111
  output_path = context.get_context(GradConst.OUTPUT_PATH)
112
112
  self.level = context.get_context(GradConst.LEVEL)
113
113
  self.bounds = context.get_context(GradConst.BOUNDS)
114
- self.dump_dir = f"{output_path}/rank{rank_id}/Dump/"
114
+ time_stamp = context.get_context(GradConst.TIME_STAMP)
115
+ self.dump_dir = f"{output_path}/rank{rank_id}/Dump{time_stamp}/"
115
116
  self.save_dir = f"{output_path}/rank{rank_id}/"
116
117
  self.current_step = None
117
118
  self.stop_event = multiprocessing.Event()
@@ -49,12 +49,10 @@ class HookInput:
49
49
  self.param_list = grad_context.get_context(GradConst.PARAM_LIST)
50
50
  self.rank_id = get_rank_id()
51
51
  output_path = grad_context.get_context(GradConst.OUTPUT_PATH)
52
- self.dump_dir = os.path.join(output_path, f"rank{self.rank_id}", "Dump")
52
+ time_stamp = grad_context.get_context(GradConst.TIME_STAMP)
53
+ self.dump_dir = os.path.join(output_path, f"rank{self.rank_id}", f"Dump{time_stamp}")
53
54
  self.save_dir = os.path.join(output_path, f"rank{self.rank_id}")
54
55
  self.step_finish_flag = os.path.join(self.dump_dir, GradConst.STEP_FINISH)
55
- if os.path.exists(self.save_dir):
56
- logger.warning(f"Delete existing path {self.save_dir}.")
57
- remove_path(self.save_dir)
58
56
  self.level = grad_context.get_context(GradConst.LEVEL)
59
57
  self.bounds = grad_context.get_context(GradConst.BOUNDS)
60
58
  self.mode = mindspore.get_context("mode")
@@ -281,7 +281,7 @@ def create_hooks(context, monitor):
281
281
  global RANK
282
282
  pre_hooks = []
283
283
  hooks = []
284
- RANK = str(get_rank())
284
+ RANK = get_rank()
285
285
  if communication.GlobalComm.INITED and RANK not in monitor.module_rank_list and monitor.module_rank_list != []:
286
286
  return [pre_hooks, hooks]
287
287