mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -12,16 +12,20 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
-
15
+ import atexit
16
16
  import csv
17
17
  import fcntl
18
+ import io
18
19
  import os
20
+ import pickle
21
+ from multiprocessing import shared_memory
19
22
  import stat
20
23
  import json
21
24
  import re
22
25
  import shutil
23
- from datetime import datetime, timezone
24
- from dateutil import parser
26
+ import sys
27
+ import zipfile
28
+ import multiprocessing
25
29
  import yaml
26
30
  import numpy as np
27
31
  import pandas as pd
@@ -29,7 +33,10 @@ import pandas as pd
29
33
  from msprobe.core.common.decorator import recursion_depth_decorator
30
34
  from msprobe.core.common.log import logger
31
35
  from msprobe.core.common.exceptions import FileCheckException
32
- from msprobe.core.common.const import FileCheckConst
36
+ from msprobe.core.common.const import FileCheckConst, CompareConst
37
+ from msprobe.core.common.global_lock import global_lock, is_main_process
38
+
39
+ proc_lock = multiprocessing.Lock()
33
40
 
34
41
 
35
42
  class FileChecker:
@@ -165,6 +172,12 @@ def check_path_exists(path):
165
172
  if not os.path.exists(path):
166
173
  logger.error('The file path %s does not exist.' % path)
167
174
  raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
175
+
176
+
177
+ def check_path_not_exists(path):
178
+ if os.path.exists(path):
179
+ logger.error('The file path %s already exist.' % path)
180
+ raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR)
168
181
 
169
182
 
170
183
  def check_path_readability(path):
@@ -299,12 +312,13 @@ def check_path_before_create(path):
299
312
  def check_dirpath_before_read(path):
300
313
  path = os.path.realpath(path)
301
314
  dirpath = os.path.dirname(path)
302
- if check_others_writable(dirpath):
303
- logger.warning(f"The directory is writable by others: {dirpath}.")
304
- try:
305
- check_path_owner_consistent(dirpath)
306
- except FileCheckException:
307
- logger.warning(f"The directory {dirpath} is not yours.")
315
+ if dedup_log('check_dirpath_before_read', dirpath):
316
+ if check_others_writable(dirpath):
317
+ logger.warning(f"The directory is writable by others: {dirpath}.")
318
+ try:
319
+ check_path_owner_consistent(dirpath)
320
+ except FileCheckException:
321
+ logger.warning(f"The directory {dirpath} is not yours.")
308
322
 
309
323
 
310
324
  def check_file_or_directory_path(path, isdir=False):
@@ -446,6 +460,17 @@ def save_excel(path, data):
446
460
  return "list"
447
461
  raise ValueError("Data must be a DataFrame or a list of (DataFrame, sheet_name) pairs.")
448
462
 
463
+ def save_in_slice(df, base_name):
464
+ df_length = len(df)
465
+ if df_length < CompareConst.MAX_EXCEL_LENGTH:
466
+ df.to_excel(writer, sheet_name=base_name if base_name else 'Sheet1', index=False)
467
+ else:
468
+ slice_num = (df_length + CompareConst.MAX_EXCEL_LENGTH - 1) // CompareConst.MAX_EXCEL_LENGTH
469
+ slice_size = (df_length + slice_num - 1) // slice_num
470
+ for i in range(slice_num):
471
+ df.iloc[i * slice_size: min((i + 1) * slice_size, df_length)] \
472
+ .to_excel(writer, sheet_name=f'{base_name}_part_{i}' if base_name else f'part_{i}', index=False)
473
+
449
474
  check_path_before_create(path)
450
475
  path = os.path.realpath(path)
451
476
 
@@ -453,18 +478,27 @@ def save_excel(path, data):
453
478
  data_type = validate_data(data)
454
479
 
455
480
  try:
456
- if data_type == "single":
457
- data.to_excel(path, index=False)
458
- elif data_type == "list":
459
- with pd.ExcelWriter(path) as writer:
481
+ with pd.ExcelWriter(path) as writer:
482
+ if data_type == "single":
483
+ save_in_slice(data, None)
484
+ elif data_type == "list":
460
485
  for data_df, sheet_name in data:
