mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  import stat
3
18
 
@@ -10,6 +25,7 @@ class Const:
10
25
  """
11
26
  TOOL_NAME = "msprobe"
12
27
 
28
+ ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$"
13
29
  SEP = "."
14
30
  REGEX_PREFIX_MAX_LENGTH = 20
15
31
  REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$"
@@ -20,6 +36,8 @@ class Const:
20
36
  OFF = 'OFF'
21
37
  BACKWARD = 'backward'
22
38
  FORWARD = 'forward'
39
+ PROGRESS_TIMEOUT = 3000
40
+ EXCEPTION_NONE = None
23
41
  JIT = 'Jit'
24
42
  PRIMITIVE_PREFIX = 'Primitive'
25
43
  DEFAULT_LIST = []
@@ -82,6 +100,7 @@ class Const:
82
100
  GRAD_PROBE = "grad_probe"
83
101
  TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE]
84
102
  DUMP_DATA_COLLECTION_LIST = [STATISTICS, TENSOR]
103
+ DUMP_DATA_MODE_LIST = [ALL, INPUT, OUTPUT, FORWARD, BACKWARD]
85
104
  LEVEL_L0 = "L0"
86
105
  LEVEL_L1 = "L1"
87
106
  LEVEL_L2 = "L2"
@@ -93,6 +112,7 @@ class Const:
93
112
  DATA = "data"
94
113
  PT_FRAMEWORK = "pytorch"
95
114
  MS_FRAMEWORK = "mindspore"
115
+ UNKNOWN_FRAMEWORK = "unknown"
96
116
  DIRECTORY_LENGTH = 4096
97
117
  FILE_NAME_LENGTH = 255
98
118
  FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16]
@@ -103,6 +123,8 @@ class Const:
103
123
  CPU_LOWERCASE = 'cpu'
104
124
  CUDA_LOWERCASE = 'cuda'
105
125
  DISTRIBUTED = 'Distributed'
126
+ DUMP_PREFIX = ["Distributed", "Functional", "Torch", "Tensor", "Mint", "MintFunctional", "Primitive",
127
+ "Aten", "VF", "NPU", "Jit"]
106
128
 
107
129
  # struct json param
108
130
  ORIGIN_DATA = "origin_data"
@@ -113,21 +135,25 @@ class Const:
113
135
  MODULE_WHITE_LIST = ["torch", "numpy"]
114
136
 
115
137
  FUNC_SKIP_LIST = ["construct", "__call__"]
116
-
117
- FILE_SKIP_LIST = ["site-packages/mindspore", "package/mindspore", "msprobe", "site-packages/torch", "package/torch"]
138
+ FILE_SKIP_LIST = ["msprobe", "MindSpeed"]
139
+ DATA_TYPE_SKIP_LIST = ["Primitive", "Jit"]
118
140
 
119
141
  STACK_FILE_INDEX = 0
120
-
121
142
  STACK_FUNC_INDEX = 2
122
-
123
143
  STACK_FUNC_ELE_INDEX = 1
124
144
 
125
- CONSTRUCT_NAME_INDEX = -3
126
-
127
- NAME_FIRST_POSSIBLE_INDEX = -4
128
-
129
- NAME_SECOND_POSSIBLE_INDEX = -5
130
-
145
+ SCOPE_ID_INDEX = -1
146
+ SCOPE_DIRECTION_INDEX = -2
147
+ TYPE_NAME_INDEX = -3
148
+ LAYER_NAME_INDEX = -4
149
+ API_TYPE_INDEX = 0
150
+ LEFT_MOVE_INDEX = -1
151
+ RIGHT_MOVE_INDEX = 1
152
+
153
+ TOP_LAYER = "TopLayer"
154
+ CELL = "Cell"
155
+ MODULE = "Module"
156
+ FRAME_FILE_LIST = ["site-packages/torch", "package/torch", "site-packages/mindspore", "package/mindspore"]
131
157
  INPLACE_LIST = [
132
158
  "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
133
159
  "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single", "all_to_all",
@@ -145,11 +171,12 @@ class Const:
145
171
  FILL_CHAR_NUMS = 50
146
172
  TOOL_ENDS_SUCCESSFULLY = f"{TOOL_NAME} ends successfully."
147
173
  WITHOUT_CALL_STACK = "The call stack retrieval failed."
148
-
174
+
149
175
  STEP = "step"
150
176
  RANK = "rank"
151
177
  HYPHEN = "-"
152
- STEP_RANK_MAXIMUM_RANGE = [int(0), int(1e6)]
178
+ STEP_RANK_MINIMUM_VALUE = 0
179
+ STEP_RANK_MAXIMUM_VALUE = int(1e6)
153
180
 
154
181
  # data type const
155
182
  FLOAT16 = "Float16"
@@ -159,6 +186,13 @@ class Const:
159
186
  TORCH_FLOAT32 = "torch.float32"
160
187
  TORCH_BFLOAT16 = "torch.bfloat16"
161
188
 
189
+ DTYPE = 'dtype'
190
+ SHAPE = 'shape'
191
+ MAX = 'Max'
192
+ MIN = 'Min'
193
+ MEAN = 'Mean'
194
+ NORM = 'Norm'
195
+
162
196
 
163
197
  class CompareConst:
164
198
  """
