mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,13 +12,16 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
17
16
  import csv
18
17
  import fcntl
19
18
  import os
19
+ import stat
20
20
  import json
21
21
  import re
22
22
  import shutil
23
+ from datetime import datetime, timezone
24
+ from dateutil import parser
23
25
  import yaml
24
26
  import numpy as np
25
27
  import pandas as pd
@@ -67,9 +69,11 @@ class FileChecker:
67
69
  self.check_path_ability()
68
70
  if self.is_script:
69
71
  check_path_owner_consistent(self.file_path)
70
- check_path_pattern_vaild(self.file_path)
72
+ check_path_pattern_valid(self.file_path)
71
73
  check_common_file_size(self.file_path)
72
74
  check_file_suffix(self.file_path, self.file_type)
75
+ if self.path_type == FileCheckConst.FILE:
76
+ check_dirpath_before_read(self.file_path)
73
77
  return self.file_path
74
78
 
75
79
  def check_path_ability(self):
@@ -122,9 +126,10 @@ class FileOpen:
122
126
  self.file_path = os.path.realpath(self.file_path)
123
127
  check_path_length(self.file_path)
124
128
  self.check_ability_and_owner()
125
- check_path_pattern_vaild(self.file_path)
129
+ check_path_pattern_valid(self.file_path)
126
130
  if os.path.exists(self.file_path):
127
131
  check_common_file_size(self.file_path)
132
+ check_dirpath_before_read(self.file_path)
128
133
 
129
134
  def check_ability_and_owner(self):
130
135
  if self.mode in self.SUPPORT_READ_MODE:
@@ -193,7 +198,7 @@ def check_path_owner_consistent(path):
193
198
  raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
194
199
 
195
200
 
196
- def check_path_pattern_vaild(path):
201
+ def check_path_pattern_valid(path):
197
202
  if not re.match(FileCheckConst.FILE_VALID_PATTERN, path):
198
203
  logger.error('The file path %s contains special characters.' % (path))
199
204
  raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
@@ -217,7 +222,6 @@ def check_common_file_size(file_path):
217
222
  check_file_size(file_path, max_size)
218
223
  return
219
224
  check_file_size(file_path, FileCheckConst.COMMOM_FILE_SIZE)
220
-
221
225
 
222
226
 
223
227
  def check_file_suffix(file_path, file_suffix):
@@ -238,9 +242,18 @@ def check_path_type(file_path, file_type):
238
242
  raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
239
243
 
240
244
 
245
+ def check_others_writable(directory):
246
+ dir_stat = os.stat(directory)
247
+ is_writable = (
248
+ bool(dir_stat.st_mode & stat.S_IWGRP) or # 组可写
249
+ bool(dir_stat.st_mode & stat.S_IWOTH) # 其他用户可写
250
+ )
251
+ return is_writable
252
+
253
+
241
254
  def make_dir(dir_path):
242
- dir_path = os.path.realpath(dir_path)
243
255
  check_path_before_create(dir_path)
256
+ dir_path = os.path.realpath(dir_path)
244
257
  if os.path.isdir(dir_path):
245
258
  return
246
259
  try:
@@ -262,8 +275,9 @@ def create_directory(dir_path):
262
275
  Exception Description:
263
276
  when invalid data throw exception
264
277
  """
265
- dir_path = os.path.realpath(dir_path)
278
+ check_link(dir_path)
266
279
  check_path_before_create(dir_path)
280
+ dir_path = os.path.realpath(dir_path)
267
281
  parent_dir = os.path.dirname(dir_path)
268
282
  if not os.path.isdir(parent_dir):
269
283
  create_directory(parent_dir)
@@ -271,6 +285,7 @@ def create_directory(dir_path):
271
285
 
272
286
 
273
287
  def check_path_before_create(path):
288
+ check_link(path)
274
289
  if path_len_exceeds_limit(path):
275
290
  raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, 'The file path length exceeds limit.')
276
291
 
@@ -279,6 +294,17 @@ def check_path_before_create(path):
279
294
  'The file path {} contains special characters.'.format(path))
280
295
 
281
296
 
297
+ def check_dirpath_before_read(path):
298
+ path = os.path.realpath(path)
299
+ dirpath = os.path.dirname(path)
300
+ if check_others_writable(dirpath):
301
+ logger.warning(f"The directory is writable by others: {dirpath}.")
302
+ try:
303
+ check_path_owner_consistent(dirpath)
304
+ except FileCheckException:
305
+ logger.warning(f"The directory {dirpath} is not yours.")
306
+
307
+
282
308
  def check_file_or_directory_path(path, isdir=False):
283
309
  """
