mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,4 +1,5 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
2
3
  #
3
4
  # Licensed under the Apache License, Version 2.0 (the "License");
4
5
  # you may not use this file except in compliance with the License.
@@ -11,37 +12,33 @@
11
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
13
  # See the License for the specific language governing permissions and
13
14
  # limitations under the License.
14
- # ============================================================================
15
15
 
16
- import os
17
16
  import copy
18
17
  import functools
18
+ import os
19
19
  from collections import defaultdict
20
20
 
21
21
  import mindspore as ms
22
- from mindspore.common.tensor import Tensor
23
- from mindspore import ops
24
22
  from mindspore import nn
25
23
  try:
26
24
  from mindspore.common._pijit_context import PIJitCaptureContext
27
- pijit_label = True
28
25
  except ImportError:
29
26
  pijit_label = False
27
+ else:
28
+ pijit_label = True
30
29
 
31
30
 
31
+ from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException
32
+ from msprobe.core.common.file_utils import create_directory
33
+ from msprobe.core.common.utils import Const, print_tools_ends_info
32
34
  from msprobe.core.data_dump.data_collector import build_data_collector
35
+ from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs
33
36
  from msprobe.core.data_dump.scope import BaseScope
34
- from msprobe.mindspore.common.utils import get_rank_if_initialized
35
- from msprobe.core.common.file_utils import create_directory
37
+ from msprobe.mindspore.cell_processor import CellProcessor
36
38
  from msprobe.mindspore.common.log import logger
37
- from msprobe.core.common.utils import Const
38
- from msprobe.core.common.exceptions import DistributedNotInitializedError
39
+ from msprobe.mindspore.common.utils import get_rank_if_initialized
39
40
  from msprobe.mindspore.dump.hook_cell.api_registry import api_register
40
- from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \
41
- ModuleBackwardInputs, ModuleBackwardOutputs
42
- from msprobe.core.common.exceptions import MsprobeException
43
- from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell
44
- from msprobe.mindspore.cell_processor import CellProcessor
41
+ from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService
45
42
  from msprobe.mindspore.dump.jit_dump import JitDump
46
43
 
47
44
 
@@ -52,11 +49,12 @@ class Service:
52
49
  self.config.level = self.config.level_ori
53
50
  self.data_collector = build_data_collector(self.config)
54
51
  self.cell_processor = CellProcessor(self.data_collector.scope)
52
+ self.primitive_hook_service = PrimitiveHookService(self)
55
53
  self.switch = False
54
+ self.primitive_switch = False
56
55
  self.current_iter = 0
57
56
  self.first_start = True
58
57
  self.current_rank = None
59
- self.primitive_counters = {}
60
58
  self.dump_iter_dir = None
61
59
  self.start_call = False
62
60
  self.check_level_valid()
@@ -71,28 +69,30 @@ class Service:
71
69
  )
72
70
 
73
71
  def check_level_valid(self):
74
- if self.config.level == "L2":
72
+ if self.config.level == Const.LEVEL_L2:
75
73
  raise MsprobeException(
76
74
  MsprobeException.INVALID_PARAM_ERROR, "L2 level dump function is currently not supported."
77
75
  )
78
76
 
79
77
  def build_hook(self, target_type, name):
80
- def forward_hook(api_or_cell_name, cell, input, output):
78
+ def forward_hook(api_or_cell_name, cell, input_data, output):
81
79
  if not self.should_excute_hook():
80
+ if hasattr(cell, 'input_kwargs'):
81
+ del cell.input_kwargs
82
82
  return None
83
83
 
84
84
  if target_type == BaseScope.Module_Type_Module:
85
- api_or_cell_name = cell.mindstudio_reserved_name
86
- module_input_output = ModuleForwardInputsOutputs(args=input, kwargs={}, output=output)
85
+ api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
86
+ module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output)
87
87
  else:
88
- module_input_output = ModuleForwardInputsOutputs(args=input, kwargs=cell.input_kwargs,
88
+ module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs,
89
89
  output=output)