@@ -201,10 +235,17 @@ class CompareConst:
201
235
  RESULT = "Result"
202
236
  MAGNITUDE = 0.5
203
237
  OP_NAME = "op_name"
238
+ STRUCT = "struct"
204
239
  INPUT_STRUCT = "input_struct"
240
+ KWARGS_STRUCT = "kwargs_struct"
205
241
  OUTPUT_STRUCT = "output_struct"
206
242
  SUMMARY = "summary"
207
243
  MAX_EXCEL_LENGTH = 1048576
244
+ YES = "Yes"
245
+ NO = "No"
246
+ STATISTICS_INDICATOR_NUM = 4
247
+ EPSILON = 1e-10
248
+ COMPARE_ENDS_SUCCESSFULLY = "msprobe compare ends successfully."
208
249
 
209
250
  COMPARE_RESULT_HEADER = [
210
251
  NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR,
@@ -222,6 +263,12 @@ class CompareConst:
222
263
  NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_MD5, BENCH_MD5, RESULT
223
264
  ]
224
265
 
266
+ HEAD_OF_COMPARE_MODE = {
267
+ Const.ALL: COMPARE_RESULT_HEADER,
268
+ Const.SUMMARY: SUMMARY_COMPARE_RESULT_HEADER,
269
+ Const.MD5: MD5_COMPARE_RESULT_HEADER
270
+ }
271
+
225
272
  # compare standard
226
273
  HUNDRED_RATIO_THRESHOLD = 0.01
227
274
  THOUSAND_RATIO_THRESHOLD = 0.001
@@ -241,6 +288,8 @@ class CompareConst:
241
288
  PASS = 'pass'
242
289
  WARNING = 'Warning'
243
290
  ERROR = 'error'
291
+ TRUE = 'TRUE'
292
+ FALSE = 'FALSE'
244
293
  SKIP = 'SKIP'
245
294
  N_A = 'N/A'
246
295
  INF = 'inf'
@@ -298,6 +347,13 @@ class CompareConst:
298
347
  MAX_DIFF: None, MIN_DIFF: None, MEAN_DIFF: None, NORM_DIFF: None, MAX_RELATIVE_ERR: None,
299
348
  MIN_RELATIVE_ERR: None, MEAN_RELATIVE_ERR: None, NORM_RELATIVE_ERR: None
300
349
  }
350
+ INPUT_PATTERN = Const.SEP + Const.INPUT + Const.SEP
351
+ KWARGS_PATTERN = Const.SEP + Const.KWARGS + Const.SEP
352
+ OUTPUT_PATTERN = Const.SEP + Const.OUTPUT + Const.SEP
353
+ COMPARE_KEY = 'compare_key'
354
+ COMPARE_SHAPE = 'compare_shape'
355
+ INTERNAL_API_MAPPING_FILE = 'ms_to_pt_api.yaml'
356
+ UNREADABLE = 'unreadable data'
301
357
 
302
358
 
303
359
  class FileCheckConst:
@@ -322,7 +378,7 @@ class FileCheckConst:
322
378
  MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
323
379
  MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
324
380
  MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
