mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,15 +12,20 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
17
16
  import csv
18
17
  import fcntl
19
18
  import os
19
+ import stat
20
20
  import json
21
21
  import re
22
22
  import shutil
23
+ from datetime import datetime, timezone
24
+ from dateutil import parser
25
+ import OpenSSL
23
26
  import yaml
24
27
  import numpy as np
28
+ import pandas as pd
25
29
 
26
30
  from msprobe.core.common.log import logger
27
31
  from msprobe.core.common.exceptions import FileCheckException
@@ -66,9 +70,11 @@ class FileChecker:
66
70
  self.check_path_ability()
67
71
  if self.is_script:
68
72
  check_path_owner_consistent(self.file_path)
69
- check_path_pattern_vaild(self.file_path)
73
+ check_path_pattern_valid(self.file_path)
70
74
  check_common_file_size(self.file_path)
71
75
  check_file_suffix(self.file_path, self.file_type)
76
+ if self.path_type == FileCheckConst.FILE:
77
+ check_dirpath_before_read(self.file_path)
72
78
  return self.file_path
73
79
 
74
80
  def check_path_ability(self):
@@ -121,9 +127,10 @@ class FileOpen:
121
127
  self.file_path = os.path.realpath(self.file_path)
122
128
  check_path_length(self.file_path)
123
129
  self.check_ability_and_owner()
124
- check_path_pattern_vaild(self.file_path)
130
+ check_path_pattern_valid(self.file_path)
125
131
  if os.path.exists(self.file_path):
126
132
  check_common_file_size(self.file_path)
133
+ check_dirpath_before_read(self.file_path)
127
134
 
128
135
  def check_ability_and_owner(self):
129
136
  if self.mode in self.SUPPORT_READ_MODE:
@@ -187,12 +194,12 @@ def check_other_user_writable(path):
187
194
 
188
195
  def check_path_owner_consistent(path):
189
196
  file_owner = os.stat(path).st_uid
190
- if file_owner != os.getuid():
197
+ if file_owner != os.getuid() and os.getuid() != 0:
191
198
  logger.error('The file path %s may be insecure because is does not belong to you.' % path)
192
199
  raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
193
200
 
194
201
 
195
- def check_path_pattern_vaild(path):
202
+ def check_path_pattern_valid(path):
196
203
  if not re.match(FileCheckConst.FILE_VALID_PATTERN, path):
197
204
  logger.error('The file path %s contains special characters.' % (path))
198
205
  raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
@@ -214,7 +221,8 @@ def check_common_file_size(file_path):
214
221
  for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items():
215
222
  if file_path.endswith(suffix):
216
223
  check_file_size(file_path, max_size)
217
- break
224
+ return
225
+ check_file_size(file_path, FileCheckConst.COMMOM_FILE_SIZE)
218
226
 
219
227
 
220
228
  def check_file_suffix(file_path, file_suffix):
@@ -235,9 +243,18 @@ def check_path_type(file_path, file_type):
235
243
  raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
236
244
 
237
245
 
246
+ def check_others_writable(directory):
247
+ dir_stat = os.stat(directory)
248
+ is_writable = (
249
+ bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写
250
+ bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写
251
+ )
252
+ return is_writable
253
+
254
+
238
255
  def make_dir(dir_path):
239
- dir_path = os.path.realpath(dir_path)
240
256
  check_path_before_create(dir_path)
257
+ dir_path = os.path.realpath(dir_path)
241
258
  if os.path.isdir(dir_path):
242
259
  return
243
260
  try:
@@ -259,8 +276,9 @@ def create_directory(dir_path):
259
276
  Exception Description:
260
277
  when invalid data throw exception