90
90
 
91
91
  self.data_collector.update_api_or_module_name(api_or_cell_name)
92
92
  self.data_collector.forward_data_collect(api_or_cell_name, cell, pid, module_input_output)
93
93
  if self.data_collector.if_return_forward_new_output():
94
94
  return self.data_collector.get_forward_new_output()
95
- if target_type == BaseScope.Module_Type_API:
95
+ if hasattr(cell, 'input_kwargs'):
96
96
  del cell.input_kwargs
97
97
  return output
98
98
 
@@ -100,12 +100,19 @@ class Service:
100
100
  if not self.should_excute_hook():
101
101
  return
102
102
 
103
+ need_exchange = True
103
104
  if target_type == BaseScope.Module_Type_Module:
104
- api_or_cell_name = cell.mindstudio_reserved_name
105
+ if not hasattr(cell, 'has_pre_hook_called') or not cell.has_pre_hook_called:
106
+ need_exchange = False
107
+ api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name)
108
+
105
109
  self.data_collector.update_api_or_module_name(api_or_cell_name)
106
110
  if self.data_collector:
107
111
  # 框架最新接口变更,grad_input和grad_output的含义发生了变化,与torch含义保持一致,因此此处调换顺序传入
108
- module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
112
+ if need_exchange:
113
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input)
114
+ else:
115
+ module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output)
109
116
  self.data_collector.backward_data_collect(api_or_cell_name, cell, pid, module_input_output)
110
117
 
111
118
  pid = os.getpid()
@@ -114,145 +121,40 @@ class Service:
114
121
  forward_hook = functools.partial(forward_hook, forward_name_template)
115
122
  backward_hook = functools.partial(backward_hook, backward_name_template)
116
123
 
117
- def wrap_forward_hook(cell, input, output):
118
- return forward_hook(cell, input, output)
124
+ def wrap_forward_hook(cell, input_data, output_data):
125
+ return forward_hook(cell, input_data, output_data)
119
126
 
120
127
  def wrap_backward_hook(cell, grad_input, grad_output):
121
128
  return backward_hook(cell, grad_input, grad_output)
122
129
 
123
130
  return wrap_forward_hook, wrap_backward_hook
124
131
 