461
- data_df.to_excel(writer, sheet_name=sheet_name, index=False)
486
+ save_in_slice(data_df, sheet_name)
462
487
  except Exception as e:
463
488
  logger.error(f'Save excel file "{os.path.basename(path)}" failed.')
464
489
  raise RuntimeError(f"Save excel file {path} failed.") from e
465
490
  change_mode(path, FileCheckConst.DATA_FILE_AUTHORITY)
466
491
 
467
492
 
493
+ def move_directory(src_path, dst_path):
494
+ check_file_or_directory_path(src_path, isdir=True)
495
+ check_path_before_create(dst_path)
496
+ try:
497
+ shutil.move(src_path, dst_path)
498
+ except Exception as e:
499
+ logger.error(f"move directory {src_path} to {dst_path} failed")
500
+ raise RuntimeError(f"move directory {src_path} to {dst_path} failed") from e
501
+ change_mode(dst_path, FileCheckConst.DATA_DIR_AUTHORITY)
468
502
 
469
503
 
470
504
  def move_file(src_path, dst_path):
@@ -530,7 +564,7 @@ def write_csv(data, filepath, mode="a+", malicious_check=False):
530
564
  if not isinstance(value, str):
531
565
  return True
532
566
  try:
533
- # -1.00 or +1.00 should be consdiered as digit numbers
567
+ # -1.00 or +1.00 should be considered as digit numbers
534
568
  float(value)
535
569
  except ValueError:
536
570
  # otherwise, they will be considered as formular injections
@@ -576,7 +610,7 @@ def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False
576
610
  if not isinstance(value, str):
577
611
  return True
578
612
  try:
579
- # -1.00 or +1.00 should be consdiered as digit numbers
613
+ # -1.00 or +1.00 should be considered as digit numbers
580
614
  float(value)
581
615
  except ValueError:
582
616
  # otherwise, they will be considered as formular injections
@@ -607,8 +641,11 @@ def write_df_to_csv(data, filepath, mode="w", header=True, malicious_check=False
607
641
  def remove_path(path):
608
642
  if not os.path.exists(path):
609
643
  return
644
+ if os.path.islink(path):
645
+ logger.error(f"Failed to delete {path}, it is a symbolic link.")
646
+ raise RuntimeError("Delete file or directory failed.")
610
647
  try:
611
- if os.path.islink(path) or os.path.isfile(path):
648
+ if os.path.isfile(path):
612
649
  os.remove(path)
613
650
  else:
614
651
  shutil.rmtree(path)
@@ -617,7 +654,7 @@ def remove_path(path):
617
654
  raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) from err
618
655
  except Exception as e:
619
656
  logger.error("Failed to delete {}. Please check.".format(path))
620
- raise RuntimeError(f"Delete {path} failed.") from e
657
+ raise RuntimeError("Delete file or directory failed.") from e
621
658
 
622
659
 
623
660
  def get_json_contents(file_path):
@@ -651,46 +688,238 @@ def os_walk_for_files(path, depth):
651
688
  return res
652
689
 
653
690
 