325
- MAX_YAML_SIZE = 1048576 # 1 * 1024 * 1024
381
+ MAX_YAML_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
326
382
  COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
327
383
  DIR = "dir"
328
384
  FILE = "file"
@@ -351,6 +407,9 @@ class MsCompareConst:
351
407
  # api_info field
352
408
  MINT = "Mint"
353
409
  MINT_FUNCTIONAL = "MintFunctional"
410
+ TENSOR_API = "Tensor"
411
+
412
+ API_NAME_STR_LENGTH = 4
354
413
 
355
414
  TASK_FIELD = "task"
356
415
  STATISTICS_TASK = "statistics"
@@ -358,6 +417,10 @@ class MsCompareConst:
358
417
  DUMP_DATA_DIR_FIELD = "dump_data_dir"
359
418
  DATA_FIELD = "data"
360
419
 
420
+ # supported api yaml
421
+ SUPPORTED_API_LIST_FILE = "checker_support_api.yaml"
422
+ SUPPORTED_TENSOR_LIST_KEY = "tensor"
423
+
361
424
  # detail_csv
362
425
  DETAIL_CSV_API_NAME = "API Name"
363
426
  DETAIL_CSV_BENCH_DTYPE = "Bench Dtype"
@@ -382,15 +445,20 @@ class MsgConst:
382
445
  MSPROBE_LOG_LEVEL = "MSPROBE_LOG_LEVEL"
383
446
  LOG_LEVEL_ENUM = ["0", "1", "2", "3", "4"]
384
447
  LOG_LEVEL = ["DEBUG", "INFO", "WARNING", "ERROR"]
448
+
385
449
  class LogLevel:
386
450
  class DEBUG:
387
451
  value = 0
452
+
388
453
  class INFO:
389
454
  value = 1
455
+
390
456
  class WARNING:
391
457
  value = 2
458
+
392
459
  class ERROR:
393
460
  value = 3
461
+
394
462
  SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"]
395
463
 
396
464
  NOT_CREATED_INSTANCE = "PrecisionDebugger instance is not created."
@@ -400,3 +468,35 @@ class GraphMode:
400
468
  NPY_MODE = "NPY_MODE"
401
469
  STATISTIC_MODE = "STATISTIC_MODE"
402
470
  ERROR_MODE = "ERROR_MODE"
471
+
472
+
473
+ class MonitorConst:
474
+ """
475
+ Class for monitor const
476
+ """
477
+ OP_LIST = ["min", "max", "norm", "zeros", "nans", "id", "mean"]
478
+ MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR"
479
+ DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output"
480
+ DATABASE = "database"
481
+ EMAIL = "email"
482
+ OPT_TY = ['Megatron_DistributedOptimizer', 'Megatron_Float16OptimizerWithFloat16Params']
483
+ DEEPSPEED_OPT_TY = ("DeepSpeedZeroOptimizer_Stage0", "DeepSpeedZeroOptimizer_Stage1_or_2", "DeepSpeedZeroOptimizer_Stage3")
484
+ RULE_NAME = ['AnomalyTurbulence']
485
+
486
+ DOT = "."
487
+ VPP_SEP = ":"
488
+ ACTV_IN = "input"
489
+ ACTV_OUT = "output"
490
+ ACTVGRAD_IN = "input_grad"
491
+ ACTVGRAD_OUT = "output_grad"
492
+ POST_GRAD = "post_grad"
493
+ PRE_GRAD = "pre_grad"
494
+ PREFIX_POST = "post"
495
+ PREFIX_PRE = "pre"
496
+
497
+
498
+ ANOMALY_JSON = "anomaly.json"
499
+ ANALYSE_JSON = "anomaly_analyse.json"
500
+ TENSORBOARD = "tensorboard"
501
+ CSV = "csv"
502
+ API = "api"
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  class CodedException(Exception):
2
17
  def __init__(self, code, error_info=''):
3
18
  super().__init__()
@@ -11,10 +26,12 @@ class CodedException(Exception):
11
26
  class MsprobeException(CodedException):
12
27
  INVALID_PARAM_ERROR = 0
13
28
  OVERFLOW_NUMS_ERROR = 1
29
+ RECURSION_LIMIT_ERROR = 2
14
30
 
