mindstudio-probe 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (228) hide show
  1. mindstudio_probe-1.0.1.dist-info/LICENSE +201 -0
  2. mindstudio_probe-1.0.1.dist-info/METADATA +30 -0
  3. mindstudio_probe-1.0.1.dist-info/RECORD +228 -0
  4. mindstudio_probe-1.0.1.dist-info/WHEEL +5 -0
  5. mindstudio_probe-1.0.1.dist-info/entry_points.txt +2 -0
  6. mindstudio_probe-1.0.1.dist-info/top_level.txt +1 -0
  7. msprobe/README.md +182 -0
  8. msprobe/__init__.py +0 -0
  9. msprobe/config/README.md +397 -0
  10. msprobe/config/config.json +28 -0
  11. msprobe/config/img/free_benchmark.png +0 -0
  12. msprobe/core/common/const.py +241 -0
  13. msprobe/core/common/exceptions.py +88 -0
  14. msprobe/core/common/file_check.py +265 -0
  15. msprobe/core/common/log.py +55 -0
  16. msprobe/core/common/utils.py +516 -0
  17. msprobe/core/common_config.py +58 -0
  18. msprobe/core/data_dump/data_collector.py +140 -0
  19. msprobe/core/data_dump/data_processor/base.py +245 -0
  20. msprobe/core/data_dump/data_processor/factory.py +61 -0
  21. msprobe/core/data_dump/data_processor/pytorch_processor.py +346 -0
  22. msprobe/core/data_dump/json_writer.py +116 -0
  23. msprobe/core/data_dump/scope.py +178 -0
  24. msprobe/mindspore/__init__.py +1 -0
  25. msprobe/mindspore/debugger/__init__.py +0 -0
  26. msprobe/mindspore/debugger/debugger_config.py +51 -0
  27. msprobe/mindspore/debugger/precision_debugger.py +32 -0
  28. msprobe/mindspore/doc/dump.md +65 -0
  29. msprobe/mindspore/dump/__init__.py +0 -0
  30. msprobe/mindspore/dump/api_kbk_dump.py +55 -0
  31. msprobe/mindspore/dump/dump_tool_factory.py +38 -0
  32. msprobe/mindspore/dump/kernel_graph_dump.py +60 -0
  33. msprobe/mindspore/ms_config.py +78 -0
  34. msprobe/mindspore/overflow_check/__init__.py +0 -0
  35. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +45 -0
  36. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +32 -0
  37. msprobe/mindspore/task_handler_factory.py +21 -0
  38. msprobe/msprobe.py +67 -0
  39. msprobe/pytorch/__init__.py +4 -0
  40. msprobe/pytorch/advisor/advisor.py +124 -0
  41. msprobe/pytorch/advisor/advisor_const.py +59 -0
  42. msprobe/pytorch/advisor/advisor_result.py +58 -0
  43. msprobe/pytorch/api_accuracy_checker/.keep +0 -0
  44. msprobe/pytorch/api_accuracy_checker/__init__.py +0 -0
  45. msprobe/pytorch/api_accuracy_checker/common/.keep +0 -0
  46. msprobe/pytorch/api_accuracy_checker/common/__init__.py +0 -0
  47. msprobe/pytorch/api_accuracy_checker/common/config.py +50 -0
  48. msprobe/pytorch/api_accuracy_checker/common/utils.py +224 -0
  49. msprobe/pytorch/api_accuracy_checker/compare/__init__.py +0 -0
  50. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +216 -0
  51. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +545 -0
  52. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +133 -0
  53. msprobe/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml +390 -0
  54. msprobe/pytorch/api_accuracy_checker/compare/compare.py +345 -0
  55. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +74 -0
  56. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +249 -0
  57. msprobe/pytorch/api_accuracy_checker/config.yaml +4 -0
  58. msprobe/pytorch/api_accuracy_checker/run_ut/.keep +0 -0
  59. msprobe/pytorch/api_accuracy_checker/run_ut/__init__.py +0 -0
  60. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +328 -0
  61. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +203 -0
  62. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +127 -0
  63. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +493 -0
  64. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +7 -0
  65. msprobe/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json +5 -0
  66. msprobe/pytorch/common/__init__.py +2 -0
  67. msprobe/pytorch/common/compare_script.template +14 -0
  68. msprobe/pytorch/common/log.py +32 -0
  69. msprobe/pytorch/common/parse_json.py +37 -0
  70. msprobe/pytorch/common/utils.py +224 -0
  71. msprobe/pytorch/compare/acc_compare.py +1024 -0
  72. msprobe/pytorch/compare/distributed_compare.py +111 -0
  73. msprobe/pytorch/compare/highlight.py +100 -0
  74. msprobe/pytorch/compare/mapping.yaml +607 -0
  75. msprobe/pytorch/compare/match.py +36 -0
  76. msprobe/pytorch/compare/npy_compare.py +244 -0
  77. msprobe/pytorch/debugger/__init__.py +0 -0
  78. msprobe/pytorch/debugger/debugger_config.py +86 -0
  79. msprobe/pytorch/debugger/precision_debugger.py +95 -0
  80. msprobe/pytorch/doc/FAQ.md +193 -0
  81. msprobe/pytorch/doc/api_accuracy_checker.md +269 -0
  82. msprobe/pytorch/doc/atat/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/205/342/225/226/320/265/321/205/320/225/342/225/226/321/206/320/245/342/226/221/321/206/320/235/320/276dump/321/206/320/260/320/227/321/205/320/227/320/226/321/206/320/220/320/267/321/210/320/223/342/225/234/321/205/320/257/342/225/221/321/207/342/225/221/342/224/220/321/206/320/232/320/265/321/205/320/241/320/232.md +182 -0
  83. msprobe/pytorch/doc/dump.md +207 -0
  84. msprobe/pytorch/doc/img/BLOOM-7B_1.png +0 -0
  85. msprobe/pytorch/doc/img/BLOOM-7B_2.png +0 -0
  86. msprobe/pytorch/doc/img/BLOOM-7B_3.png +0 -0
  87. msprobe/pytorch/doc/img/BLOOM-7B_4.png +0 -0
  88. msprobe/pytorch/doc/img/GPT-3_1.png +0 -0
  89. msprobe/pytorch/doc/img/GPT-3_2.png +0 -0
  90. msprobe/pytorch/doc/img/GPT-3_3.png +0 -0
  91. msprobe/pytorch/doc/img/GPT-3_4.png +0 -0
  92. msprobe/pytorch/doc/img/GPT-3_5.png +0 -0
  93. msprobe/pytorch/doc/img/GPT-3_6.png +0 -0
  94. msprobe/pytorch/doc/img/GPT-3_7.png +0 -0
  95. msprobe/pytorch/doc/img/GPT-3_8.png +0 -0
  96. msprobe/pytorch/doc/img/YOLOV5S_1.png +0 -0
  97. msprobe/pytorch/doc/img/YOLOV5S_2.png +0 -0
  98. msprobe/pytorch/doc/img/accuracy_checking_details.png +0 -0
  99. msprobe/pytorch/doc/img/accuracy_checking_result.png +0 -0
  100. msprobe/pytorch/doc/img/api_precision_compare_details.png +0 -0
  101. msprobe/pytorch/doc/img/api_precision_compare_result.png +0 -0
  102. msprobe/pytorch/doc/img/auto_analyze_log.png +0 -0
  103. msprobe/pytorch/doc/img/compare_result_pkl.png +0 -0
  104. msprobe/pytorch/doc/img/compare_result_pkl_md5.png.png +0 -0
  105. msprobe/pytorch/doc/img/cpu_info.png +0 -0
  106. msprobe/pytorch/doc/img/module_compare.png +0 -0
  107. msprobe/pytorch/doc/parse_tool.md +286 -0
  108. msprobe/pytorch/doc/ptdbg_ascend_compare.md +176 -0
  109. msprobe/pytorch/doc/ptdbg_ascend_overview.md +68 -0
  110. msprobe/pytorch/doc/ptdbg_ascend_quickstart.md +381 -0
  111. msprobe/pytorch/doc/run_overflow_check.md +25 -0
  112. msprobe/pytorch/doc//321/205/320/254/320/270/321/207/342/225/221/342/224/220/321/207/342/226/223/342/225/233/321/205/342/225/221/320/266/321/206/320/277/320/244/321/205/320/277/342/225/243.md +90 -0
  113. msprobe/pytorch/free_benchmark/__init__.py +8 -0
  114. msprobe/pytorch/free_benchmark/common/__init__.py +0 -0
  115. msprobe/pytorch/free_benchmark/common/constant.py +67 -0
  116. msprobe/pytorch/free_benchmark/common/counter.py +72 -0
  117. msprobe/pytorch/free_benchmark/common/enums.py +37 -0
  118. msprobe/pytorch/free_benchmark/common/params.py +129 -0
  119. msprobe/pytorch/free_benchmark/common/utils.py +98 -0
  120. msprobe/pytorch/free_benchmark/compare/grad_saver.py +183 -0
  121. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +104 -0
  122. msprobe/pytorch/free_benchmark/main.py +102 -0
  123. msprobe/pytorch/free_benchmark/perturbed_layers/__init__.py +0 -0
  124. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +13 -0
  125. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +41 -0
  126. msprobe/pytorch/free_benchmark/perturbed_layers/npu/__init__.py +0 -0
  127. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +90 -0
  128. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +104 -0
  129. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +63 -0
  130. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +68 -0
  131. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +28 -0
  132. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +45 -0
  133. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +19 -0
  134. msprobe/pytorch/free_benchmark/result_handlers/__init__.py +0 -0
  135. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +203 -0
  136. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +39 -0
  137. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +24 -0
  138. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +31 -0
  139. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +170 -0
  140. msprobe/pytorch/functional/__init__.py +0 -0
  141. msprobe/pytorch/functional/data_processor.py +0 -0
  142. msprobe/pytorch/functional/dump_module.py +39 -0
  143. msprobe/pytorch/hook_module/__init__.py +1 -0
  144. msprobe/pytorch/hook_module/api_registry.py +161 -0
  145. msprobe/pytorch/hook_module/hook_module.py +109 -0
  146. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1876 -0
  147. msprobe/pytorch/hook_module/utils.py +29 -0
  148. msprobe/pytorch/hook_module/wrap_aten.py +100 -0
  149. msprobe/pytorch/hook_module/wrap_distributed.py +75 -0
  150. msprobe/pytorch/hook_module/wrap_functional.py +108 -0
  151. msprobe/pytorch/hook_module/wrap_npu_custom.py +73 -0
  152. msprobe/pytorch/hook_module/wrap_tensor.py +72 -0
  153. msprobe/pytorch/hook_module/wrap_torch.py +88 -0
  154. msprobe/pytorch/hook_module/wrap_vf.py +64 -0
  155. msprobe/pytorch/module_processer.py +98 -0
  156. msprobe/pytorch/online_dispatch/__init__.py +20 -0
  157. msprobe/pytorch/online_dispatch/compare.py +236 -0
  158. msprobe/pytorch/online_dispatch/dispatch.py +274 -0
  159. msprobe/pytorch/online_dispatch/dump_compare.py +186 -0
  160. msprobe/pytorch/online_dispatch/single_compare.py +391 -0
  161. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +50 -0
  162. msprobe/pytorch/online_dispatch/utils.py +187 -0
  163. msprobe/pytorch/parse.py +4 -0
  164. msprobe/pytorch/parse_tool/__init__.py +0 -0
  165. msprobe/pytorch/parse_tool/cli.py +32 -0
  166. msprobe/pytorch/parse_tool/lib/__init__.py +0 -0
  167. msprobe/pytorch/parse_tool/lib/compare.py +259 -0
  168. msprobe/pytorch/parse_tool/lib/config.py +51 -0
  169. msprobe/pytorch/parse_tool/lib/file_desc.py +31 -0
  170. msprobe/pytorch/parse_tool/lib/interactive_cli.py +102 -0
  171. msprobe/pytorch/parse_tool/lib/parse_exception.py +54 -0
  172. msprobe/pytorch/parse_tool/lib/parse_tool.py +158 -0
  173. msprobe/pytorch/parse_tool/lib/utils.py +367 -0
  174. msprobe/pytorch/parse_tool/lib/visualization.py +90 -0
  175. msprobe/pytorch/pt_config.py +93 -0
  176. msprobe/pytorch/service.py +167 -0
  177. msprobe/test/core_ut/common/test_utils.py +345 -0
  178. msprobe/test/core_ut/data_dump/test_data_collector.py +47 -0
  179. msprobe/test/core_ut/data_dump/test_json_writer.py +183 -0
  180. msprobe/test/core_ut/data_dump/test_scope.py +151 -0
  181. msprobe/test/core_ut/test_common_config.py +152 -0
  182. msprobe/test/core_ut/test_file_check.py +218 -0
  183. msprobe/test/core_ut/test_log.py +109 -0
  184. msprobe/test/mindspore_ut/test_api_kbk_dump.py +51 -0
  185. msprobe/test/mindspore_ut/test_debugger_config.py +42 -0
  186. msprobe/test/mindspore_ut/test_dump_tool_factory.py +51 -0
  187. msprobe/test/mindspore_ut/test_kernel_graph_dump.py +66 -0
  188. msprobe/test/mindspore_ut/test_kernel_graph_overflow_check.py +63 -0
  189. msprobe/test/mindspore_ut/test_ms_config.py +69 -0
  190. msprobe/test/mindspore_ut/test_overflow_check_tool_factory.py +51 -0
  191. msprobe/test/mindspore_ut/test_precision_debugger.py +56 -0
  192. msprobe/test/mindspore_ut/test_task_handler_factory.py +58 -0
  193. msprobe/test/pytorch_ut/advisor/test_advisor.py +83 -0
  194. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +108 -0
  195. msprobe/test/pytorch_ut/api_accuracy_checker/common/test_config.py +39 -0
  196. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_algorithm.py +112 -0
  197. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +77 -0
  198. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare.py +125 -0
  199. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_column.py +10 -0
  200. msprobe/test/pytorch_ut/api_accuracy_checker/compare/test_compare_utils.py +43 -0
  201. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/dump.json +179 -0
  202. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/forward.json +63 -0
  203. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_data_generate.py +99 -0
  204. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_multi_run_ut.py +115 -0
  205. msprobe/test/pytorch_ut/api_accuracy_checker/run_ut/test_run_ut.py +72 -0
  206. msprobe/test/pytorch_ut/compare/test_acc_compare.py +17 -0
  207. msprobe/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +105 -0
  208. msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +121 -0
  209. msprobe/test/pytorch_ut/free_benchmark/test_main.py +101 -0
  210. msprobe/test/pytorch_ut/functional/test_dump_module.py +15 -0
  211. msprobe/test/pytorch_ut/hook_module/test_api_registry.py +130 -0
  212. msprobe/test/pytorch_ut/hook_module/test_hook_module.py +42 -0
  213. msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +65 -0
  214. msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +35 -0
  215. msprobe/test/pytorch_ut/hook_module/test_wrap_functional.py +20 -0
  216. msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +35 -0
  217. msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +43 -0
  218. msprobe/test/pytorch_ut/hook_module/test_wrap_vf.py +11 -0
  219. msprobe/test/pytorch_ut/test_pt_config.py +69 -0
  220. msprobe/test/pytorch_ut/test_service.py +59 -0
  221. msprobe/test/resources/advisor.txt +3 -0
  222. msprobe/test/resources/compare_result_20230703104808.csv +9 -0
  223. msprobe/test/resources/compare_result_without_accuracy.csv +9 -0
  224. msprobe/test/resources/config.yaml +3 -0
  225. msprobe/test/resources/npu_test.pkl +8 -0
  226. msprobe/test/run_test.sh +30 -0
  227. msprobe/test/run_ut.py +58 -0
  228. msprobe/test/test_module_processer.py +64 -0