125
- def wrap_primitive(self, origin_func, primitive_name):
126
- service_instance = self
127
-
128
- def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type):
129
- def backward_hook(grad):
130
- captured_grads.append(grad)
131
- backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}"
132
- try:
133
- if len(captured_grads) == num_tensors and hook_type == Const.INPUT:
134
- service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
135
- new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads))
136
- service_instance.data_collector.backward_output_data_collect(
137
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
138
- )
139
- captured_grads.clear()
140
- elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT:
141
- service_instance.data_collector.update_api_or_module_name(backward_primitive_name)
142
- new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads))
143
- service_instance.data_collector.backward_input_data_collect(
144
- backward_primitive_name, service_instance, os.getpid(), new_module_input_output
145
- )
146
- captured_grads.clear()
147
-
148
- except Exception as exception:
149
- raise Exception(f"This is a primitive op {hook_type}_backward dump error: {exception},"
150
- f" updated_primitive_name: {updated_primitive_name}") from exception
151
-
152
- return backward_hook
153
-
154
- def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name):
155
- hooked_inputs = []
156
- num_tensors = sum(isinstance(arg, Tensor) for arg in args)
157
- input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name,
158
- Const.INPUT)
159
- for _, arg in enumerate(args):
160
- if isinstance(arg, Tensor):
161
- arg_hooked = ops.HookBackward(input_backward_hook)(arg)
162
- hooked_inputs.append(arg_hooked)
163
- else:
164
- hooked_inputs.append(arg)
165
- return hooked_inputs
166
-
167
- def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name):
168
- if isinstance(out, tuple):
169
- num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out)
170
- else:
171
- num_output_tensors = 1
172
- output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors,
173
- updated_primitive_name, Const.OUTPUT)
174
-
175
- if isinstance(out, Tensor):
176
- return ops.HookBackward(output_backward_hook)(out)
177
- elif isinstance(out, tuple):
178
- hooked_outputs = []
179
- for tensor in out:
180
- if isinstance(tensor, Tensor):
181
- hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor))
182
- else:
183
- hooked_outputs.append(tensor)
184
- return tuple(hooked_outputs)
185
- return out
186
-
187
- def wrapped_primitive_call(instance_self, *args, **kwargs):
188
- service_instance.update_primitive_counters(primitive_name)
189
- current_count = service_instance.primitive_counters.get(primitive_name, 0)
190
- updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}"
191
-
192
- if not service_instance.switch:
193
- return origin_func(*args, **kwargs)
194
-
195
- captured_grads_input, captured_grads_output = [], []
196
-
197
- try:
198
- hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name)
199
- except Exception as exception:
200
- raise Exception("This is a primitive op dump error during input hooking: {},"
201
- " primitive_name: {}".format(exception, primitive_name)) from exception
202
-
203
- try:
204
- out = origin_func(*hooked_inputs, **kwargs)
205
- except Exception as exception:
206
- raise Exception("This is a primitive op dump error during function call: {},"
207
- " primitive_name: {}".format(exception, primitive_name)) from exception
208
-
209
- forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}"
210
- service_instance.data_collector.update_api_or_module_name(forward_primitive_name)
211
- if service_instance.data_collector:
212
- module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out)
213
- try:
214
- service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self,
215
- os.getpid(), module_input_output)
216
- except Exception as exception:
217
- raise Exception("This is a primitive op dump error during forward data collection: {},"
218
- " primitive_name: {}".format(exception, primitive_name)) from exception
219
-
220
- if service_instance.data_collector.if_return_forward_new_output():
221
- out = service_instance.data_collector.get_forward_new_output()
222
-
223
- try:
224
- out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name)
225
- except Exception as exception:
226
- raise Exception("This is a primitive op dump error during output hooking: {},"
227
- " primitive_name: {}".format(exception, primitive_name)) from exception
228
-
229
- return out
230
-
231
- return wrapped_primitive_call
232
-
233
132
  def update_primitive_counters(self, primitive_name):
234
133
  if primitive_name not in self.primitive_counters:
235
134
  self.primitive_counters[primitive_name] = 0
236
135
  else:
237
136
  self.primitive_counters[primitive_name] += 1
238
137
 
239
- def register_hooks(self):
138
+ def register_primitive_hooks(self):
240
139
  primitive_set = set()
241
140
  for _, cell in self.model.cells_and_names():
242
141
  for pname, primitive in cell._primitives.items():
243
142
  primitive_set.add((pname, primitive))
244
143
 
245
144
  for pname, primitive in primitive_set:
246
- NewPrimitive = type('NewPrimitive', (primitive.__class__,),
247
- {'__call__': self.wrap_primitive(primitive.__call__, pname)})
248
- primitive.__class__ = NewPrimitive
145
+ primitive_class_name = primitive.__class__.__name__
146
+ primitive_combined_name = pname + Const.SEP + primitive_class_name
147
+ new_primitive = type('NewPrimitive', (primitive.__class__,),
148
+ {'__call__': self.primitive_hook_service.wrap_primitive(primitive.__call__,
149
+ primitive_combined_name)})
150
+ primitive.__class__ = new_primitive
249
151
 
250
152
  def step(self):
251
153
  self.current_iter += 1
252
154
  self.data_collector.update_iter(self.current_iter)
253
- HOOKCell.cell_count = defaultdict(int)
254
- CellProcessor.cell_count = {}
255
- self.primitive_counters.clear()
155
+ self.primitive_hook_service.primitive_counters.clear()
156
+ self.data_collector.data_writer.reset_cache()
157
+ JitDump.jit_count = defaultdict(int)
256
158
 
257
159
  def start(self, model=None):
258
160
  self.start_call = True