284
310
  Function Description:
@@ -344,7 +370,7 @@ def load_yaml(yaml_path):
344
370
  def load_npy(filepath):
345
371
  check_file_or_directory_path(filepath)
346
372
  try:
347
- npy = np.load(filepath)
373
+ npy = np.load(filepath, allow_pickle=False)
348
374
  except Exception as e:
349
375
  logger.error(f"The numpy file failed to load. Please check the path: {filepath}.")
350
376
  raise RuntimeError(f"Load numpy file {filepath} failed.") from e
@@ -354,7 +380,7 @@ def load_npy(filepath):
354
380
  def load_json(json_path):
355
381
  try:
356
382
  with FileOpen(json_path, "r") as f:
357
- fcntl.flock(f, fcntl.LOCK_EX)
383
+ fcntl.flock(f, fcntl.LOCK_SH)
358
384
  data = json.load(f)
359
385
  fcntl.flock(f, fcntl.LOCK_UN)
360
386
  except Exception as e:
@@ -363,11 +389,11 @@ def load_json(json_path):
363
389
  return data
364
390
 
365
391
 
366
- def save_json(json_path, data, indent=None):
367
- json_path = os.path.realpath(json_path)
392
+ def save_json(json_path, data, indent=None, mode="w"):
368
393
  check_path_before_create(json_path)
394
+ json_path = os.path.realpath(json_path)
369
395
  try:
370
- with FileOpen(json_path, 'w') as f:
396
+ with FileOpen(json_path, mode) as f:
371
397
  fcntl.flock(f, fcntl.LOCK_EX)
372
398
  json.dump(data, f, indent=indent)
373
399
  fcntl.flock(f, fcntl.LOCK_UN)
@@ -378,8 +404,8 @@ def save_json(json_path, data, indent=None):
378
404
 
379
405
 
380
406
  def save_yaml(yaml_path, data):
381
- yaml_path = os.path.realpath(yaml_path)
382
407
  check_path_before_create(yaml_path)
408
+ yaml_path = os.path.realpath(yaml_path)
383
409
  try:
384
410
  with FileOpen(yaml_path, 'w') as f:
385
411
  fcntl.flock(f, fcntl.LOCK_EX)
@@ -391,6 +417,37 @@ def save_yaml(yaml_path, data):
391
417
  change_mode(yaml_path, FileCheckConst.DATA_FILE_AUTHORITY)
392
418
 
393
419
 
420
+ def save_excel(path, data):
421
+ def validate_data(data):
422
+ """Validate that the data is a DataFrame or a list of (DataFrame, sheet_name) pairs."""
423
+ if isinstance(data, pd.DataFrame):
424
+ return "single"
425
+ elif isinstance(data, list):
426
+ if all(isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], pd.DataFrame) for item in data):
427
+ return "list"
428
+ raise ValueError("Data must be a DataFrame or a list of (DataFrame, sheet_name) pairs.")
429
+
430
+ check_path_before_create(path)
431
+ path = os.path.realpath(path)
432
+
433
+ # 验证数据类型
434
+ data_type = validate_data(data)
435
+
436
+ try:
437
+ if data_type == "single":
438
+ data.to_excel(path, index=False)
439
+ elif data_type == "list":
440
+ with pd.ExcelWriter(path) as writer:
441
+ for data_df, sheet_name in data:
442
+ data_df.to_excel(writer, sheet_name=sheet_name, index=False)
443
+ except Exception as e:
444
+ logger.error(f'Save excel file "{os.path.basename(path)}" failed.')
445
+ raise RuntimeError(f"Save excel file {path} failed.") from e
446
+ change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
447
+
448
+
449
+
450
+
394
451
  def move_file(src_path, dst_path):
395
452
  check_file_or_directory_path(src_path)
396
453
  check_path_before_create(dst_path)
@@ -403,8 +460,8 @@ def move_file(src_path, dst_path):
403
460
 
404
461
 
405
462
  def save_npy(data, filepath):
406
- filepath = os.path.realpath(filepath)
407
463
  check_path_before_create(filepath)
464
+ filepath = os.path.realpath(filepath)
408
465
  try:
409
466
  np.save(filepath, data)
410
467
  except Exception as e:
@@ -425,6 +482,7 @@ def save_npy_to_txt(data, dst_file='', align=0):
425
482
  pad_array = np.zeros((align - data.size % align,))
426
483
  data = np.append(data, pad_array)
427
484
  check_path_before_create(dst_file)