261
278
  """
262
- dir_path = os.path.realpath(dir_path)
279
+ check_link(dir_path)
263
280
  check_path_before_create(dir_path)
281
+ dir_path = os.path.realpath(dir_path)
264
282
  parent_dir = os.path.dirname(dir_path)
265
283
  if not os.path.isdir(parent_dir):
266
284
  create_directory(parent_dir)
@@ -268,6 +286,7 @@ def create_directory(dir_path):
268
286
 
269
287
 
270
288
  def check_path_before_create(path):
289
+ check_link(path)
271
290
  if path_len_exceeds_limit(path):
272
291
  raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, 'The file path length exceeds limit.')
273
292
 
@@ -276,6 +295,17 @@ def check_path_before_create(path):
276
295
  'The file path {} contains special characters.'.format(path))
277
296
 
278
297
 
298
+ def check_dirpath_before_read(path):
299
+ path = os.path.realpath(path)
300
+ dirpath = os.path.dirname(path)
301
+ if check_others_writable(dirpath):
302
+ logger.warning(f"The directory is writable by others: {dirpath}.")
303
+ try:
304
+ check_path_owner_consistent(dirpath)
305
+ except FileCheckException:
306
+ logger.warning(f"The directory {dirpath} is not yours.")
307
+
308
+
279
309
  def check_file_or_directory_path(path, isdir=False):
280
310
  """
281
311
  Function Description:
@@ -322,7 +352,7 @@ def check_file_type(path):
322
352
  elif os.path.isfile(path):
323
353
  return FileCheckConst.FILE
324
354
  else:
325
- logger.error('Neither a file nor a directory.')
355
+ logger.error(f'{path} does not exist, please check!')
326
356
  raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
327
357
 
328
358
 
@@ -338,10 +368,10 @@ def load_yaml(yaml_path):
338
368
  return yaml_data
339
369
 
340
370
 
341
- def load_npy(filepath, enable_pickle=False):
371
+ def load_npy(filepath):
342
372
  check_file_or_directory_path(filepath)
343
373
  try:
344
- npy = np.load(filepath, allow_pickle=enable_pickle)
374
+ npy = np.load(filepath, allow_pickle=False)
345
375
  except Exception as e:
346
376
  logger.error(f"The numpy file failed to load. Please check the path: {filepath}.")
347
377
  raise RuntimeError(f"Load numpy file {filepath} failed.") from e
@@ -351,7 +381,7 @@ def load_npy(filepath, enable_pickle=False):
351
381
  def load_json(json_path):
352
382
  try:
353
383
  with FileOpen(json_path, "r") as f:
354
- fcntl.flock(f, fcntl.LOCK_EX)
384
+ fcntl.flock(f, fcntl.LOCK_SH)
355
385
  data = json.load(f)
356
386
  fcntl.flock(f, fcntl.LOCK_UN)
357
387
  except Exception as e:
@@ -360,11 +390,11 @@ def load_json(json_path):
360
390
  return data
361
391
 
362
392
 
363
- def save_json(json_path, data, indent=None):
364
- json_path = os.path.realpath(json_path)
393
+ def save_json(json_path, data, indent=None, mode="w"):
365
394
  check_path_before_create(json_path)
395
+ json_path = os.path.realpath(json_path)
366
396
  try:
367
- with FileOpen(json_path, 'w') as f:
397
+ with FileOpen(json_path, mode) as f:
368
398
  fcntl.flock(f, fcntl.LOCK_EX)
369
399
  json.dump(data, f, indent=indent)
370
400
  fcntl.flock(f, fcntl.LOCK_UN)
@@ -374,6 +404,35 @@ def save_json(json_path, data, indent=None):
374
404
  change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY)
375
405
 
376
406
 
407
+ def save_yaml(yaml_path, data):
408
+ check_path_before_create(yaml_path)
409
+ yaml_path = os.path.realpath(yaml_path)
410
+ try:
411
+ with FileOpen(yaml_path, 'w') as f:
412
+ fcntl.flock(f, fcntl.LOCK_EX)
413
+ yaml.dump(data, f, sort_keys=False)
414
+ fcntl.flock(f, fcntl.LOCK_UN)
415
+ except Exception as e:
416
+ logger.error(f'Save yaml file "{os.path.basename(yaml_path)}" failed.')
417
+ raise RuntimeError(f"Save yaml file {yaml_path} failed.") from e
418
+ change_mode(yaml_path, FileCheckConst.DATA_FILE_AUTHORITY)
419
+
420
+
421
+ def save_excel(path, data):
422
+ check_path_before_create(path)
423
+ path = os.path.realpath(path)
424
+ try:
425
+ if isinstance(data, pd.DataFrame):
426
+ data.to_excel(path, index=False)
427
+ else:
428
+ logger.error(f'unsupported data type.')
429
+ return
430
+ except Exception as e:
431
+ logger.error(f'Save excel file "{os.path.basename(path)}" failed.')
432
+ raise RuntimeError(f"Save excel file {path} failed.") from e
433
+ change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
434
+
435
+
377
436
  def move_file(src_path, dst_path):