@@ -262,9 +164,8 @@ class Service:
262
164
  api_register.api_set_ori_func()
263
165
  self.should_stop_service = True
264
166
  self.switch = False
265
- logger.info("************************************************")
266
- logger.info(f"* {Const.TOOL_NAME} ends successfully. *")
267
- logger.info("************************************************")
167
+ self.primitive_switch = False
168
+ print_tools_ends_info()
268
169
  return
269
170
  if self.config.step and self.current_iter not in self.config.step:
270
171
  return
@@ -281,7 +182,7 @@ class Service:
281
182
  if self.config.rank and self.current_rank not in self.config.rank:
282
183
  return
283
184
  self.register_hook_new()
284
- if self.config.level == "L1":
185
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
285
186
  JitDump.set_config(self.config)
286
187
  JitDump.set_data_collector(self.data_collector)
287
188
  ms.common.api._MindsporeFunctionExecutor = JitDump
@@ -291,10 +192,32 @@ class Service:
291
192
  PIJitCaptureContext.__exit__ = self.empty
292
193
  self.first_start = False
293
194
 
195
+ api_register.api_set_hook_func()
294
196
  self.switch = True
197
+ self.primitive_switch = True
295
198
  logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
296
199
  self.create_dirs()
297
200
  logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
201
+ JitDump.jit_dump_switch = True
202
+
203
+ def forward_backward_dump_end(self):
204
+ if self.should_stop_service:
205
+ return
206
+ logger.info(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() is set successfully. ")
207
+ if not self.start_call:
208
+ logger.error(f"{Const.TOOL_NAME}: debugger.start() is not set in the current scope.")
209
+ raise Exception("debugger.start() is not set in the current scope.")
210
+ if not self.switch:
211
+ logger.error(f"{Const.TOOL_NAME}: debugger.forward_backward_dump_end() should be called between "
212
+ "debugger.start() and debugger.stop() ")
213
+ raise Exception("debugger.stop() is already called. ")
214
+ if self.config.step and self.current_iter not in self.config.step:
215
+ return
216
+ if self.config.rank and self.current_rank not in self.config.rank:
217
+ return
218
+ self.primitive_switch = False
219
+ api_register.api_set_ori_func()
220
+ JitDump.jit_dump_switch = False
298
221
 
299
222
  def stop(self):
300
223
  if self.should_stop_service:
@@ -309,8 +232,10 @@ class Service:
309
232
  if self.config.rank and self.current_rank not in self.config.rank:
310
233
  return
311
234
  self.switch = False
235
+ self.primitive_switch = False
312
236
  self.start_call = False
313
237
  self.data_collector.write_json()
238
+ JitDump.jit_dump_switch = False
314
239
 
315
240
  def need_end_service(self):
316
241
  if self.config.step and self.current_iter > max(self.config.step):
@@ -349,16 +274,16 @@ class Service:
349
274
 
350
275
  def register_hook_new(self):
351
276
  logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task))
352
- if self.config.level == "L1":
277
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L1]:
353
278
  api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API))
354
279
  api_register.api_set_hook_func()
355
- if self.model:
356
- self.register_hooks()
280
+ if self.model and self.config.task in Const.DUMP_DATA_COLLECTION_LIST:
281
+ self.register_primitive_hooks()
357
282
 
358
- if self.config.level == "L0":
283
+ if self.config.level in [Const.LEVEL_MIX, Const.LEVEL_L0]:
359
284
  if not self.model:
360
285
  raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
361
- "The current level is L0, the model cannot be None")
286
+ f"The current level is {self.config.level}, the model cannot be None")
362
287
  for name, cell in self.model.cells_and_names():
363
288
  if cell == self.model:
364
289
  continue
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.core.common.const import Const
2
17
  from msprobe.mindspore.debugger.debugger_config import DebuggerConfig
3
18
  from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory
msprobe/msprobe.py CHANGED
@@ -45,10 +45,15 @@ def main():
45
45
  multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut')
46
46
  api_precision_compare_cmd_parser = subparsers.add_parser('api_precision_compare')