@@ -0,0 +1,346 @@
1
+ import os
2
+ import zlib
3
+ from dataclasses import asdict
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import torch
8
+ from msprobe.core.common.exceptions import MsaccException
9
+ from msprobe.core.common.file_check import path_len_exceeds_limit, change_mode
10
+ from msprobe.core.common.log import logger
11
+ from msprobe.core.common.const import Const, OverflowConst, FileCheckConst
12
+ from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
13
+ ModuleForwardInputsOutputs, TensorStatInfo
14
+ from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
15
+
16
+ try:
17
+ import torch_npu
18
+ except ImportError:
19
+ pass
20
+
21
+
22
+ class PytorchDataProcessor(BaseDataProcessor):
23
+ pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor)
24
+
25
+ def __init__(self, config, data_writer):
26
+ super().__init__(config, data_writer)
27
+ self.torch_object_key = {
28
+ "device": self.analyze_device_in_kwargs,
29
+ "dtype": self.analyze_dtype_in_kwargs
30
+ }
31
+
32
+ @staticmethod
33
+ def get_md5_for_tensor(x):
34
+ if x.dtype == torch.bfloat16:
35
+ x = x.float()
36
+ tensor_bytes = x.cpu().detach().numpy().tobytes()
37
+ crc32_hash = zlib.crc32(tensor_bytes)
38
+ return f"{crc32_hash:08x}"
39
+
40
+ @staticmethod
41
+ def analyze_device_in_kwargs(element):
42
+ single_arg = {}
43
+ single_arg.update({'type': "torch.device"})
44
+ if not isinstance(element, str):
45
+ if hasattr(element, "index"):
46
+ device_value = element.type + ":" + str(element.index)
47
+ else:
48
+ device_value = element.type
49
+ single_arg.update({"value": device_value})
50
+ else:
51
+ single_arg.update({"value": element})
52
+ return single_arg
53
+
54
+ @staticmethod
55
+ def analyze_dtype_in_kwargs(element):
56
+ return {"type": "torch.dtype", "value": str(element)}
57
+
58
+ @staticmethod
59
+ def get_stat_info(data):
60
+ tensor_stat = TensorStatInfo()
61
+ if data.is_meta:
62
+ return tensor_stat
63
+ data_clone = data.detach()
64
+ if data_clone.numel() == 0:
65
+ return tensor_stat
66
+ elif data_clone.dtype == torch.bool:
67
+ tensor_stat.max = True in data_clone
68
+ tensor_stat.min = False not in data_clone
69
+ elif not data_clone.shape:
70
+ tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item()
71
+ else:
72
+ if not data_clone.is_floating_point() or data_clone.dtype == torch.float64:
73
+ data_clone = data_clone.float()
74
+ tensor_stat.max = torch._C._VariableFunctionsClass.max(data_clone).item()
75
+ tensor_stat.min = torch._C._VariableFunctionsClass.min(data_clone).item()
76
+ tensor_stat.mean = torch._C._VariableFunctionsClass.mean(data_clone).item()
77
+ tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item()
78
+ return tensor_stat
79
+
80
+ @staticmethod
81
+ def _analyze_torch_size(arg):
82
+ return {"type": "torch.Size", "value": list(arg)}
83
+
84
+ @classmethod
85
+ def get_special_types(cls):
86
+ return super().get_special_types() + cls.pytorch_special_type
87
+
88
+ def analyze_single_element(self, element, suffix_stack):
89
+ if suffix_stack and suffix_stack[-1] in self.torch_object_key:
90
+ return self.torch_object_key[suffix_stack[-1]](element)
91
+ if isinstance(element, torch.Size):
92
+ return self._analyze_torch_size(element)
93
+ converted_numpy, numpy_type = self._convert_numpy_to_builtin(element)
94
+ if converted_numpy is not element:
95
+ return self._analyze_numpy(converted_numpy, numpy_type)
96
+ if isinstance(element, torch.Tensor):
97
+ return self._analyze_tensor(element, Const.SEP.join(suffix_stack))
98
+ if isinstance(element, (bool, int, float, str, slice)):
99
+ return self._analyze_builtin(element)
100
+ return None
101
+
102
+ def analyze_element(self, element):
103
+ return self.recursive_apply_transform(element, self.analyze_single_element)
104
+
105
+ def _analyze_tensor(self, tensor, suffix):
106
+ tensor_stat = self.get_stat_info(tensor)
107
+ tensor_json = {}
108
+ tensor_json.update({'type': 'torch.Tensor'})
109
+ tensor_json.update({'dtype': str(tensor.dtype)})
110
+ tensor_json.update({"shape": tensor.shape})
111
+ tensor_json.update({"Max": tensor_stat.max})
112
+ tensor_json.update({"Min": tensor_stat.min})
113
+ tensor_json.update({"Mean": tensor_stat.mean})
114
+ tensor_json.update({"Norm": tensor_stat.norm})
115
+ tensor_json.update({"requires_grad": tensor.requires_grad})
116
+ if self.config.summary_mode == "md5":
117
+ tensor_md5 = self.get_md5_for_tensor(tensor)
118
+ tensor_json.update({"md5": tensor_md5})
119
+ return tensor_json
120
+
121
+
122
+ class StatisticsDataProcessor(PytorchDataProcessor):
123
+ pass
124
+
125
+
126
+ class TensorDataProcessor(PytorchDataProcessor):
127
+ def _analyze_tensor(self, tensor, suffix):
128
+ dump_data_name, file_path = self.get_save_file_path(suffix)
129
+ if not path_len_exceeds_limit(file_path):
130
+ torch.save(tensor, file_path)
131
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
132
+ else:
133
+ logger.warning(f'The file path {file_path} length exceeds limit.')
134
+ single_arg = super()._analyze_tensor(tensor, suffix)
135
+ single_arg.update({"data_name": dump_data_name})
136
+ return single_arg
137
+
138
+
139
+ class OverflowCheckDataProcessor(PytorchDataProcessor):
140
+ __slots__ = ["cached_tensors_and_file_paths"]
141
+
142
+ def __init__(self, config, data_writer):
143
+ super().__init__(config, data_writer)
144
+ self.cached_tensors_and_file_paths = {}
145
+ self.real_overflow_dump_times = 0
146
+ self.overflow_nums = config.overflow_num
147
+ self.bits_for_overflow = 8
148
+
149
+ @staticmethod
150
+ def overflow_debug_mode_enable():
151
+ overflow_mode = os.getenv(OverflowConst.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE)
152
+ return overflow_mode == Const.ENV_ENABLE
153
+
154
+ @staticmethod
155
+ def handle_tensor_extremum_nan_inf(data_clone, operator):
156
+ data_nan = torch._C._VariableFunctionsClass.isnan(data_clone)
157
+ if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel():
158
+ return float('nan')
159
+ finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone)
160
+ if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0:
161
+ finite_values = data_clone[finite_mask]
162
+ return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \
163
+ torch._C._VariableFunctionsClass.min(finite_values).item()
164
+ else:
165
+ data_no_nan = data_clone[~data_nan]
166
+ return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \
167
+ torch._C._VariableFunctionsClass.min(data_no_nan).item()
168
+
169
+ def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
170
+ self.has_overflow = False
171
+ api_info_struct = super().analyze_forward(name, module, module_input_output)
172
+ self.maybe_save_overflow_data_and_check_overflow_times()
173
+ return api_info_struct if self.has_overflow else None
174
+
175
+ def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
176
+ self.has_overflow = False
177
+ api_info_struct = super().analyze_backward(name, module, module_input_output)
178
+ self.maybe_save_overflow_data_and_check_overflow_times()
179
+ return api_info_struct if self.has_overflow else None
180
+
181
+ def maybe_save_overflow_data_and_check_overflow_times(self):
182
+ if self.has_overflow:
183
+ for file_path, tensor in self.cached_tensors_and_file_paths.items():
184
+ torch.save(tensor, file_path)
185
+ change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
186
+ self.inc_and_check_overflow_times()
187
+ self.cached_tensors_and_file_paths = {}
188
+
189
+ def inc_and_check_overflow_times(self):
190
+ self.real_overflow_dump_times += 1
191
+ if self.overflow_nums == -1:
192
+ return
193
+ if self.real_overflow_dump_times >= self.overflow_nums:
194
+ raise MsaccException(MsaccException.OVERFLOW_NUMS_ERROR, str(self.real_overflow_dump_times))
195
+
196
+ def check_overflow_npu(self):
197
+ if self.overflow_debug_mode_enalbe():
198
+ float_status = torch.zeros(self.bits_for_overflow).npu()
199
+ result = torch_npu.npu_get_float_status(float_status, OverflowConst.OVERFLOW_DEBUG_MODE)
200
+ if result.cpu()[0] != 0:
201
+ return True
202
+ else:
203
+ return False
204
+ else:
205
+ return torch_npu._C._check_overflow_npu()
206
+
207
+ def clear_overflow_npu(self):
208
+ if self.overflow_debug_mode_enable():
209
+ float_status = torch.zeros(self.bits_for_overflow).npu()
210
+ torch_npu.npu_clear_float_status(float_status, OverflowConst.OVERFLOW_DEBUG_MODE)
211
+ else:
212
+ torch_npu._C._clear_overflow_npu()
213
+
214
+ def _analyze_maybe_overflow_tensor(self, tensor_json, tensor):
215
+ data_clone = tensor.detach()
216
+ if hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan():
217
+ if tensor_json['Max'] is None:
218
+ return
219
+ if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']):
220
+ tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "max")
221
+ self.has_overflow = True
222
+ if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']):
223
+ tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "min")
224
+ self.has_overflow = True
225
+ else:
226
+ self.has_overflow = self.check_overflow_npu()
227
+ if self.has_overflow:
228
+ self.clear_overflow_npu()
229
+
230
+ def _analyze_tensor(self, tensor, suffix):
231
+ dump_data_name, file_path = self.get_save_file_path(suffix)
232
+ if not path_len_exceeds_limit(file_path):
233
+ self.cached_tensors_and_file_paths.update({file_path: tensor})
234
+ else:
235
+ logger.warning(f'The file path {file_path} length exceeds limit.')
236
+ single_arg = super()._analyze_tensor(tensor, suffix)
237
+ self._analyze_maybe_overflow_tensor(single_arg, tensor)
238
+ single_arg.update({"data_name": dump_data_name})
239
+ return single_arg
240
+
241
+
242
+ class FreeBenchmarkDataProcessor(PytorchDataProcessor):
243
+
244
+ def __init__(self, config, data_writer):
245
+ super().__init__(config, data_writer)
246
+ self.checker = FreeBenchmarkCheck(config=config)
247
+ self._return_forward_new_output = None
248
+ self._forward_new_output = None
249
+
250
+ def update_iter(self, current_iter):
251
+ super().update_iter(current_iter)
252
+ self.checker.update_iter(current_iter)
253
+
254
+ def update_unequal_rows(self, unequal_rows: List[UnequalRow]):
255
+ if not unequal_rows:
256
+ return
257
+ for row in unequal_rows:
258
+ data_dict = asdict(row)
259
+ self.data_writer.write_data_to_csv(
260
+ data_dict.values(),
261
+ data_dict.keys(),
262
+ self.data_writer.free_benchmark_file_path
263
+ )
264
+ return
265
+
266
+ def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
267
+ self.checker.pre_forward(name, module, self, module_input_output.args, module_input_output.kwargs)
268
+
269
+ def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs):
270
+ new_output, unequal_rows = self.checker.forward(
271
+ name,
272
+ module,
273
+ module_input_output.args,
274
+ module_input_output.kwargs,
275
+ module_input_output.output,
276
+ )
277
+ self.update_unequal_rows(unequal_rows)
278
+ if self.checker.if_fix():
279
+ self._return_forward_new_output = True
280
+ self._forward_new_output = new_output
281
+
282
+ def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs):
283
+ self.checker.backward(name, module, module_input_output.grad_output)
284
+
285
+
286
+ class KernelDumpDataProcessor(PytorchDataProcessor):
287
+ forward_init_status = False
288
+ multi_output_apis = ["_sort_", "npu_flash_attention"]
289
+
290
+ def __init__(self, config, data_writer):
291
+ super().__init__(config, data_writer)
292
+
293
+ def analyze_forward(self, name, module, module_input_output):
294
+ if self.config.is_forward_acl_dump:
295
+ self.forward_acl_dump(name, module, module_input_output)
296
+ else:
297
+ self.dump_mode_backward_acl_dump(name, module, module_input_output)
298
+
299
+ def forward_acl_dump(self, name, module, module_input_output):
300
+ if not KernelDumpDataProcessor.forward_init_status:
301
+ KernelDumpDataProcessor.forward_init_status = True
302
+ torch_npu.npu.synchronize()
303
+ torch_npu.npu.init_dump()
304
+ torch_npu.npu.set_dump(self.config.acl_config)
305
+ torch_npu.npu.synchronize()
306
+ if self.op_need_trigger(name):
307
+ module.forward(*module_input_output.args, **module_input_output.kwargs).cpu()
308
+ else:
309
+ module.forward(*module_input_output.args, **module_input_output.kwargs)
310
+ torch_npu.npu.synchronize()
311
+ torch_npu.npu.finalize_dump()
312
+ torch_npu.npu.synchronize()
313
+ KernelDumpDataProcessor.forward_init_status = False
314
+ logger.info("Dump %s op file." % name)
315
+
316
+ def acl_backward_dump_status(self, output, grad, module_name):
317
+ if isinstance(output, torch.Tensor):
318
+ output.backward(grad, retain_graph=True)
319
+ return True
320
+
321
+ for api_name in KernelDumpDataProcessor.multi_output_apis:
322
+ if api_name in module_name:
323
+ output[0].backward(grad, retain_graph=True)
324
+ return True
325
+ return False
326
+
327
+ def dump_mode_backward_acl_dump(self, name, module, module_input_output):
328
+ grad_path = self.config.backward_input.get(name)
329
+ if not KernelDumpDataProcessor.forward_init_status:
330
+ KernelDumpDataProcessor.forward_init_status = True
331
+ output = module.forward(*module_input_output.args, **module_input_output.kwargs)
332
+ grad = torch.load(grad_path).to("npu").requires_grad_()
333
+ torch_npu.npu.init_dump()
334
+ torch_npu.npu.set_dump(self.config.acl_config)
335
+ torch_npu.npu.synchronize()
336
+ if not self.acl_backward_dump_status(output, grad, name):
337
+ logger.warning("The output of {} is not of tensor type and cannot be automatically derived. "
338
+ "you can manually construct a single API backward case for ACL dump.".format(
339
+ name))
340
+ torch_npu.npu.synchronize()
341
+ torch_npu.npu.finalize_dump()
342
+ KernelDumpDataProcessor.forward_init_status = False
343
+ logger.info("Dump %s op file." % name)
344
+
345
+ def op_need_trigger(self, module_name):
346
+ return 'Tensor.__getitem__.' in module_name
@@ -0,0 +1,116 @@
1
+ import os
2
+ import csv
3
+ import fcntl
4
+ import json
5
+ from pathlib import Path
6
+
7
+ from msprobe.core.common.file_check import change_mode
8
+ from msprobe.core.common.log import logger
9
+ from msprobe.core.common.const import Const, FileCheckConst
10
+
11
+
12
+ class DataWriter:
13
+
14
+ def __init__(self, init_json=None) -> None:
15
+ self.dump_count = 0
16
+ self.init_json = init_json
17
+ self.dump_file_path = None # os.path.join(dump_dir, DataWriter.dump_json_name)
18
+ self.stack_file_path = None # os.path.join(dump_dir, DataWriter.stack_json_name)
19
+ self.construct_file_path = None # os.path.join(dump_dir, DataWriter.construct_json_name)
20
+ self.free_benchmark_file_path = None
21
+ self.dump_tensor_data_dir = None
22
+ self.buffer_size = 1000
23
+ self.cache_data = {Const.DATA: {}}
24
+ self.cache_stack = {}
25
+ self.cache_construct = {}
26
+
27
+ @staticmethod
28
+ def write_data_to_csv(result: list, result_header: tuple, file_path: str):
29
+ if not result:
30
+ return
31
+ is_exists = os.path.exists(file_path)
32
+ append = "a+" if is_exists else "w+"
33
+ with os.fdopen(
34
+ os.open(file_path, Const.WRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), append, newline=""
35
+ ) as csv_file:
36
+ spawn_writer = csv.writer(csv_file)
37
+ if not is_exists:
38
+ spawn_writer.writerow(result_header)
39
+ spawn_writer.writerows([result,])
40
+
41
+ def initialize_json_file(self, **kwargs):
42
+ kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}})
43
+ with os.fdopen(
44
+ os.open(self.dump_file_path, Const.OVERWRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), 'w'
45
+ ) as f:
46
+ json.dump(kwargs, f)
47
+
48
+ if os.path.exists(self.stack_file_path):
49
+ os.remove(self.stack_file_path)
50
+ Path(self.stack_file_path).touch()
51
+ change_mode(self.stack_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
52
+
53
+ if os.path.exists(self.construct_file_path):
54
+ os.remove(self.construct_file_path)
55
+ Path(self.construct_file_path).touch()
56
+ change_mode(self.construct_file_path, FileCheckConst.DATA_FILE_AUTHORITY)
57
+
58
+ def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir,
59
+ free_benchmark_file_path):
60
+ self.dump_file_path = dump_file_path
61
+ self.stack_file_path = stack_file_path
62
+ self.construct_file_path = construct_file_path
63
+ self.dump_tensor_data_dir = dump_data_dir
64
+ self.free_benchmark_file_path = free_benchmark_file_path
65
+
66
+ def update_data(self, new_data):
67
+ key = next(iter(new_data.keys())) # assert len(new_data.keys()) == 1
68
+ if key in self.cache_data[Const.DATA]:
69
+ self.cache_data[Const.DATA][key].update(new_data[key])
70
+ else:
71
+ self.cache_data[Const.DATA].update(new_data)
72
+
73
+ def flush_data_when_buffer_is_full(self):
74
+ if len(self.cache_data[Const.DATA]) >= self.buffer_size:
75
+ self.write_data_json(self.dump_file_path)
76
+
77
+ def update_stack(self, new_data):
78
+ self.cache_stack.update(new_data)
79
+
80
+ def update_construct(self, new_data):
81
+ self.cache_construct.update(new_data)
82
+
83
+ def write_data_json(self, file_path):
84
+ logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ")
85
+ if Path(file_path).exists() and os.path.getsize(file_path) > 0:
86
+ with open(file_path, "r+") as f:
87
+ fcntl.flock(f, fcntl.LOCK_EX)
88
+ data_to_write = json.load(f)
89
+ fcntl.flock(f, fcntl.LOCK_UN)
90
+ else:
91
+ self.init_json['data_path'] = self.dump_tensor_data_dir
92
+ data_to_write = self.init_json
93
+ data_to_write[Const.DATA].update(self.cache_data[Const.DATA])
94
+ with open(file_path, 'w+') as f:
95
+ fcntl.flock(f, fcntl.LOCK_EX)
96
+ json.dump(data_to_write, f, indent=1)
97
+ fcntl.flock(f, fcntl.LOCK_UN)
98
+
99
+ self.cache_data[Const.DATA].clear()
100
+
101
+ def write_stack_info_json(self, file_path):
102
+ with open(file_path, 'w+') as f:
103
+ fcntl.flock(f, fcntl.LOCK_EX)
104
+ json.dump(self.cache_stack, f, indent=1)
105
+ fcntl.flock(f, fcntl.LOCK_UN)
106
+
107
+ def write_construct_info_json(self, file_path):
108
+ with open(file_path, 'w+') as f:
109
+ fcntl.flock(f, fcntl.LOCK_EX)
110
+ json.dump(self.cache_construct, f, indent=1)
111
+ fcntl.flock(f, fcntl.LOCK_UN)
112
+
113
+ def write_json(self):
114
+ self.write_data_json(self.dump_file_path)
115
+ self.write_stack_info_json(self.stack_file_path)
116
+ self.write_construct_info_json(self.construct_file_path)
@@ -0,0 +1,178 @@
1
+ from abc import ABC, abstractmethod
2
+ from msprobe.core.common.exceptions import ScopeException
3
+ from msprobe.core.common.const import Const
4
+
5
+
6
+ def build_scope(scope_class, scope=None, api_list=None):
7
+ if not scope and not api_list:
8
+ return None
9
+ if scope is None:
10
+ scope = []
11
+ if api_list is None:
12
+ api_list = []
13
+ if scope_class:
14
+ return scope_class(scope, api_list)
15
+ return build_range_scope_according_to_scope_name(scope, api_list)
16
+
17
+
18
+ def build_range_scope_according_to_scope_name(scope, api_list):
19
+ api_range_scope = APIRangeScope(scope, api_list)
20
+ module_range_scope = ModuleRangeScope(scope, api_list)
21
+ if not scope: # 如果没有scope参数则用哪类scope都一样
22
+ return api_range_scope
23
+ if api_range_scope.is_valid and module_range_scope.is_valid:
24
+ raise ScopeException(ScopeException.InvalidScope, f"scope={scope}.")
25
+ elif api_range_scope.is_valid:
26
+ return api_range_scope
27
+ elif module_range_scope.is_valid:
28
+ return module_range_scope
29
+ else:
30
+ raise ScopeException(ScopeException.InvalidScope, f"scope={scope}")
31
+
32
+
33
+ class BaseScope(ABC):
34
+ Module_Type_Module = "Module"
35
+ Module_Type_API = "api"
36
+
37
+ def __init__(self, scope, api_list):
38
+ scope, api_list = self.rectify_args(scope, api_list)
39
+ self.scope = scope
40
+ self.api_list = api_list
41
+
42
+ @staticmethod
43
+ def rectify_args(scope, api_list):
44
+ if not isinstance(api_list, list):
45
+ raise ScopeException(ScopeException.InvalidApiStr,
46
+ f"api_list参数须配置为列表,实际类型为{type(api_list)}.")
47
+ for api in api_list:
48
+ if not isinstance(api, str):
49
+ raise ScopeException(ScopeException.InvalidApiStr,
50
+ f"api_list中的元素须配置为字符串,实际类型为{type(api)}.")
51
+ if isinstance(scope, str):
52
+ scope = [scope]
53
+ return scope, api_list
54
+ if not isinstance(scope, list):
55
+ raise ScopeException(ScopeException.InvalidScope,
56
+ f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.")
57
+ for s in scope:
58
+ if not isinstance(s, str):
59
+ raise ScopeException(ScopeException.InvalidScope,
60
+ f"scope列表元素要求类型为字符串,实际类型为{type(s)}.")
61
+ return scope, api_list
62
+
63
+ @abstractmethod
64
+ def check(self, name):
65
+ pass
66
+
67
+ def check_api_list(self, api_name):
68
+ if not self.api_list:
69
+ return True
70
+ for api_str in self.api_list:
71
+ if api_str in api_name:
72
+ return True
73
+ return False
74
+
75
+
76
+ class ListScope(BaseScope):
77
+ @staticmethod
78
+ def rectify_args(scope, api_list):
79
+ if scope and api_list:
80
+ raise ScopeException(ScopeException.ArgConflict,
81
+ f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.")
82
+ return super(ListScope, ListScope).rectify_args(scope, api_list)
83
+
84
+ def check(self, module_name):
85
+ if not self.scope or module_name in self.scope:
86
+ return self.check_api_list(module_name)
87
+ return False
88
+
89
+
90
+ class RangeScope(BaseScope, ABC):
91
+
92
+ def __init__(self, *args):
93
+ super().__init__(*args)
94
+ self.in_scope = False
95
+ self.is_valid = self.check_scope_is_valid()
96
+
97
+
98
+ @staticmethod
99
+ def rectify_args(scope, api_list):
100
+ scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list)
101
+ if isinstance(scope, list):
102
+ if len(scope) == 1:
103
+ scope.append(scope[0])
104
+ elif len(scope) > 2:
105
+ raise ScopeException(ScopeException.InvalidScope,
106
+ f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.")
107
+
108
+ return scope, api_list
109
+
110
+ @abstractmethod
111
+ def check_scope_is_valid(self):
112
+ pass
113
+
114
+ def begin_module(self, module_name):
115
+ pass
116
+
117
+ def end_module(self, module_name):
118
+ pass
119
+
120
+
121
+ class APIRangeScope(RangeScope):
122
+ def check_scope_is_valid(self):
123
+ if not self.scope:
124
+ return True
125
+ scope_start_type = self.scope[0].split(Const.SEP)[0]
126
+ if scope_start_type == BaseScope.Module_Type_Module:
127
+ return False
128
+ scope_stop_type = self.scope[1].split(Const.SEP)[0]
129
+ if scope_stop_type == BaseScope.Module_Type_Module:
130
+ return False
131
+ return True
132
+
133
+ def check(self, api_name):
134
+ if self.scope and api_name == self.scope[0]:
135
+ self.in_scope = True
136
+
137
+ if not self.scope or self.in_scope:
138
+ result = self.check_api_list(api_name)
139
+ else:
140
+ result = False
141
+
142
+ if self.scope and api_name == self.scope[1]:
143
+ self.in_scope = False
144
+ return result
145
+
146
+
147
+ class ModuleRangeScope(RangeScope):
148
+ """
149
+ 模块与api不同的是,模块内部还有子结构需要dump,
150
+ 需要用pre_hook和full_backward_hook来精确控制module的开始和结束,
151
+ 在这些hook触发时调用begin_module和end_module做区间控制
152
+ """
153
+ def check_scope_is_valid(self):
154
+ if not self.scope:
155
+ return True
156
+ scope_start_type = self.scope[0].split(Const.SEP)[0]
157
+ scope_stop_type = self.scope[1].split(Const.SEP)[0]
158
+ if scope_start_type == BaseScope.Module_Type_Module and \
159
+ scope_stop_type == BaseScope.Module_Type_Module:
160
+ return True
161
+ return False
162
+
163
+ def begin_module(self, module_name):
164
+ if not self.scope:
165
+ return
166
+ if module_name == self.scope[0]:
167
+ self.in_scope = True
168
+
169
+ def end_module(self, module_name):
170
+ if not self.scope:
171
+ return
172
+ if module_name == self.scope[1]:
173
+ self.in_scope = False
174
+
175
+ def check(self, module_name):
176
+ if not self.scope or self.in_scope:
177
+ return self.check_api_list(module_name)
178
+ return False
@@ -0,0 +1 @@
1
+ from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
File without changes
@@ -0,0 +1,51 @@
1
+ import os
2
+
3
+
4
+ class DebuggerConfig:
5
+ convert_map = {
6
+ "L0": "cell",
7
+ "L1": "api",
8
+ "L2": 'kernel'
9
+ }
10
+
11
+ def __init__(self, common_config, task_config):
12
+ self.dump_path = common_config.dump_path
13
+ self.task = common_config.task
14
+ self.rank = [] if not common_config.rank else common_config.rank
15
+ self.step = [] if not common_config.step else common_config.step
16
+ if not common_config.level:
17
+ common_config.level = "L1"
18
+ self.level = DebuggerConfig.convert_map[common_config.level]
19
+ self.list = [] if not task_config.list else task_config.list
20
+ self.data_mode = [] if not task_config.data_mode else task_config.data_mode
21
+ self.file_format = task_config.file_format
22
+ self.check_mode = task_config.check_mode
23
+
24
+ self.check()
25
+
26
+ def check(self):
27
+ if not self.dump_path:
28
+ raise Exception("Dump path is empty.")
29
+ if not os.path.isabs(self.dump_path):
30
+ raise Exception("Dump path must be absolute path.")
31
+ if not self.task:
32
+ self.task = "statistics"
33
+ if not self.level:
34
+ raise Exception("level must be L0, L1 or L2")
35
+ if not self.file_format:
36
+ self.file_format = "npy"
37
+ if not self.check_mode:
38
+ self.check_mode = "all"
39
+ self._check_rank()
40
+ self._check_step()
41
+ return True
42
+
43
+ def _check_rank(self):
44
+ for rank_id in self.rank:
45
+ if not isinstance(rank_id, int) or rank_id < 0:
46
+ raise ValueError(f"rank {self.rank} must be a positive integer.")
47
+
48
+ def _check_step(self):
49
+ for s in self.step:
50
+ if not isinstance(s, int):
51
+ raise ValueError(f"step element {s} should be int")