654
- def check_crt_valid(pem_path, is_public_key=False):
691
+ def check_zip_file(zip_file_path):
692
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_file:
693
+ total_size = 0
694
+ if len(zip_file.infolist()) > FileCheckConst.MAX_FILE_IN_ZIP_SIZE:
695
+ raise ValueError(f"Too many files in {os.path.basename(zip_file_path)}")
696
+ for file_info in zip_file.infolist():
697
+ if file_info.file_size > FileCheckConst.MAX_FILE_SIZE:
698
+ raise ValueError(f"File {file_info.filename} is too large to extract")
699
+
700
+ total_size += file_info.file_size
701
+ if total_size > FileCheckConst.MAX_ZIP_SIZE:
702
+ raise ValueError(f"Total extracted size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes")
703
+
704
+
705
+ def read_xlsx(file_path, sheet_name=None):
706
+ check_file_or_directory_path(file_path)
707
+ check_zip_file(file_path)
708
+ try:
709
+ if sheet_name:
710
+ result_df = pd.read_excel(file_path, keep_default_na=False, sheet_name=sheet_name)
711
+ else:
712
+ result_df = pd.read_excel(file_path, keep_default_na=False)
713
+ except Exception as e:
714
+ logger.error(f"The xlsx file failed to load. Please check the path: {file_path}.")
715
+ raise RuntimeError(f"Read xlsx file {file_path} failed.") from e
716
+ return result_df
717
+
718
+
719
+ def create_file_with_list(result_list, filepath):
720
+ check_path_before_create(filepath)
721
+ filepath = os.path.realpath(filepath)
722
+ try:
723
+ with FileOpen(filepath, 'w', encoding='utf-8') as file:
724
+ fcntl.flock(file, fcntl.LOCK_EX)
725
+ for item in result_list:
726
+ file.write(item + '\n')
727
+ fcntl.flock(file, fcntl.LOCK_UN)
728
+ except Exception as e:
729
+ logger.error(f'Save list to file "{os.path.basename(filepath)}" failed.')
730
+ raise RuntimeError(f"Save list to file {os.path.basename(filepath)} failed.") from e
731
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
732
+
733
+
734
+ def create_file_with_content(data, filepath):
735
+ check_path_before_create(filepath)
736
+ filepath = os.path.realpath(filepath)
737
+ try:
738
+ with FileOpen(filepath, 'w', encoding='utf-8') as file:
739
+ fcntl.flock(file, fcntl.LOCK_EX)
740
+ file.write(data)
741
+ fcntl.flock(file, fcntl.LOCK_UN)
742
+ except Exception as e:
743
+ logger.error(f'Save content to file "{os.path.basename(filepath)}" failed.')
744
+ raise RuntimeError(f"Save content to file {os.path.basename(filepath)} failed.") from e
745
+ change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
746
+
747
+
748
+ def add_file_to_zip(zip_file_path, file_path, arc_path=None):
655
749
  """
656
- Check the validity of the SSL certificate.
750
+ Add a file to a ZIP archive, if zip does not exist, create one.
657
751
 
658
- Load the SSL certificate from the specified path, parse and check its validity period.
659
- If the certificate is expired or invalid, raise a RuntimeError.
752
+ :param zip_file_path: Path to the ZIP archive
753
+ :param file_path: Path to the file to add
754
+ :param arc_path: Optional path inside the ZIP archive where the file should be added
755
+ """
756
+ check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
757
+ check_file_size(file_path, FileCheckConst.MAX_FILE_IN_ZIP_SIZE)
758
+ zip_size = os.path.getsize(zip_file_path) if os.path.exists(zip_file_path) else 0
759
+ if zip_size + os.path.getsize(file_path) > FileCheckConst.MAX_ZIP_SIZE:
760
+ raise RuntimeError(f"ZIP file size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes")
761
+ check_path_before_create(zip_file_path)
762
+ try:
763
+ proc_lock.acquire()
764
+ with zipfile.ZipFile(zip_file_path, 'a') as zip_file:
765
+ zip_file.write(file_path, arc_path)
766
+ except Exception as e:
767
+ logger.error(f'add file to zip "{os.path.basename(zip_file_path)}" failed.')
768
+ raise RuntimeError(f"add file to zip {os.path.basename(zip_file_path)} failed.") from e
769
+ finally:
770
+ proc_lock.release()
771
+ change_mode(zip_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
660
772
 
661
- Parameters:
662
- pem_path (str): The file path of the SSL certificate.
663
- is_public_key (bool): The file is public key or not.
664
773
 
665
- Raises:
666
- RuntimeError: If the SSL certificate is invalid or expired.
774
+ def create_file_in_zip(zip_file_path, file_name, content):
667
775
  """
668
- import OpenSSL
776
+ Create a file with content inside a ZIP archive.
777
+
778
+ :param zip_file_path: Path to the ZIP archive
779
+ :param file_name: Name of the file to create
780
+ :param content: Content to write to the file
781
+ """
782
+ check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
783
+ check_path_before_create(zip_file_path)
784
+ zip_size = os.path.getsize(zip_file_path) if os.path.exists(zip_file_path) else 0
785
+ if zip_size + sys.getsizeof(content) > FileCheckConst.MAX_ZIP_SIZE:
786
+ raise RuntimeError(f"ZIP file size exceeds the limit of {FileCheckConst.MAX_ZIP_SIZE} bytes")
669
787
  try:
670
- with FileOpen(pem_path, "r") as f:
671
- pem_data = f.read()
672
- if is_public_key:
673
- cert = OpenSSL.crypto.load_publickey(OpenSSL.crypto.FILETYPE_PEM, pem_data)
674
- else:
675
- cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pem_data)
676
- pem_start = parser.parse(cert.get_notBefore().decode("UTF-8"))
677
- pem_end = parser.parse(cert.get_notAfter().decode("UTF-8"))
678
- logger.info(f"The SSL certificate passes the verification and the validity period "
679
- f"starts from {pem_start} ends at {pem_end}.")
788
+ with open(zip_file_path, 'a+') as f: # 必须用 'a+' 模式才能 flock
789
+ # 2. 获取排他锁(阻塞直到成功)
790
+ fcntl.flock(f, fcntl.LOCK_EX) # LOCK_EX: 独占锁
791
+ with zipfile.ZipFile(zip_file_path, 'a') as zip_file:
792
+ zip_info = zipfile.ZipInfo(file_name)
793
+ zip_info.compress_type = zipfile.ZIP_DEFLATED
794
+ zip_file.writestr(zip_info, content)
795
+ fcntl.flock(f, fcntl.LOCK_UN)
680
796
  except Exception as e:
681
- logger.error("Failed to parse the SSL certificate. Check the certificate.")
682
- raise RuntimeError(f"The SSL certificate is invalid, {pem_path}") from e
797
+ logger.error(f'Save content to file "{os.path.basename(zip_file_path)}" failed.')
798
+ raise RuntimeError(f"Save content to file {os.path.basename(zip_file_path)} failed.") from e
799
+ change_mode(zip_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
683
800
 
684
- now_utc = datetime.now(tz=timezone.utc)
685
- if cert.has_expired() or not (pem_start <= now_utc <= pem_end):
686
- raise RuntimeError(f"The SSL certificate has expired and needs to be replaced, {pem_path}")
687
801
 
802
+ def extract_zip(zip_file_path, extract_dir):
803
+ """
804
+ Extract the contents of a ZIP archive to a specified directory.
688
805
 
689
- def read_xlsx(file_path):
690
- check_file_or_directory_path(file_path)
806
+ :param zip_file_path: Path to the ZIP archive
807
+ :param extract_dir: Directory to extract the contents to
808
+ """
809
+ check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
691
810
  try:
692
- result_df = pd.read_excel(file_path, keep_default_na=False)
811
+ proc_lock.acquire()
812
+ check_zip_file(zip_file_path)
693
813
  except Exception as e:
694
- logger.error(f"The xlsx file failed to load. Please check the path: {file_path}.")
695
- raise RuntimeError(f"Read xlsx file {file_path} failed.") from e
696
- return result_df
814
+ logger.error(f'Save content to file "{os.path.basename(zip_file_path)}" failed.')
815
+ raise RuntimeError(f"Save content to file {os.path.basename(zip_file_path)} failed.") from e
816
+ finally:
817
+ proc_lock.release()
818
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_file:
819
+ zip_file.extractall(extract_dir)
820
+
821
+
822
+ def split_zip_file_path(zip_file_path):
823
+ check_file_suffix(zip_file_path, FileCheckConst.ZIP_SUFFIX)
824
+ zip_file_path = os.path.realpath(zip_file_path)
825
+ return os.path.dirname(zip_file_path), os.path.basename(zip_file_path)
826
+
827
+
828
+ def dedup_log(func_name, filter_name):
829
+ with SharedDict() as shared_dict:
830
+ exist_names = shared_dict.get(func_name, set())
831
+ if filter_name in exist_names:
832
+ return False
833
+ exist_names.add(filter_name)
834
+ shared_dict[func_name] = exist_names
835
+ return True
836
+
837
+
838
+ class SharedDict:
839
+ def __init__(self):
840
+ self._changed = False
841
+ self._dict = None
842
+ self._shm = None
843
+
844
+ def __enter__(self):
845
+ self._load_shared_memory()
846
+ return self
847
+
848
+ def __exit__(self, exc_type, exc_val, exc_tb):
849
+ try:
850
+ if self._changed:
851
+ data = pickle.dumps(self._dict)
852
+ global_lock.acquire()
853
+ try:
854
+ self._shm.buf[0:len(data)] = bytearray(data)
855
+ finally:
856
+ global_lock.release()
857
+ self._shm.close()
858
+ except FileNotFoundError:
859
+ name = self.get_shared_memory_name()
860
+ logger.debug(f'close shared memory {name} failed, shared memory has already been destroyed.')
861
+
862
+ def __setitem__(self, key, value):
863
+ self._dict[key] = value
864
+ self._changed = True
865
+
866
+ def __contains__(self, item):
867
+ return item in self._dict
868
+
869
+ @classmethod
870
+ def destroy_shared_memory(cls):
871
+ if is_main_process():
872
+ name = cls.get_shared_memory_name()
873
+ try:
874
+ shm = shared_memory.SharedMemory(create=False, name=name)
875
+ shm.close()
876
+ shm.unlink()
877
+ logger.debug(f'destroy shared memory, name: {name}')
878
+ except FileNotFoundError:
879
+ logger.debug(f'destroy shared memory {name} failed, shared memory has already been destroyed.')
880
+
881
+ @classmethod
882
+ def get_shared_memory_name(cls):
883
+ if is_main_process():
884
+ return f'shared_memory_{os.getpid()}'
885
+ return f'shared_memory_{os.getppid()}'
886
+
887
+ def get(self, key, default=None):
888
+ return self._dict.get(key, default)
889
+
890
+ def _load_shared_memory(self):
891
+ name = self.get_shared_memory_name()
892
+ try:
893
+ self._shm = shared_memory.SharedMemory(create=False, name=name)
894
+ except FileNotFoundError:
895
+ try:
896
+ # 共享内存空间增加至5M
897
+ self._shm = shared_memory.SharedMemory(create=True, name=name, size=1024 * 1024 * 5)
898
+ data = pickle.dumps({})
899
+ self._shm.buf[0:len(data)] = bytearray(data)
900
+ logger.debug(f'create shared memory, name: {name}')
901
+ except FileExistsError:
902
+ self._shm = shared_memory.SharedMemory(create=False, name=name)
903
+ self._safe_load()
904
+
905
+ def _safe_load(self):
906
+ with io.BytesIO(self._shm.buf[:]) as buff:
907
+ try:
908
+ self._dict = SafeUnpickler(buff).load()
909
+ except Exception as e:
910
+ logger.debug(f'shared dict is unreadable, reason: {e}, create new dict.')
911
+ self._dict = {}
912
+ self._shm.buf[:] = bytearray(b'\x00' * len(self._shm.buf)) # 清空内存
913
+ self._changed = True
914
+
915
+
916
+ class SafeUnpickler(pickle.Unpickler):
917
+ WHITELIST = {'builtins': {'str', 'bool', 'int', 'float', 'list', 'set', 'dict'}}
918
+
919
+ def find_class(self, module, name):
920
+ if module in self.WHITELIST and name in self.WHITELIST[module]:
921
+ return super().find_class(module, name)
922
+ raise pickle.PicklingError(f'Unpickling {module}.{name} is illegal!')
923
+
924
+
925
+ atexit.register(SharedDict.destroy_shared_memory)
@@ -0,0 +1,169 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.import functools
15
+ import functools
16
+ from msprobe.core.common.const import Const
17
+ from msprobe.core.common.file_utils import check_file_or_directory_path
18
+ from msprobe.core.common.file_utils import save_npy
19
+
20
+
21
+ class FrameworkDescriptor:
22
+ def __get__(self, instance, owner):
23
+ if owner._framework is None:
24
+ owner.import_framework()
25
+ return owner._framework
26
+
27
+
28
+ class FmkAdp:
29
+ fmk = Const.PT_FRAMEWORK
30
+ supported_fmk = [Const.PT_FRAMEWORK, Const.MS_FRAMEWORK]
31
+ supported_dtype_list = ["bfloat16", "float16", "float32", "float64"]
32
+ _framework = None
33
+ framework = FrameworkDescriptor()
34
+
35
+ @classmethod
36
+ def import_framework(cls):
37
+ if cls.fmk == Const.PT_FRAMEWORK:
38
+ import torch
39
+ cls._framework = torch
40
+ elif cls.fmk == Const.MS_FRAMEWORK:
41
+ import mindspore
42
+ cls._framework = mindspore
43
+ else:
44
+ raise Exception(f"init framework adapter error, not in {cls.supported_fmk}")
45
+
46
+ @classmethod
47
+ def set_fmk(cls, fmk=Const.PT_FRAMEWORK):
48
+ if fmk not in cls.supported_fmk:
49
+ raise Exception(f"init framework adapter error, not in {cls.supported_fmk}")
50
+ cls.fmk = fmk
51
+ cls._framework = None # 重置框架,以便下次访问时重新导入
52
+
53
+ @classmethod
54
+ def get_rank(cls):
55
+ if cls.fmk == Const.PT_FRAMEWORK:
56
+ return cls.framework.distributed.get_rank()
57
+ return cls.framework.communication.get_rank()
58
+
59
+ @classmethod
60
+ def get_rank_id(cls):
61
+ if cls.is_initialized():
62
+ return cls.get_rank()
63
+ return 0
64
+
65
+ @classmethod
66
+ def is_initialized(cls):
67
+ if cls.fmk == Const.PT_FRAMEWORK:
68
+ return cls.framework.distributed.is_initialized()
69
+ return cls.framework.communication.GlobalComm.INITED
70
+
71
+ @classmethod
72
+ def is_nn_module(cls, module):
73
+ if cls.fmk == Const.PT_FRAMEWORK:
74
+ return isinstance(module, cls.framework.nn.Module)
75
+ return isinstance(module, cls.framework.nn.Cell)
76
+
77
+ @classmethod
78
+ def is_tensor(cls, tensor):
79
+ if cls.fmk == Const.PT_FRAMEWORK:
80
+ return isinstance(tensor, cls.framework.Tensor)
81
+ return isinstance(tensor, cls.framework.Tensor)
82
+
83
+ @classmethod
84
+ def process_tensor(cls, tensor, func):
85
+ if cls.fmk == Const.PT_FRAMEWORK:
86
+ if not tensor.is_floating_point() or tensor.dtype == cls.framework.float64:
87
+ tensor = tensor.float()
88
+ return float(func(tensor))
89
+ return float(func(tensor).asnumpy())
90
+
91
+ @classmethod
92
+ def tensor_max(cls, tensor):
93
+ return cls.process_tensor(tensor, lambda x: x.max())
94
+
95
+ @classmethod
96
+ def tensor_min(cls, tensor):
97
+ return cls.process_tensor(tensor, lambda x: x.min())
98
+
99
+ @classmethod
100
+ def tensor_mean(cls, tensor):
101
+ return cls.process_tensor(tensor, lambda x: x.mean())
102
+
103
+ @classmethod
104
+ def tensor_norm(cls, tensor):
105
+ return cls.process_tensor(tensor, lambda x: x.norm())
106
+
107
+ @classmethod
108
+ def save_tensor(cls, tensor, filepath):
109
+ if cls.fmk == Const.PT_FRAMEWORK:
110
+ tensor_npy = tensor.cpu().detach().float().numpy()
111
+ else:
112
+ tensor_npy = tensor.asnumpy()
113
+ save_npy(tensor_npy, filepath)
114
+
115
+ @classmethod
116
+ def dtype(cls, dtype_str):
117
+ if dtype_str not in cls.supported_dtype_list:
118
+ raise Exception(f"{dtype_str} is not supported by adapter, not in {cls.supported_dtype_list}")
119
+ return getattr(cls.framework, dtype_str)
120
+
121
+ @classmethod
122
+ def named_parameters(cls, module):
123
+ if cls.fmk == Const.PT_FRAMEWORK:
124
+ if not isinstance(module, cls.framework.nn.Module):
125
+ raise Exception(f"{module} is not a torch.nn.Module")
126
+ return module.named_parameters()
127
+ if not isinstance(module, cls.framework.nn.Cell):
128
+ raise Exception(f"{module} is not a mindspore.nn.Cell")
129
+ return module.parameters_and_names()
130
+
131
+ @classmethod
132
+ def register_forward_pre_hook(cls, module, hook, with_kwargs=False):
133
+ if cls.fmk == Const.PT_FRAMEWORK:
134
+ if not isinstance(module, cls.framework.nn.Module):
135
+ raise Exception(f"{module} is not a torch.nn.Module")
136
+ module.register_forward_pre_hook(hook, with_kwargs=with_kwargs)
137
+ else:
138
+ if not isinstance(module, cls.framework.nn.Cell):
139
+ raise Exception(f"{module} is not a mindspore.nn.Cell")
140
+ original_construct = module.construct
141
+
142
+ @functools.wraps(original_construct)
143
+ def new_construct(*args, **kwargs):
144
+ if with_kwargs:
145
+ hook(module, args, kwargs)
146
+ else:
147
+ hook(module, args)
148
+ return original_construct(*args, **kwargs)
149
+
150
+ module.construct = new_construct
151
+
152
+ @classmethod
153
+ def load_checkpoint(cls, path, to_cpu=True, weights_only=True):
154
+ check_file_or_directory_path(path)
155
+ if cls.fmk == Const.PT_FRAMEWORK:
156
+ try:
157
+ if to_cpu:
158
+ return cls.framework.load(path, map_location=cls.framework.device("cpu"), weights_only=weights_only)
159
+ else:
160
+ return cls.framework.load(path, weights_only=weights_only)
161
+ except Exception as e:
162
+ raise RuntimeError(f"load pt file {path} failed: {e}") from e
163
+ return mindspore.load_checkpoint(path)
164
+
165
+ @classmethod
166
+ def asnumpy(cls, tensor):
167
+ if cls.fmk == Const.PT_FRAMEWORK:
168
+ return tensor.float().numpy()
169
+ return tensor.float().asnumpy()
@@ -0,0 +1,86 @@
1
+ # Copyright (c) 2025-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import multiprocessing
17
+ from multiprocessing.shared_memory import SharedMemory
18
+ import random
19
+ import time
20
+ import atexit
21
+ import os
22
+
23
+ from msprobe.core.common.log import logger
24
+
25
+
26
+ def is_main_process():
27
+ return multiprocessing.current_process().name == 'MainProcess'
28
+
29
+
30
+ class GlobalLock:
31
+ def __init__(self):
32
+ self.name = self.get_lock_name()
33
+ try:
34
+ self._shm = SharedMemory(create=False, name=self.name)
35
+ time.sleep(random.randint(0, 500) / 10000) # 等待随机时长以避免同时获得锁
36
+ except FileNotFoundError:
37
+ try:
38
+ self._shm = SharedMemory(create=True, name=self.name, size=1)
39
+ self._shm.buf[0] = 0
40
+ logger.debug(f'{self.name} is created.')
41
+ except FileExistsError:
42
+ self.__init__()
43
+
44
+ @classmethod
45
+ def get_lock_name(cls):
46
+ if is_main_process():
47
+ return f'global_lock_{os.getpid()}'
48
+ return f'global_lock_{os.getppid()}'
49
+
50
+ @classmethod
51
+ def is_lock_exist(cls):
52
+ try:
53
+ SharedMemory(create=False, name=cls.get_lock_name()).close()
54
+ return True
55
+ except FileNotFoundError:
56
+ return False
57
+
58
+ def cleanup(self):
59
+ self._shm.close()
60
+ if is_main_process():
61
+ try:
62
+ self._shm.unlink()
63
+ logger.debug(f'{self.name} is unlinked.')
64
+ except FileNotFoundError:
65
+ logger.warning(f'{self.name} has already been unlinked.')
66
+
67
+ def acquire(self, timeout=180):
68
+ """
69
+ acquire global lock, default timeout is 3 minutes.
70
+
71
+ :param float timeout: timeout(seconds), default value is 180.
72
+ """
73
+ start = time.time()
74
+ while time.time() - start < timeout:
75
+ if self._shm.buf[0] == 0:
76
+ self._shm.buf[0] = 1
77
+ return
78
+ time.sleep(random.randint(10, 500) / 10000) # 自旋,等待1-50ms
79
+ self._shm.buf[0] = 1
80
+
81
+ def release(self):
82
+ self._shm.buf[0] = 0
83
+
84
+
85
+ global_lock = GlobalLock()
86
+ atexit.register(global_lock.cleanup)