15
31
  err_strs = {
16
32
  INVALID_PARAM_ERROR: "[msprobe] 无效参数:",
17
- OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:"
33
+ OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:",
34
+ RECURSION_LIMIT_ERROR: "[msprobe] 递归调用超过限制:"
18
35
  }
19
36
 
20
37
 
@@ -41,7 +58,7 @@ class ParseJsonException(CodedException):
41
58
  InvalidDumpJson = 1
42
59
  err_strs = {
43
60
  UnexpectedNameStruct: "[msprobe] Unexpected name in json: ",
44
- InvalidDumpJson: "[msprobe] json格式不正确: ",
61
+ InvalidDumpJson: "[msprobe] Invalid dump.json format: ",
45
62
  }
46
63
 
47
64
 
@@ -73,9 +90,13 @@ class StepException(CodedException):
73
90
  class FreeBenchmarkException(CodedException):
74
91
  UnsupportedType = 0
75
92
  InvalidGrad = 1
93
+ InvalidPerturbedOutput = 2
94
+ OutputIndexError = 3
76
95
  err_strs = {
77
96
  UnsupportedType: "[msprobe] Free benchmark get unsupported type: ",
78
97
  InvalidGrad: "[msprobe] Free benchmark gradient invalid: ",
98
+ InvalidPerturbedOutput: "[msprobe] Free benchmark invalid perturbed output: ",
99
+ OutputIndexError: "[msprobe] Free benchmark output index out of bounds: ",
79
100
  }
80
101
 
81
102
 
@@ -87,6 +108,7 @@ class DistributedNotInitializedError(Exception):
87
108
  def __str__(self):
88
109
  return self.msg
89
110
 
111
+
90
112
  class ApiAccuracyCheckerException(CodedException):
91
113
  ParseJsonFailed = 0
92
114
  UnsupportType = 1
@@ -97,4 +119,4 @@ class ApiAccuracyCheckerException(CodedException):
97
119
  UnsupportType: "[msprobe] Api Accuracy Checker get unsupported type: ",
98
120
  WrongValue: "[msprobe] Api Accuracy Checker get wrong value: ",
99
121
  ApiWrong: "[msprobe] Api Accuracy Checker something wrong with api: ",
100
- }
122
+ }
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,13 +12,17 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
17
16
  import csv
18
17
  import fcntl
19
18
  import os
19
+ import stat
20
20
  import json
21
21
  import re
22
22
  import shutil
23
+ from datetime import datetime, timezone
24
+ from dateutil import parser
25
+ import OpenSSL
23
26
  import yaml
24
27
  import numpy as np
25
28
  import pandas as pd
@@ -67,9 +70,11 @@ class FileChecker:
67
70
  self.check_path_ability()
68
71
  if self.is_script:
69
72
  check_path_owner_consistent(self.file_path)
70
- check_path_pattern_vaild(self.file_path)
73
+ check_path_pattern_valid(self.file_path)
71
74
  check_common_file_size(self.file_path)
72
75
  check_file_suffix(self.file_path, self.file_type)
76
+ if self.path_type == FileCheckConst.FILE:
77
+ check_dirpath_before_read(self.file_path)
73
78
  return self.file_path
74
79
 
75
80
  def check_path_ability(self):
@@ -122,9 +127,10 @@ class FileOpen:
122
127
  self.file_path = os.path.realpath(self.file_path)
123
128
  check_path_length(self.file_path)
124
129
  self.check_ability_and_owner()
125
- check_path_pattern_vaild(self.file_path)
130
+ check_path_pattern_valid(self.file_path)
126
131
  if os.path.exists(self.file_path):
127
132
  check_common_file_size(self.file_path)
133
+ check_dirpath_before_read(self.file_path)
128
134
 
129
135
  def check_ability_and_owner(self):
130
136
  if self.mode in self.SUPPORT_READ_MODE:
@@ -193,7 +199,7 @@ def check_path_owner_consistent(path):
193
199
  raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
194
200
 
195
201
 
196
- def check_path_pattern_vaild(path):
202
+ def check_path_pattern_valid(path):
197
203
  if not re.match(FileCheckConst.FILE_VALID_PATTERN, path):
198
204
  logger.error('The file path %s contains special characters.' % (path))
199
205
  raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