47
47
  run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check')
48
+ graph_service_cmd_parser = subparsers.add_parser('graph')
48
49
  _compare_parser(compare_cmd_parser)
49
- is_torch_available=is_module_available("torch")
50
+ is_torch_available = is_module_available("torch")
50
51
  is_mindspore_available = is_module_available("mindspore")
51
- if is_torch_available:
52
+ if len(sys.argv) < 4:
53
+ parser.print_help()
54
+ sys.exit(0)
55
+ framework_args = parser.parse_args(sys.argv[1:3])
56
+ if framework_args.framework == Const.PT_FRAMEWORK:
52
57
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command
53
58
  from msprobe.pytorch.parse_tool.cli import parse as cli_parse
54
59
  from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut
@@ -56,20 +61,24 @@ def main():
56
61
  _api_precision_compare_command
57
62
  from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \
58
63
  _run_overflow_check_command
64
+ from msprobe.visualization.graph_service import _pt_graph_service_parser, _pt_graph_service_command
59
65
 
60
66
  _run_ut_parser(run_ut_cmd_parser)
61
67
  _run_ut_parser(multi_run_ut_cmd_parser)
62
68
  multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
63
- help='Number of splits for parallel processing. Range: 1-64')
69
+ help='Number of splits for parallel processing. Range: 1-64')
64
70
  _api_precision_compare_parser(api_precision_compare_cmd_parser)
65
71
  _run_overflow_check_parser(run_overflow_check_cmd_parser)
66
- elif is_mindspore_available:
72
+ _pt_graph_service_parser(graph_service_cmd_parser)
73
+ elif framework_args.framework == Const.MS_FRAMEWORK:
67
74
  from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument
75
+ from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command
68
76
  add_api_accuracy_checker_argument(run_ut_cmd_parser)
77
+ from msprobe.mindspore.api_accuracy_checker.cmd_parser import multi_add_api_accuracy_checker_argument
78
+ multi_add_api_accuracy_checker_argument(multi_run_ut_cmd_parser)
79
+
80
+ _ms_graph_service_parser(graph_service_cmd_parser)
69
81
 
70
- if len(sys.argv) == 1:
71
- parser.print_help()
72
- sys.exit(0)
73
82
  args = parser.parse_args(sys.argv[1:])
74
83
  if sys.argv[2] == Const.PT_FRAMEWORK:
75
84
  if not is_torch_available:
@@ -86,6 +95,8 @@ def main():
86
95
  _api_precision_compare_command(args)
87
96
  elif sys.argv[3] == "run_overflow_check":
88
97
  _run_overflow_check_command(args)
98
+ elif sys.argv[3] == "graph":
99
+ _pt_graph_service_command(args)
89
100
  elif sys.argv[3] == "compare":
90
101
  if args.cell_mapping is not None or args.api_mapping is not None:
91
102
  logger.error("Argument -cm or -am is not supported in PyTorch framework")
@@ -100,6 +111,12 @@ def main():
100
111
  elif sys.argv[3] == "run_ut":
101
112
  from msprobe.mindspore.api_accuracy_checker.main import api_checker_main
102
113
  api_checker_main(args)
114
+ elif sys.argv[3] == "multi_run_ut":
115
+ from msprobe.mindspore.api_accuracy_checker.main import mul_api_checker_main
116
+ mul_api_checker_main(args)
117
+ elif sys.argv[3] == "graph":
118
+ _ms_graph_service_command(args)
119
+
103
120
 
104
121
  if __name__ == "__main__":
105
122
  main()
@@ -1,4 +1,24 @@
1
- from .debugger.precision_debugger import PrecisionDebugger
2
- from .common.utils import seed_all
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ from msprobe.pytorch.monitor.module_hook import TrainerMon
3
20
  from .compare.distributed_compare import compare_distributed