378
437
  check_file_or_directory_path(src_path)
379
438
  check_path_before_create(dst_path)
@@ -386,8 +445,8 @@ def move_file(src_path, dst_path):
386
445
 
387
446
 
388
447
  def save_npy(data, filepath):
389
- filepath = os.path.realpath(filepath)
390
448
  check_path_before_create(filepath)
449
+ filepath = os.path.realpath(filepath)
391
450
  try:
392
451
  np.save(filepath, data)
393
452
  except Exception as e:
@@ -396,9 +455,9 @@ def save_npy(data, filepath):
396
455
  change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
397
456
 
398
457
 
399
- def save_npy_to_txt(self, data, dst_file='', align=0):
458
+ def save_npy_to_txt(data, dst_file='', align=0):
400
459
  if os.path.exists(dst_file):
401
- self.log.info("Dst file %s exists, will not save new one.", dst_file)
460
+ logger.info("Dst file %s exists, will not save new one." % dst_file)
402
461
  return
403
462
  shape = data.shape
404
463
  data = data.flatten()
@@ -408,10 +467,11 @@ def save_npy_to_txt(self, data, dst_file='', align=0):
408
467
  pad_array = np.zeros((align - data.size % align,))
409
468
  data = np.append(data, pad_array)
410
469
  check_path_before_create(dst_file)
470
+ dst_file = os.path.realpath(dst_file)
411
471
  try:
412
472
  np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
413
473
  except Exception as e:
414
- self.log.error("An unexpected error occurred: %s when savetxt to %s" % (str(e)), dst_file)
474
+ logger.error("An unexpected error occurred: %s when savetxt to %s" % (str(e), dst_file))
415
475
  change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY)
416
476
 
417
477
 
@@ -421,8 +481,8 @@ def save_workbook(workbook, file_path):
421
481
  workbook: 要保存的工作簿对象
422
482
  file_path: 文件保存路径