485
+ dst_file = os.path.realpath(dst_file)
428
486
  try:
429
487
  np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
430
488
  except Exception as e:
@@ -438,8 +496,8 @@ def save_workbook(workbook, file_path):
438
496
  workbook: 要保存的工作簿对象
439
497
  file_path: 文件保存路径
440
498
  """
441
- file_path = os.path.realpath(file_path)
442
499
  check_path_before_create(file_path)
500
+ file_path = os.path.realpath(file_path)
443
501
  try:
444
502
  workbook.save(file_path)
445
503
  except Exception as e:
@@ -451,7 +509,7 @@ def save_workbook(workbook, file_path):
451
509
  def write_csv(data, filepath, mode="a+", malicious_check=False):
452
510
  def csv_value_is_valid(value: str) -> bool:
453
511
  if not isinstance(value, str):
454
- return True
512
+ return True
455
513
  try:
456
514
  # -1.00 or +1.00 should be consdiered as digit numbers
457
515
  float(value)
@@ -459,16 +517,16 @@ def write_csv(data, filepath, mode="a+", malicious_check=False):
459
517
  # otherwise, they will be considered as formular injections
460
518
  return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
461
519
  return True
462
-
520
+
463
521
  if malicious_check:
464
522
  for row in data:
465
523
  for cell in row:
466
524
  if not csv_value_is_valid(cell):
467
- raise RuntimeError(f"Malicious value [{cell}] is not allowed " \
525
+ raise RuntimeError(f"Malicious value [{cell}] is not allowed "
468
526
  f"to be written into the csv: {filepath}.")
469
527
 
470
- file_path = os.path.realpath(filepath)
471
528
  check_path_before_create(filepath)
529
+ file_path = os.path.realpath(filepath)
472
530
  try:
473
531
  with FileOpen(filepath, mode, encoding='utf-8-sig') as f:
474
532
  writer = csv.writer(f)
@@ -479,16 +537,54 @@ def write_csv(data, filepath, mode="a+", malicious_check=False):
479
537
  change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
480
538
 
481
539
 
482
- def read_csv(filepath):
540
+ def read_csv(filepath, as_pd=True, header='infer'):
483
541
  check_file_or_directory_path(filepath)
484
542
  try:
485
- csv_data = pd.read_csv(filepath)
543
+ if as_pd:
544
+ csv_data = pd.read_csv(filepath, header=header)
545
+ else:
546
+ with FileOpen(filepath, 'r', encoding='utf-8-sig') as f:
547
+ csv_reader = csv.reader(f, delimiter=',')
548
+ csv_data = list(csv_reader)
486
549
  except Exception as e:
487
550
  logger.error(f"The csv file failed to load. Please check the path: {filepath}.")
488
551
  raise RuntimeError(f"Read csv file {filepath} failed.") from e
489
552
  return csv_data
490
553
 
491
554
 
555
+ def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False):
556
+ def csv_value_is_valid(value: str) -> bool:
557
+ if not isinstance(value, str):
558
+ return True
559
+ try:
560
+ # -1.00 or +1.00 should be consdiered as digit numbers
561
+ float(value)
562
+ except ValueError:
563
+ # otherwise, they will be considered as formular injections
564
+ return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
565
+ return True
566
+
567
+ if not isinstance(data, pd.DataFrame):
568
+ raise ValueError("The data type of data is not supported. Only support pd.DataFrame.")
569
+
570
+ if malicious_check:
571
+ for i in range(len(data)):
572
+ for j in range(len(data.columns)):
573
+ cell = data.iloc[i, j]
574
+ if not csv_value_is_valid(cell):
575
+ raise RuntimeError(f"Malicious value [{cell}] is not allowed "
576
+ f"to be written into the csv: {filepath}.")
577
+
578
+ check_path_before_create(filepath)
579
+ file_path = os.path.realpath(filepath)
580
+ try:
581
+ data.to_csv(filepath, mode=mode, header=header, index=False)
582
+ except Exception as e:
583
+ logger.error(f'Save csv file "{os.path.basename(file_path)}" failed')
584
+ raise RuntimeError(f"Save csv file {file_path} failed.") from e
585
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
586
+
587
+
492
588
  def remove_path(path):
493
589
  if not os.path.exists(path):
494
590
  return
@@ -521,3 +617,57 @@ def get_json_contents(file_path):
521
617
  def get_file_content_bytes(file):
522
618
  with FileOpen(file, 'rb') as file_handle:
523
619
  return file_handle.read()
620
+
621
+
622
+ # 对os.walk设置遍历深度
623
+ def os_walk_for_files(path, depth):
624
+ res = []
625
+ for root, _, files in os.walk(path, topdown=True):
626
+ check_file_or_directory_path(root, isdir=True)
627
+ if root.count(os.sep) - path.count(os.sep) >= depth:
628
+ _[:] = []
629
+ else:
630
+ for file in files:
631
+ res.append({"file": file, "root": root})
632
+ return res
633
+
634
+
635
+ def check_crt_valid(pem_path):
636
+ """
637
+ Check the validity of the SSL certificate.
638
+
639
+ Load the SSL certificate from the specified path, parse and check its validity period.
640
+ If the certificate is expired or invalid, raise a RuntimeError.
641
+
642
+ Parameters:
643
+ pem_path (str): The file path of the SSL certificate.
644
+
645
+ Raises:
646
+ RuntimeError: If the SSL certificate is invalid or expired.
647
+ """
648
+ import OpenSSL
649
+ try:
650
+ with FileOpen(pem_path, "r") as f:
651
+ pem_data = f.read()
652
+ cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pem_data)
653
+ pem_start = parser.parse(cert.get_notBefore().decode("UTF-8"))
654
+ pem_end = parser.parse(cert.get_notAfter().decode("UTF-8"))
655
+ logger.info(f"The SSL certificate passes the verification and the validity period "
656
+ f"starts from {pem_start} ends at {pem_end}.")
657
+ except Exception as e:
658
+ logger.error("Failed to parse the SSL certificate. Check the certificate.")
659
+ raise RuntimeError(f"The SSL certificate is invalid, {pem_path}") from e
660
+
661
+ now_utc = datetime.now(tz=timezone.utc)
662
+ if cert.has_expired() or not (pem_start <= now_utc <= pem_end):
663
+ raise RuntimeError(f"The SSL certificate has expired and needs to be replaced, {pem_path}")
664
+
665
+
666
+ def read_xlsx(file_path):
667
+ check_file_or_directory_path(file_path)
668
+ try:
669
+ result_df = pd.read_excel(file_path, keep_default_na=False)
670
+ except Exception as e:
671
+ logger.error(f"The xlsx file failed to load. Please check the path: {file_path}.")
672
+ raise RuntimeError(f"Read xlsx file {file_path} failed.") from e
673
+ return result_df
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  from msprobe.core.common.file_utils import load_yaml
3
18
 
@@ -157,6 +157,9 @@ inplace_tensor_op:
157
157
  - trunc_
158
158
  - unsqueeze_
159
159
  - xlogy_
160
+ - bitwise_left_shift_
161
+ - bitwise_right_shift_
162
+ - arctan2_
160
163
 
161
164
  inplace_torch_op:
162
165
  - _add_relu_
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  import time
3
18
  import sys
@@ -5,6 +20,16 @@ from functools import wraps
5
20
  from msprobe.core.common.const import MsgConst
6
21
 
7
22
 
23
+ def filter_special_chars(func):
24
+ @wraps(func)
25
+ def func_level(self, msg, **kwargs):
26
+ for char in MsgConst.SPECIAL_CHAR:
27
+ msg = msg.replace(char, '_')
28
+ return func(self, msg, **kwargs)
29
+
30
+ return func_level
31
+
32
+
8
33
  class BaseLogger:
9
34
  def __init__(self):
10
35
  self.rank = None
@@ -21,14 +46,6 @@ class BaseLogger:
21
46
  def get_rank(self):
22
47
  return self.rank
23
48
 
24
- def filter_special_chars(func):
25
- @wraps(func)
26
- def func_level(self, msg, **kwargs):
27
- for char in MsgConst.SPECIAL_CHAR:
28
- msg = msg.replace(char, '_')
29
- return func(self, msg, **kwargs)
30
- return func_level
31
-
32
49
  @filter_special_chars
33
50
  def error(self, msg):
34
51
  if self.level <= MsgConst.LogLevel.ERROR.value:
@@ -56,6 +73,7 @@ class BaseLogger:
56
73
  return func(*args, **kwargs)
57
74
  else:
58
75
  return None
76
+
59
77
  return func_rank_0
60
78
 
61
79
  def info_on_rank_0(self, msg):
@@ -66,7 +84,7 @@ class BaseLogger:
66
84
 
67
85
  def warning_on_rank_0(self, msg):
68
86
  return self.on_rank_0(self.warning)(msg)
69
-
87
+
70
88
  def error_log_with_exp(self, msg, exception):
71
89
  self.error(msg)
72
90
  raise exception