4
- from .compare.pt_compare import compare
21
+ from .compare.pt_compare import compare
22
+ from .common.utils import seed_all
23
+ from .debugger.precision_debugger import PrecisionDebugger
24
+ from .functional.module_dump import module_dump, module_dump_end
@@ -1,8 +1,33 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
4
+ # All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
1
18
  import os
19
+ from collections import namedtuple
2
20
  from msprobe.core.common.file_utils import load_yaml, check_file_or_directory_path
21
+ from msprobe.core.common.utils import is_int
3
22
  from msprobe.pytorch.pt_config import RunUTConfig
4
23
 
5
24
 
25
+ RunUtConfig = namedtuple('RunUtConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
26
+ 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
27
+ 'black_list', 'error_data_path', 'online_config'])
28
+ OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
29
+
30
+
6
31
  class Config:
7
32
  def __init__(self, yaml_file):
8
33
  check_file_or_directory_path(yaml_file, False)
@@ -33,8 +58,10 @@ class Config:
33
58
  raise ValueError(f"{key} must be one of {validators.keys()}")
34
59
  if not isinstance(value, validators.get(key)):
35
60
  raise ValueError(f"{key} must be {validators[key].__name__} type")
36
- if key == 'precision' and value < 0:
37
- raise ValueError("precision must be greater than 0")
61
+ if key == 'precision' and not is_int(value):
62
+ raise ValueError("precision must be an integer")
63
+ if key == 'precision' and (value < 0 or value > 20):
64
+ raise ValueError("precision must be greater than or equal to 0 and less than 21")
38
65
  if key == 'white_list':
39
66
  RunUTConfig.check_filter_list_config(key, value)
40
67
  if key == 'black_list':
@@ -51,3 +78,55 @@ class Config:
51
78
  cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
52
79
  yaml_path = os.path.join(cur_path, "config.yaml")
53
80
  msCheckerConfig = Config(yaml_path)
81
+
82
+
83
+ class CheckerConfig:
84
+ def __init__(self, task_config=None):
85
+ self.white_list = msCheckerConfig.white_list
86
+ self.black_list = msCheckerConfig.black_list
87
+ self.error_data_path = msCheckerConfig.error_data_path
88
+ self.is_online = msCheckerConfig.is_online
89
+ self.nfs_path = msCheckerConfig.nfs_path
90
+ self.host = msCheckerConfig.host
91
+ self.port = msCheckerConfig.port
92
+ self.rank_list = msCheckerConfig.rank_list
93
+ self.tls_path = msCheckerConfig.tls_path
94
+
95
+ if task_config:
96
+ self.load_config(task_config)
97
+
98
+ def load_config(self, task_config):
99
+ self.white_list = task_config.white_list
100
+ self.black_list = task_config.black_list
101
+ self.error_data_path = task_config.error_data_path
102
+ self.is_online = task_config.is_online
103
+ self.nfs_path = task_config.nfs_path
104
+ self.host = task_config.host
105
+ self.port = task_config.port
106
+ self.rank_list = task_config.rank_list
107
+ self.tls_path = task_config.tls_path
108
+
109
+ def get_online_config(self):
110
+ return OnlineConfig(
111
+ is_online=self.is_online,
112
+ nfs_path=self.nfs_path,
113
+ host=self.host,
114
+ port=self.port,
115
+ rank_list=self.rank_list,
116
+ tls_path=self.tls_path
117
+ )
118
+
119
+ def get_run_ut_config(self, **config_params):
120
+ return RunUtConfig(
121
+ forward_content=config_params.get('forward_content'),
122
+ backward_content=config_params.get('backward_content'),
123
+ result_csv_path=config_params.get('result_csv_path'),
124
+ details_csv_path=config_params.get('details_csv_path'),
125
+ save_error_data=config_params.get('save_error_data'),
126
+ is_continue_run_ut=config_params.get('is_continue_run_ut'),
127
+ real_data_path=config_params.get('real_data_path'),
128
+ white_list=self.white_list,
129
+ black_list=self.black_list,
130
+ error_data_path=config_params.get('error_data_path'),
131
+ online_config=self.get_online_config()
132
+ )