423
483
  """
424
- file_path = os.path.realpath(file_path)
425
484
  check_path_before_create(file_path)
485
+ file_path = os.path.realpath(file_path)
426
486
  try:
427
487
  workbook.save(file_path)
428
488
  except Exception as e:
@@ -431,9 +491,27 @@ def save_workbook(workbook, file_path):
431
491
  change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
432
492
 
433
493
 
434
- def write_csv(data, filepath, mode="a+"):
435
- file_path = os.path.realpath(filepath)
494
+ def write_csv(data, filepath, mode="a+", malicious_check=False):
495
+ def csv_value_is_valid(value: str) -> bool:
496
+ if not isinstance(value, str):
497
+ return True
498
+ try:
499
+ # -1.00 or +1.00 should be consdiered as digit numbers
500
+ float(value)
501
+ except ValueError:
502
+ # otherwise, they will be considered as formular injections
503
+ return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
504
+ return True
505
+
506
+ if malicious_check:
507
+ for row in data:
508
+ for cell in row:
509
+ if not csv_value_is_valid(cell):
510
+ raise RuntimeError(f"Malicious value [{cell}] is not allowed "
511
+ f"to be written into the csv: {filepath}.")
512
+
436
513
  check_path_before_create(filepath)
514
+ file_path = os.path.realpath(filepath)
437
515
  try:
438
516
  with FileOpen(filepath, mode, encoding='utf-8-sig') as f:
439
517
  writer = csv.writer(f)
@@ -444,6 +522,54 @@ def write_csv(data, filepath, mode="a+"):
444
522
  change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
445
523
 
446
524
 
525
+ def read_csv(filepath, as_pd=True):
526
+ check_file_or_directory_path(filepath)
527
+ try:
528
+ if as_pd:
529
+ csv_data = pd.read_csv(filepath)
530
+ else:
531
+ with FileOpen(filepath, 'r', encoding='utf-8-sig') as f:
532
+ csv_reader = csv.reader(f, delimiter=',')
533
+ csv_data = list(csv_reader)
534
+ except Exception as e:
535
+ logger.error(f"The csv file failed to load. Please check the path: {filepath}.")
536
+ raise RuntimeError(f"Read csv file {filepath} failed.") from e
537
+ return csv_data
538
+
539
+
540
+ def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False):
541
+ def csv_value_is_valid(value: str) -> bool:
542
+ if not isinstance(value, str):
543
+ return True
544
+ try:
545
+ # -1.00 or +1.00 should be consdiered as digit numbers
546
+ float(value)
547
+ except ValueError:
548
+ # otherwise, they will be considered as formular injections
549
+ return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
550
+ return True
551
+
552
+ if not isinstance(data, pd.DataFrame):
553
+ raise ValueError("The data type of data is not supported. Only support pd.DataFrame.")
554
+
555
+ if malicious_check:
556
+ for i in range(len(data)):
557
+ for j in range(len(data.columns)):
558
+ cell = data.iloc[i, j]
559
+ if not csv_value_is_valid(cell):
560
+ raise RuntimeError(f"Malicious value [{cell}] is not allowed "
561
+ f"to be written into the csv: {filepath}.")
562
+
563
+ check_path_before_create(filepath)
564
+ file_path = os.path.realpath(filepath)
565
+ try:
566
+ data.to_csv(filepath, mode=mode, header=header, index=False)
567
+ except Exception as e:
568
+ logger.error(f'Save csv file "{os.path.basename(file_path)}" failed')
569
+ raise RuntimeError(f"Save csv file {file_path} failed.") from e
570
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
571
+
572
+
447
573
  def remove_path(path):
448
574
  if not os.path.exists(path):
449
575
  return
@@ -476,3 +602,46 @@ def get_json_contents(file_path):
476
602
  def get_file_content_bytes(file):
477
603
  with FileOpen(file, 'rb') as file_handle:
478
604
  return file_handle.read()
605
+
606
+
607
+ # 对os.walk设置遍历深度
608
+ def os_walk_for_files(path, depth):
609
+ res = []
610
+ for root, _, files in os.walk(path, topdown=True):
611
+ check_file_or_directory_path(root, isdir=True)
612
+ if root.count(os.sep) - path.count(os.sep) >= depth:
613
+ _[:] = []
614
+ else:
615
+ for file in files:
616
+ res.append({"file": file, "root": root})
617
+ return res
618
+
619
+
620
+ def check_crt_valid(pem_path):
621
+ """
622
+ Check the validity of the SSL certificate.
623
+
624
+ Load the SSL certificate from the specified path, parse and check its validity period.
625
+ If the certificate is expired or invalid, raise a RuntimeError.
626
+
627
+ Parameters:
628
+ pem_path (str): The file path of the SSL certificate.
629
+
630
+ Raises:
631
+ RuntimeError: If the SSL certificate is invalid or expired.
632
+ """
633
+ try:
634
+ with FileOpen(pem_path, "r") as f:
635
+ pem_data = f.read()
636
+ cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pem_data)
637
+ pem_start = parser.parse(cert.get_notBefore().decode("UTF-8"))
638
+ pem_end = parser.parse(cert.get_notAfter().decode("UTF-8"))
639
+ logger.info(f"The SSL certificate passes the verification and the validity period "
640
+ f"starts from {pem_start} ends at {pem_end}.")
641
+ except Exception as e:
642
+ logger.error("Failed to parse the SSL certificate. Check the certificate.")
643
+ raise RuntimeError(f"The SSL certificate is invalid, {pem_path}") from e
644
+
645
+ now_utc = datetime.now(tz=timezone.utc)
646
+ if cert.has_expired() or not (pem_start <= now_utc <= pem_end):
647
+ raise RuntimeError(f"The SSL certificate has expired and needs to be replaced, {pem_path}")
@@ -0,0 +1,53 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from msprobe.core.common.file_utils import load_yaml
18
+
19
+
20
+ class InplaceOpChecker:
21
+ OP_FUNCTIONAL = 'functional'
22
+ OP_TENSOR = 'tensor'
23
+ OP_TORCH = 'torch'
24
+ OP_DISTRIBUTED = 'distributed'
25
+
26
+ INPLACE_OPS_DICT = None
27
+
28
+ @classmethod
29
+ def load_ops(cls):
30
+ if cls.INPLACE_OPS_DICT is None:
31
+ cls.INPLACE_OPS_DICT = dict()
32
+ cur_path = os.path.dirname(os.path.realpath(__file__))
33
+ yaml_path = os.path.join(cur_path, "inplace_ops.yaml")
34
+ all_ops = load_yaml(yaml_path)
35
+ cls.INPLACE_OPS_DICT[cls.OP_FUNCTIONAL] = all_ops.get('inplace_functional_op')
36
+ cls.INPLACE_OPS_DICT[cls.OP_TENSOR] = all_ops.get('inplace_tensor_op')
37
+ cls.INPLACE_OPS_DICT[cls.OP_TORCH] = all_ops.get('inplace_torch_op')
38
+ cls.INPLACE_OPS_DICT[cls.OP_DISTRIBUTED] = all_ops.get('inplace_distributed_op')
39
+
40
+ @classmethod
41
+ def check(cls, api, category='distributed'):
42
+ """
43
+ 给定api和分类,检查其是否为inplace操作
44
+ """
45
+ if not cls.INPLACE_OPS_DICT:
46
+ cls.load_ops()
47
+
48
+ if category not in cls.INPLACE_OPS_DICT.keys():
49
+ return False
50
+ return api in cls.INPLACE_OPS_DICT[category]
51
+
52
+
53
+ InplaceOpChecker.load_ops()
@@ -0,0 +1,251 @@
1
+ inplace_functional_op:
2
+ - threshold_
3
+ - relu_
4
+ - hardtanh_
5
+ - elu_
6
+ - selu_
7
+ - celu_
8
+ - leaky_relu_
9
+ - rrelu_
10
+
11
+ inplace_tensor_op:
12
+ - __iadd__
13
+ - __iand__
14
+ - __idiv__
15
+ - __ifloordiv__
16
+ - __ilshift__
17
+ - __imod__
18
+ - __imul__
19
+ - __ior__
20
+ - __irshift__
21
+ - __isub__
22
+ - __ixor__
23
+ - abs_
24
+ - absolute_
25
+ - acos_
26
+ - acosh_
27
+ - add_
28
+ - addbmm_
29
+ - addcdiv_
30
+ - addcmul_
31
+ - addmm_
32
+ - addmv_
33
+ - addr_
34
+ - arccos_
35
+ - arccosh_
36
+ - arcsin_
37
+ - arcsinh_
38
+ - arctan_
39
+ - arctanh_
40
+ - asin_
41
+ - asinh_
42
+ - atan2_
43
+ - atan_
44
+ - atanh_
45
+ - baddbmm_
46
+ - bernoulli_
47
+ - bitwise_and_
48
+ - bitwise_not_
49
+ - bitwise_or_
50
+ - bitwise_xor_
51
+ - cauchy_
52
+ - ceil_
53
+ - clamp_
54
+ - clamp_max_
55
+ - clamp_min_
56
+ - clip_
57
+ - copysign_
58
+ - cos_
59
+ - cosh_
60
+ - cumprod_
61
+ - cumsum_
62
+ - deg2rad_
63
+ - digamma_
64
+ - div_
65
+ - divide_
66
+ - eq_
67
+ - erf_
68
+ - erfc_
69
+ - erfinv_
70
+ - exp2_
71
+ - exp_
72
+ - expm1_
73
+ - exponential_
74
+ - fill_
75
+ - fill_diagonal_
76
+ - fix_
77
+ - float_power_
78
+ - floor_
79
+ - floor_divide_
80
+ - fmod_
81
+ - frac_
82
+ - gcd_
83
+ - ge_
84
+ - geometric_
85
+ - greater_
86
+ - gt_
87
+ - greater_equal_
88
+ - heaviside_
89
+ - hypot_
90
+ - igamma_
91
+ - igammac_
92
+ - index_add_
93
+ - index_copy_
94
+ - index_fill_
95
+ - index_put_
96
+ - lcm_
97
+ - ldexp_
98
+ - le_
99
+ - lerp_
100
+ - less_
101
+ - less_equal_
102
+ - lgamma_
103
+ - log10_
104
+ - log1p_
105
+ - log2_
106
+ - log_
107
+ - log_normal_
108
+ - logical_and_
109
+ - logical_not_
110
+ - logical_or_
111
+ - logical_xor_
112
+ - logit_
113
+ - lt_
114
+ - map2_
115
+ - map_
116
+ - masked_fill_
117
+ - masked_scatter_
118
+ - mul_
119
+ - multiply_
120
+ - mvlgamma_
121
+ - ne_
122
+ - neg_
123
+ - negative_
124
+ - normal_
125
+ - not_equal_
126
+ - pow_
127
+ - polygamma_
128
+ - put_
129
+ - rad2deg_
130
+ - reciprocal_
131
+ - relu_
132
+ - remainder_
133
+ - renorm_
134
+ - resize_
135
+ - resize_as_
136
+ - round_
137
+ - rsqrt_
138
+ - scatter_
139
+ - scatter_add_
140
+ - sgn_
141
+ - sigmoid_
142
+ - sign_
143
+ - sin_
144
+ - sinc_
145
+ - sinh_
146
+ - sqrt_
147
+ - square_
148
+ - squeeze_
149
+ - sub_
150
+ - t_
151
+ - tan_
152
+ - tanh_
153
+ - transpose_
154
+ - tril_
155
+ - triu_
156
+ - true_divide_
157
+ - trunc_
158
+ - unsqueeze_
159
+ - xlogy_
160
+
161
+ inplace_torch_op:
162
+ - _add_relu_
163
+ - abs_
164
+ - acos_
165
+ - acosh_
166
+ - addmv_
167
+ - alpha_dropout_
168
+ - arccos_
169
+ - arccosh_
170
+ - arcsin_
171
+ - arcsinh_
172
+ - arctan_
173
+ - arctanh_
174
+ - asin_
175
+ - asinh_
176
+ - atan_
177
+ - atanh_
178
+ - ceil_
179
+ - celu_
180
+ - clamp_
181
+ - clamp_max_
182
+ - clamp_min_
183
+ - clip_
184
+ - cos_
185
+ - cosh_
186
+ - deg2rad_
187
+ - dropout_
188
+ - embedding_renorm_
189
+ - erf_
190
+ - erfc_
191
+ - exp2_
192
+ - exp_
193
+ - expm1_
194
+ - feature_alpha_dropout_
195
+ - feature_dropout_
196
+ - fill_
197
+ - fix_
198
+ - floor_
199
+ - frac_
200
+ - gcd_
201
+ - index_put_
202
+ - lcm_
203
+ - ldexp_
204
+ - log10_
205
+ - log1p_
206
+ - log2_
207
+ - log_
208
+ - logit_
209
+ - nan_to_num_
210
+ - neg_
211
+ - negative_
212
+ - rad2deg_
213
+ - reciprocal_
214
+ - relu_
215
+ - resize_as_
216
+ - round_
217
+ - rrelu_
218
+ - rsqrt_
219
+ - selu_
220
+ - sigmoid_
221
+ - sin_
222
+ - sinc_
223
+ - sinh_
224
+ - sqrt_
225
+ - square_
226
+ - tan_
227
+ - tanh_
228
+ - threshold_
229
+ - trunc_
230
+ - xlogy_
231
+
232
+ inplace_distributed_op:
233
+ - broadcast
234
+ - all_reduce
235
+ - reduce
236
+ - all_gather
237
+ - gather
238
+ - scatter
239
+ - reduce_scatter
240
+ - _reduce_scatter_base
241
+ - _all_gather_base
242
+ - send
243
+ - recv
244
+ - irecv
245
+ - isend
246
+ - all_to_all_single
247
+ - all_to_all
248
+ - all_gather_into_tensor
249
+ - reduce_scatter_tensor
250
+
251
+