@@ -217,7 +223,6 @@ def check_common_file_size(file_path):
217
223
  check_file_size(file_path, max_size)
218
224
  return
219
225
  check_file_size(file_path, FileCheckConst.COMMOM_FILE_SIZE)
220
-
221
226
 
222
227
 
223
228
  def check_file_suffix(file_path, file_suffix):
@@ -238,9 +243,18 @@ def check_path_type(file_path, file_type):
238
243
  raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
239
244
 
240
245
 
246
+ def check_others_writable(directory):
247
+ dir_stat = os.stat(directory)
248
+ is_writable = (
249
+ bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写
250
+ bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写
251
+ )
252
+ return is_writable
253
+
254
+
241
255
  def make_dir(dir_path):
242
- dir_path = os.path.realpath(dir_path)
243
256
  check_path_before_create(dir_path)
257
+ dir_path = os.path.realpath(dir_path)
244
258
  if os.path.isdir(dir_path):
245
259
  return
246
260
  try:
@@ -262,8 +276,9 @@ def create_directory(dir_path):
262
276
  Exception Description:
263
277
  when invalid data throw exception
264
278
  """
265
- dir_path = os.path.realpath(dir_path)
279
+ check_link(dir_path)
266
280
  check_path_before_create(dir_path)
281
+ dir_path = os.path.realpath(dir_path)
267
282
  parent_dir = os.path.dirname(dir_path)
268
283
  if not os.path.isdir(parent_dir):
269
284
  create_directory(parent_dir)
@@ -271,6 +286,7 @@ def create_directory(dir_path):
271
286
 
272
287
 
273
288
  def check_path_before_create(path):
289
+ check_link(path)
274
290
  if path_len_exceeds_limit(path):
275
291
  raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, 'The file path length exceeds limit.')
276
292
 
@@ -279,6 +295,17 @@ def check_path_before_create(path):
279
295
  'The file path {} contains special characters.'.format(path))
280
296
 
281
297
 
298
+ def check_dirpath_before_read(path):
299
+ path = os.path.realpath(path)
300
+ dirpath = os.path.dirname(path)
301
+ if check_others_writable(dirpath):
302
+ logger.warning(f"The directory is writable by others: {dirpath}.")
303
+ try:
304
+ check_path_owner_consistent(dirpath)
305
+ except FileCheckException:
306
+ logger.warning(f"The directory {dirpath} is not yours.")
307
+
308
+
282
309
  def check_file_or_directory_path(path, isdir=False):
283
310
  """
284
311
  Function Description:
@@ -344,7 +371,7 @@ def load_yaml(yaml_path):
344
371
  def load_npy(filepath):
345
372
  check_file_or_directory_path(filepath)
346
373
  try:
347
- npy = np.load(filepath)
374
+ npy = np.load(filepath, allow_pickle=False)
348
375
  except Exception as e:
349
376
  logger.error(f"The numpy file failed to load. Please check the path: {filepath}.")
350
377
  raise RuntimeError(f"Load numpy file {filepath} failed.") from e
@@ -354,7 +381,7 @@ def load_npy(filepath):
354
381
  def load_json(json_path):
355
382
  try:
356
383
  with FileOpen(json_path, "r") as f:
357
- fcntl.flock(f, fcntl.LOCK_EX)
384
+ fcntl.flock(f, fcntl.LOCK_SH)
358
385
  data = json.load(f)
359
386
  fcntl.flock(f, fcntl.LOCK_UN)
360
387
  except Exception as e:
@@ -363,11 +390,11 @@ def load_json(json_path):
363
390
  return data
364
391
 
365
392
 
366
- def save_json(json_path, data, indent=None):
367
- json_path = os.path.realpath(json_path)
393
+ def save_json(json_path, data, indent=None, mode="w"):
368
394
  check_path_before_create(json_path)
395
+ json_path = os.path.realpath(json_path)
369
396
  try:
370
- with FileOpen(json_path, 'w') as f:
397
+ with FileOpen(json_path, mode) as f:
371
398
  fcntl.flock(f, fcntl.LOCK_EX)
372
399
  json.dump(data, f, indent=indent)
373
400
  fcntl.flock(f, fcntl.LOCK_UN)
@@ -378,8 +405,8 @@ def save_json(json_path, data, indent=None):
378
405
 
379
406
 
380
407
  def save_yaml(yaml_path, data):
381
- yaml_path = os.path.realpath(yaml_path)
382
408
  check_path_before_create(yaml_path)
409
+ yaml_path = os.path.realpath(yaml_path)
383
410
  try:
384
411
  with FileOpen(yaml_path, 'w') as f:
385
412
  fcntl.flock(f, fcntl.LOCK_EX)
@@ -391,6 +418,21 @@ def save_yaml(yaml_path, data):
391
418
  change_mode(yaml_path, FileCheckConst.DATA_FILE_AUTHORITY)
392
419
 
393
420
 
421
+ def save_excel(path, data):
422
+ check_path_before_create(path)
423
+ path = os.path.realpath(path)
424
+ try:
425
+ if isinstance(data, pd.DataFrame):
426
+ data.to_excel(path, index=False)
427
+ else:
428
+ logger.error(f'unsupported data type.')
429
+ return
430
+ except Exception as e:
431
+ logger.error(f'Save excel file "{os.path.basename(path)}" failed.')
432
+ raise RuntimeError(f"Save excel file {path} failed.") from e
433
+ change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
434
+
435
+
394
436
  def move_file(src_path, dst_path):
395
437
  check_file_or_directory_path(src_path)
396
438
  check_path_before_create(dst_path)
@@ -403,8 +445,8 @@ def move_file(src_path, dst_path):
403
445
 
404
446
 
405
447
  def save_npy(data, filepath):
406
- filepath = os.path.realpath(filepath)
407
448
  check_path_before_create(filepath)
449
+ filepath = os.path.realpath(filepath)
408
450
  try:
409
451
  np.save(filepath, data)
410
452
  except Exception as e:
@@ -425,6 +467,7 @@ def save_npy_to_txt(data, dst_file='', align=0):
425
467
  pad_array = np.zeros((align - data.size % align,))
426
468
  data = np.append(data, pad_array)
427
469
  check_path_before_create(dst_file)
470
+ dst_file = os.path.realpath(dst_file)
428
471
  try:
429
472
  np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
430
473
  except Exception as e:
@@ -438,8 +481,8 @@ def save_workbook(workbook, file_path):
438
481
  workbook: 要保存的工作簿对象
439
482
  file_path: 文件保存路径
440
483
  """
441
- file_path = os.path.realpath(file_path)
442
484
  check_path_before_create(file_path)
485
+ file_path = os.path.realpath(file_path)
443
486
  try:
444
487
  workbook.save(file_path)
445
488
  except Exception as e:
@@ -451,7 +494,7 @@ def save_workbook(workbook, file_path):
451
494
  def write_csv(data, filepath, mode="a+", malicious_check=False):
452
495
  def csv_value_is_valid(value: str) -> bool:
453
496
  if not isinstance(value, str):
454
- return True
497
+ return True
455
498
  try:
456
499
  # -1.00 or +1.00 should be consdiered as digit numbers
457
500
  float(value)
@@ -459,16 +502,16 @@ def write_csv(data, filepath, mode="a+", malicious_check=False):
459
502
  # otherwise, they will be considered as formular injections
460
503
  return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
461
504
  return True
462
-
505
+
463
506
  if malicious_check:
464
507
  for row in data:
465
508
  for cell in row:
466
509
  if not csv_value_is_valid(cell):
467
- raise RuntimeError(f"Malicious value [{cell}] is not allowed " \
510
+ raise RuntimeError(f"Malicious value [{cell}] is not allowed "
468
511
  f"to be written into the csv: {filepath}.")
469
512
 
470
- file_path = os.path.realpath(filepath)
471
513
  check_path_before_create(filepath)
514
+ file_path = os.path.realpath(filepath)
472
515
  try:
473
516
  with FileOpen(filepath, mode, encoding='utf-8-sig') as f:
474
517
  writer = csv.writer(f)
@@ -479,16 +522,54 @@ def write_csv(data, filepath, mode="a+", malicious_check=False):
479
522
  change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
480
523
 
481
524
 
482
- def read_csv(filepath):
525
+ def read_csv(filepath, as_pd=True):
483
526
  check_file_or_directory_path(filepath)
484
527
  try:
485
- csv_data = pd.read_csv(filepath)
528
+ if as_pd:
529
+ csv_data = pd.read_csv(filepath)
530
+ else:
531
+ with FileOpen(filepath, 'r', encoding='utf-8-sig') as f:
532
+ csv_reader = csv.reader(f, delimiter=',')
533
+ csv_data = list(csv_reader)
486
534
  except Exception as e:
487
535
  logger.error(f"The csv file failed to load. Please check the path: {filepath}.")
488
536
  raise RuntimeError(f"Read csv file {filepath} failed.") from e
489
537
  return csv_data
490
538
 
491
539
 
540
+ def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False):
541
+ def csv_value_is_valid(value: str) -> bool:
542
+ if not isinstance(value, str):
543
+ return True
544
+ try:
545
+ # -1.00 or +1.00 should be consdiered as digit numbers
546
+ float(value)
547
+ except ValueError:
548
+ # otherwise, they will be considered as formular injections
549
+ return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
550
+ return True
551
+
552
+ if not isinstance(data, pd.DataFrame):
553
+ raise ValueError("The data type of data is not supported. Only support pd.DataFrame.")
554
+
555
+ if malicious_check:
556
+ for i in range(len(data)):
557
+ for j in range(len(data.columns)):
558
+ cell = data.iloc[i, j]
559
+ if not csv_value_is_valid(cell):
560
+ raise RuntimeError(f"Malicious value [{cell}] is not allowed "
561
+ f"to be written into the csv: {filepath}.")
562
+
563
+ check_path_before_create(filepath)
564
+ file_path = os.path.realpath(filepath)
565
+ try:
566
+ data.to_csv(filepath, mode=mode, header=header, index=False)
567
+ except Exception as e:
568
+ logger.error(f'Save csv file "{os.path.basename(file_path)}" failed')
569
+ raise RuntimeError(f"Save csv file {file_path} failed.") from e
570
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
571
+
572
+
492
573
  def remove_path(path):
493
574
  if not os.path.exists(path):
494
575
  return
@@ -521,3 +602,46 @@ def get_json_contents(file_path):
521
602
  def get_file_content_bytes(file):
522
603
  with FileOpen(file, 'rb') as file_handle:
523
604
  return file_handle.read()
605
+
606
+
607
+ # 对os.walk设置遍历深度
608
+ def os_walk_for_files(path, depth):
609
+ res = []
610
+ for root, _, files in os.walk(path, topdown=True):
611
+ check_file_or_directory_path(root, isdir=True)
612
+ if root.count(os.sep) - path.count(os.sep) >= depth:
613
+ _[:] = []
614
+ else:
615
+ for file in files:
616
+ res.append({"file": file, "root": root})
617
+ return res
618
+
619
+
620
+ def check_crt_valid(pem_path):
621
+ """
622
+ Check the validity of the SSL certificate.
623
+
624
+ Load the SSL certificate from the specified path, parse and check its validity period.
625
+ If the certificate is expired or invalid, raise a RuntimeError.
626
+
627
+ Parameters:
628
+ pem_path (str): The file path of the SSL certificate.
629
+
630
+ Raises:
631
+ RuntimeError: If the SSL certificate is invalid or expired.
632
+ """
633
+ try:
634
+ with FileOpen(pem_path, "r") as f:
635
+ pem_data = f.read()
636
+ cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pem_data)
637
+ pem_start = parser.parse(cert.get_notBefore().decode("UTF-8"))
638
+ pem_end = parser.parse(cert.get_notAfter().decode("UTF-8"))
639
+ logger.info(f"The SSL certificate passes the verification and the validity period "
640
+ f"starts from {pem_start} ends at {pem_end}.")
641
+ except Exception as e:
642
+ logger.error("Failed to parse the SSL certificate. Check the certificate.")
643
+ raise RuntimeError(f"The SSL certificate is invalid, {pem_path}") from e
644
+
645
+ now_utc = datetime.now(tz=timezone.utc)
646
+ if cert.has_expired() or not (pem_start <= now_utc <= pem_end):
647
+ raise RuntimeError(f"The SSL certificate has expired and needs to be replaced, {pem_path}")
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  from msprobe.core.common.file_utils import load_yaml
3
18