mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +39 -3
  6. msprobe/config.json +1 -3
  7. msprobe/core/advisor/advisor.py +8 -3
  8. msprobe/core/common/const.py +113 -13
  9. msprobe/core/common/exceptions.py +25 -3
  10. msprobe/core/common/file_utils.py +150 -26
  11. msprobe/core/common/inplace_op_checker.py +15 -0
  12. msprobe/core/common/log.py +27 -9
  13. msprobe/core/common/utils.py +182 -69
  14. msprobe/core/common_config.py +44 -15
  15. msprobe/core/compare/acc_compare.py +207 -142
  16. msprobe/core/compare/check.py +2 -5
  17. msprobe/core/compare/compare_cli.py +21 -4
  18. msprobe/core/compare/highlight.py +124 -55
  19. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  20. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  21. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  22. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  23. msprobe/core/compare/npy_compare.py +52 -23
  24. msprobe/core/compare/utils.py +272 -247
  25. msprobe/core/data_dump/data_collector.py +13 -11
  26. msprobe/core/data_dump/data_processor/base.py +46 -16
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
  29. msprobe/core/data_dump/scope.py +113 -34
  30. msprobe/core/grad_probe/constant.py +27 -13
  31. msprobe/core/grad_probe/grad_compare.py +18 -1
  32. msprobe/core/grad_probe/utils.py +30 -2
  33. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  34. msprobe/core/overflow_check/api_info.py +55 -0
  35. msprobe/core/overflow_check/checker.py +138 -0
  36. msprobe/core/overflow_check/filter.py +157 -0
  37. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  38. msprobe/core/overflow_check/level.py +22 -0
  39. msprobe/core/overflow_check/utils.py +28 -0
  40. msprobe/docs/01.installation.md +10 -0
  41. msprobe/docs/02.config_introduction.md +49 -22
  42. msprobe/docs/03.config_examples.md +2 -9
  43. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  44. msprobe/docs/05.data_dump_PyTorch.md +3 -1
  45. msprobe/docs/06.data_dump_MindSpore.md +157 -90
  46. msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
  47. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  48. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  49. msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
  50. msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
  51. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  52. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  53. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  54. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  55. msprobe/docs/17.grad_probe.md +5 -6
  56. msprobe/docs/19.monitor.md +468 -0
  57. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  58. msprobe/docs/21.visualization_PyTorch.md +386 -0
  59. msprobe/docs/22.visualization_MindSpore.md +384 -0
  60. msprobe/docs/23.tool_function_introduction.md +28 -0
  61. msprobe/docs/FAQ.md +3 -0
  62. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  63. msprobe/docs/img/compare_result.png +0 -0
  64. msprobe/docs/img/monitor/cpu_info.png +0 -0
  65. msprobe/mindspore/__init__.py +15 -0
  66. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
  67. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  68. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  69. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  70. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  71. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  72. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  73. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  74. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  75. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  76. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  77. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  78. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  79. msprobe/mindspore/cell_processor.py +33 -12
  80. msprobe/mindspore/common/const.py +33 -13
  81. msprobe/mindspore/common/log.py +5 -9
  82. msprobe/mindspore/common/utils.py +43 -4
  83. msprobe/mindspore/compare/distributed_compare.py +22 -22
  84. msprobe/mindspore/compare/ms_compare.py +271 -248
  85. msprobe/mindspore/compare/ms_graph_compare.py +81 -47
  86. msprobe/mindspore/debugger/debugger_config.py +4 -1
  87. msprobe/mindspore/debugger/precision_debugger.py +7 -1
  88. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  89. msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
  90. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
  91. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
  92. msprobe/mindspore/dump/jit_dump.py +17 -5
  93. msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
  94. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  95. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  96. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  97. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
  98. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  99. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  100. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  101. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  102. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  103. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  104. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  105. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  106. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  107. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
  108. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  109. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  110. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  111. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  112. msprobe/mindspore/grad_probe/global_context.py +28 -8
  113. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  114. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  115. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  116. msprobe/mindspore/grad_probe/hook.py +24 -10
  117. msprobe/mindspore/grad_probe/utils.py +18 -5
  118. msprobe/mindspore/ms_config.py +22 -15
  119. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
  120. msprobe/mindspore/runtime.py +15 -0
  121. msprobe/mindspore/service.py +36 -30
  122. msprobe/mindspore/task_handler_factory.py +15 -0
  123. msprobe/msprobe.py +24 -7
  124. msprobe/pytorch/__init__.py +3 -2
  125. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  126. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
  127. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  128. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  129. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  130. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
  131. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
  132. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
  133. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
  135. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  136. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  137. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  138. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  139. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  140. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  141. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  142. msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
  143. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  144. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  145. msprobe/pytorch/common/parse_json.py +6 -6
  146. msprobe/pytorch/common/utils.py +56 -5
  147. msprobe/pytorch/compare/distributed_compare.py +8 -9
  148. msprobe/pytorch/compare/pt_compare.py +8 -6
  149. msprobe/pytorch/debugger/debugger_config.py +19 -15
  150. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  151. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  152. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  153. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  154. msprobe/pytorch/free_benchmark/common/params.py +8 -1
  155. msprobe/pytorch/free_benchmark/common/utils.py +26 -4
  156. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
  157. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  158. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  159. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  160. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  161. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  162. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
  163. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  164. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  165. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  166. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  167. msprobe/pytorch/hook_module/wrap_functional.py +14 -12
  168. msprobe/pytorch/module_processer.py +2 -5
  169. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  170. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  171. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  172. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  173. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  174. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  175. msprobe/pytorch/monitor/features.py +108 -0
  176. msprobe/pytorch/monitor/module_hook.py +870 -0
  177. msprobe/pytorch/monitor/module_metric.py +193 -0
  178. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  179. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  180. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  181. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  182. msprobe/pytorch/monitor/utils.py +250 -0
  183. msprobe/pytorch/monitor/visualizer.py +59 -0
  184. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  185. msprobe/pytorch/online_dispatch/compare.py +29 -38
  186. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  187. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  188. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  189. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  190. msprobe/pytorch/online_dispatch/utils.py +49 -21
  191. msprobe/pytorch/parse_tool/lib/compare.py +12 -18
  192. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  193. msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
  194. msprobe/pytorch/parse_tool/lib/utils.py +16 -35
  195. msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
  196. msprobe/pytorch/pt_config.py +31 -8
  197. msprobe/pytorch/service.py +15 -5
  198. msprobe/visualization/__init__.py +14 -0
  199. msprobe/visualization/builder/__init__.py +14 -0
  200. msprobe/visualization/builder/graph_builder.py +165 -0
  201. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  202. msprobe/visualization/compare/__init__.py +14 -0
  203. msprobe/visualization/compare/graph_comparator.py +130 -0
  204. msprobe/visualization/compare/mode_adapter.py +211 -0
  205. msprobe/visualization/graph/__init__.py +14 -0
  206. msprobe/visualization/graph/base_node.py +124 -0
  207. msprobe/visualization/graph/graph.py +200 -0
  208. msprobe/visualization/graph/node_colors.py +95 -0
  209. msprobe/visualization/graph/node_op.py +39 -0
  210. msprobe/visualization/graph_service.py +214 -0
  211. msprobe/visualization/utils.py +232 -0
  212. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  213. msprobe/docs/04.acl_config_examples.md +0 -78
  214. msprobe/mindspore/compare/layer_mapping.py +0 -146
  215. msprobe/mindspore/compare/modify_mapping.py +0 -107
  216. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  217. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  218. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  219. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  220. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
@@ -1,10 +1,24 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
1
17
  import os
2
18
  import time
3
- import json
4
19
  from multiprocessing import Pool
5
20
 
6
21
  import torch
7
-
8
22
  from torch.utils._python_dispatch import TorchDispatchMode
9
23
 
10
24
  try:
@@ -14,15 +28,15 @@ except ImportError:
14
28
  else:
15
29
  is_npu = True
16
30
 
17
- from msprobe.core.common.file_utils import check_path_before_create, check_file_or_directory_path, load_yaml
31
+ from msprobe.core.common.file_utils import check_file_or_directory_path, load_yaml, FileOpen, create_directory
18
32
  from msprobe.core.common.const import Const, CompareConst
19
33
  from msprobe.pytorch.common.log import logger
20
- from msprobe.pytorch.online_dispatch.dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \
21
- DispatchRunParam, DisPatchDataInfo
22
- from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu, get_sys_info, DispatchException, COMPARE_LOGO
34
+ from msprobe.pytorch.online_dispatch.dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, \
35
+ TimeStatistics, DispatchRunParam, DisPatchDataInfo
36
+ from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu, get_sys_info, DispatchException, \
37
+ COMPARE_LOGO
23
38
  from msprobe.pytorch.online_dispatch.compare import Comparator
24
- from msprobe.core.common.file_utils import FileOpen, create_directory
25
-
39
+ from msprobe.core.common.utils import check_str_param, safe_get_value
26
40
 
27
41
  current_time = time.strftime("%Y%m%d%H%M%S")
28
42
  RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
@@ -51,14 +65,13 @@ class PtdbgDispatch(TorchDispatchMode):
51
65
  self.all_summary = []
52
66
  self.call_stack_list = []
53
67
  self.process_num = process_num
54
- self.filter_dump_api()
68
+ self.tag = tag
55
69
  self.check_param()
70
+ self.filter_dump_api()
56
71
  dir_name = self.get_dir_name(tag)
57
72
  self.root_path = os.path.join(os.path.realpath(dump_path), dir_name)
58
73
  self.root_cpu_path = os.path.join(self.root_path, f'cpu')
59
74
  self.root_npu_path = os.path.join(self.root_path, f'npu')
60
- check_path_before_create(self.root_cpu_path)
61
- check_path_before_create(self.root_npu_path)
62
75
  create_directory(self.root_cpu_path)
63
76
  create_directory(self.root_npu_path)
64
77
 
@@ -67,7 +80,7 @@ class PtdbgDispatch(TorchDispatchMode):
67
80
  self.comparator = Comparator(self.result_csv_path, self.detail_csv_path, False)
68
81
 
69
82
  self.aten_ops_blacklist = []
70
- self.npu_adjust_autogard = []
83
+ self.npu_adjust_autograd = []
71
84
  yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
72
85
  self.get_ops(yaml_path)
73
86
 
@@ -76,8 +89,8 @@ class PtdbgDispatch(TorchDispatchMode):
76
89
  self.pool = Pool(process_num)
77
90
  if debug:
78
91
  logger.info(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} '
79
- f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], '
80
- f'process[{process_num}]')
92
+ f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], '
93
+ f'process[{process_num}]')
81
94
 
82
95
  def __exit__(self, exc_type, exc_val, exc_tb):
83
96
  super().__exit__(exc_type, exc_val, exc_tb)
@@ -119,7 +132,7 @@ class PtdbgDispatch(TorchDispatchMode):
119
132
  output_num = output_num + 1
120
133
  total_num = total_num + 1
121
134
  logger.info(f'Dispatch exit: Device[{self.device_id}], Pid[{os.getpid()} Input[{input_num}] '
122
- f'Output[{output_num}] Total[{total_num}] API_Total[{self.api_index}]]')
135
+ f'Output[{output_num}] Total[{total_num}] API_Total[{self.api_index}]]')
123
136
 
124
137
  def __torch_dispatch__(self, func, types, args=(), kwargs=None):
125
138
  if not is_npu:
@@ -134,7 +147,7 @@ class PtdbgDispatch(TorchDispatchMode):
134
147
  logger.error(f"Please check the func name {func.__name__}!")
135
148
  return func(*args, **kwargs)
136
149
 
137
- self.enable_autogard(aten_api)
150
+ self.enable_autograd(aten_api)
138
151
  if aten_api in self.aten_ops_blacklist:
139
152
  npu_out = func(*args, **kwargs)
140
153
  return npu_out
@@ -151,21 +164,22 @@ class PtdbgDispatch(TorchDispatchMode):
151
164
 
152
165
  if self.debug_flag:
153
166
  logger.info(f'Dispatch Info: Rank[{self.device_id}], Pid[{os.getpid()}], Func[{func.__name__}], '
154
- f'Name[{run_param.aten_api}_{run_param.single_api_index}], '
155
- f'Count[{self.api_index}], Sys[{get_sys_info()}]')
167
+ f'Name[{run_param.aten_api}_{run_param.single_api_index}], '
168
+ f'Count[{self.api_index}], Sys[{get_sys_info()}]')
156
169
 
157
170
  cpu_args = []
158
171
  cpu_kwargs = []
159
172
  data_to_cpu(args, 0, cpu_args)
160
173
  data_to_cpu(kwargs, 0, cpu_kwargs)
161
- cpu_args = cpu_args[0]
162
- cpu_kwargs = cpu_kwargs[0]
174
+
175
+ cpu_args = safe_get_value(cpu_args, 0, "cpu_args")
176
+ cpu_kwargs = safe_get_value(cpu_kwargs, 0, "cpu_kwargs")
163
177
 
164
178
  with TimeStatistics("NPU RUN", run_param):
165
179
  npu_out = func(*args, **kwargs)
166
180
  npu_out_cpu = []
167
181
  data_to_cpu(npu_out, 0, npu_out_cpu)
168
- npu_out_cpu = npu_out_cpu[0]
182
+ npu_out_cpu = safe_get_value(npu_out_cpu, 0, "npu_out_cpu")
169
183
 
170
184
  with TimeStatistics("CPU RUN", run_param):
171
185
  cpu_out = func(*cpu_args, **cpu_kwargs)
@@ -216,7 +230,7 @@ class PtdbgDispatch(TorchDispatchMode):
216
230
  def get_ops(self, file_path):
217
231
  yaml_file = load_yaml(file_path)
218
232
  self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist')
219
- self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard')
233
+ self.npu_adjust_autograd = yaml_file.get('npu_adjust_autograd')
220
234
 
221
235
  def filter_dump_api(self):
222
236
  if self.dump_mode != Const.LIST or not self.dump_api_list:
@@ -260,6 +274,17 @@ class PtdbgDispatch(TorchDispatchMode):
260
274
  if not isinstance(self.dump_api_list, list):
261
275
  logger.error('The type of parameter "api_list" can only be list.')
262
276
  raise DispatchException(DispatchException.INVALID_PARAMETER)
277
+ if not all(isinstance(item, str) for item in self.dump_api_list):
278
+ logger.error('The type of parameter in "api_list" can only be str.')
279
+ raise DispatchException(DispatchException.INVALID_PARAMETER)
280
+ if len(self.dump_api_list) > Const.STEP_RANK_MAXIMUM_VALUE:
281
+ logger.error('The length of parameter "api_list" should not be greater '
282
+ f'than {Const.STEP_RANK_MAXIMUM_VALUE}.')
283
+ raise DispatchException(DispatchException.INVALID_PARAMETER)
284
+ for item in self.dump_api_list:
285
+ check_str_param(item)
286
+ if self.tag is not None:
287
+ check_str_param(self.tag)
263
288
  if not isinstance(self.debug_flag, bool):
264
289
  logger.error('The type of parameter "debug" can only be bool.')
265
290
  raise DispatchException(DispatchException.INVALID_PARAMETER)
@@ -267,6 +292,6 @@ class PtdbgDispatch(TorchDispatchMode):
267
292
  logger.error('The type of parameter "process_num" can only be int and it should not be less than 0.')
268
293
  raise DispatchException(DispatchException.INVALID_PARAMETER)
269
294
 
270
- def enable_autogard(self, aten_api):
271
- if aten_api in self.npu_adjust_autogard:
272
- torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False)
295
+ def enable_autograd(self, aten_api):
296
+ if aten_api in self.npu_adjust_autograd:
297
+ torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.AutogradFunctionality, False)
@@ -1,11 +1,26 @@
1
- import os
2
- import json
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
3
16
  import copy
17
+ import json
18
+ import os
4
19
  from datetime import datetime, timezone
5
20
 
6
21
  import torch
22
+ from msprobe.core.common.file_utils import FileOpen, save_npy, save_json
7
23
  from msprobe.pytorch.common.log import logger
8
- from msprobe.core.common.file_utils import FileOpen, save_npy
9
24
 
10
25
 
11
26
  class DispatchRunParam:
@@ -55,7 +70,7 @@ class TimeStatistics:
55
70
  if self.debug:
56
71
  self.time = datetime.now(tz=timezone.utc)
57
72
  logger.info(f'Time[{self.tag}]-ENTER: Dev[{self.device}], Pid[{os.getpid()}], Fun[{self.fun}], ' \
58
- f'Id[{self.index}]')
73
+ f'Id[{self.index}]')
59
74
 
60
75
  def __exit__(self, exc_type, exc_val, exc_tb):
61
76
  if self.debug:
@@ -92,10 +107,8 @@ def dump_data(data, prefix, dump_path):
92
107
  def save_temp_summary(api_index, single_api_summary, path, lock):
93
108
  summary_path = os.path.join(path, f'summary.json')
94
109
  lock.acquire()
95
- with FileOpen(summary_path, "a") as f:
96
- json.dump([api_index, single_api_summary], f)
97
- f.write('\n')
98
- lock.release()
110
+ data = [api_index, single_api_summary]
111
+ save_json(summary_path, data, mode='a')
99
112
 
100
113
 
101
114
  def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo):
@@ -152,4 +165,3 @@ def dispatch_multiprocess(run_param, dispatch_data_info):
152
165
 
153
166
  def error_call(err):
154
167
  logger.error(f'multiprocess {err}')
155
-
@@ -1,9 +1,27 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import logging
17
+ from collections import namedtuple
2
18
  from functools import wraps
19
+
3
20
  import torch
4
- from prettytable import PrettyTable
5
- from collections import namedtuple
6
21
  from msprobe.pytorch.common.log import logger
22
+ from msprobe.pytorch.online_dispatch.utils import check_idx_valid
23
+ from prettytable import PrettyTable
24
+
7
25
 
8
26
  def func_log_wrapper():
9
27
  def _out_wrapper(func):
@@ -13,9 +31,9 @@ def func_log_wrapper():
13
31
  x = func(*kargs, **kwargs)
14
32
  logger.info(f"end to run: {func.__name__}")
15
33
  return x
16
-
34
+
17
35
  return _in_wrapper
18
-
36
+
19
37
  return _out_wrapper
20
38
 
21
39
 
@@ -31,7 +49,7 @@ class SingleBenchmarkCompareStandard:
31
49
  torch.bfloat16: 2 ** -7,
32
50
  torch.float32: 2 ** -14,
33
51
  torch.float64: 2 ** -14}
34
-
52
+
35
53
  def get_error_thd(self, dtype):
36
54
  if dtype in self.error_thd.keys():
37
55
  if dtype == torch.float64:
@@ -42,12 +60,12 @@ class SingleBenchmarkCompareStandard:
42
60
  "in fp16, bf16, fp32. "
43
61
  )
44
62
  return None
45
-
63
+
46
64
  def get_eb_thd(self, dtype):
47
65
  if dtype in self.eb_thd.keys():
48
66
  return self.eb_thd.get(dtype)
49
67
  return None
50
-
68
+
51
69
 
52
70
  class SingleBenchmarkAccuracyResult:
53
71
  def __init__(
@@ -82,7 +100,7 @@ class SingleBenchmarkAccuracyCompare:
82
100
  @func_log_wrapper()
83
101
  def check_output_size(cls, npu_out, bench_out):
84
102
  acc_result = None
85
- if npu_out.numel() == 0 and bench_out.nuimel() == 0:
103
+ if npu_out.numel() == 0 and bench_out.numel() == 0:
86
104
  info = (
87
105
  "The npu_output is [], and it is same as benchmark_output, "
88
106
  "the result of data_compare is Pass"
@@ -99,14 +117,14 @@ class SingleBenchmarkAccuracyCompare:
99
117
  logging.error(error_info)
100
118
  acc_result = SingleBenchmarkAccuracyResult(result=False)
101
119
  return acc_result
102
-
120
+
103
121
  @classmethod
104
122
  @func_log_wrapper()
105
123
  def check_output_invalid_value(cls, output):
106
124
  has_nan = torch.isnan(output).any()
107
125
  has_inf = torch.isinf(output).any()
108
126
  return has_nan or has_inf
109
-
127
+
110
128
  @classmethod
111
129
  @func_log_wrapper()
112
130
  def precision_compare_for_case(cls, npu_out, bench_out, benchmark_standard: SingleBenchmarkCompareStandard):
@@ -119,19 +137,19 @@ class SingleBenchmarkAccuracyCompare:
119
137
  if acc_result:
120
138
  failed_info = "比对数据的shape不一致"
121
139
  return CompareResultInfo(acc_result, error_thd, eb_thd, failed_info)
122
-
140
+
123
141
  if cls.check_output_invalid_value(bench_out):
124
142
  logging.info("The benchmark result contains nan/inf value. ")
125
143
  failed_info = "标杆结果存在nan值或inf值, 依照单标杆标准该用例通过"
126
144
  acc_result = SingleBenchmarkAccuracyResult(result=True)
127
145
  return CompareResultInfo(acc_result, error_thd, eb_thd, failed_info)
128
-
146
+
129
147
  if cls.check_output_invalid_value(npu_out):
130
148
  logging.info("The NPU result contains nan/inf value. ")
131
149
  failed_info = "NPU结果存在nan值或inf值, 依照单标杆标准该用例不通过"
132
150
  acc_result = SingleBenchmarkAccuracyResult(result=False)
133
151
  return CompareResultInfo(acc_result, error_thd, eb_thd, failed_info)
134
-
152
+
135
153
  data_type = npu_out.dtype
136
154
  if data_type not in [torch.float16, torch.float32, torch.float64, torch.bfloat16]:
137
155
  acc_result = cls.compute_binary_diff(npu_out, bench_out)
@@ -159,7 +177,6 @@ class SingleBenchmarkAccuracyCompare:
159
177
  acc_result.get_result(eb_thd, error_thd)
160
178
  return CompareResultInfo(acc_result, error_thd, eb_thd, None)
161
179
 
162
-
163
180
  @classmethod
164
181
  @func_log_wrapper()
165
182
  def compute_binary_diff(cls, npu_out, bench_out):
@@ -167,7 +184,7 @@ class SingleBenchmarkAccuracyCompare:
167
184
  if result:
168
185
  logger.info("二进制精度比对通过, 无需单标杆比对法验证")
169
186
  return SingleBenchmarkAccuracyResult(result=result, max_abs_diff=0, max_rel_diff=0, error_balance=0)
170
-
187
+
171
188
  @classmethod
172
189
  @func_log_wrapper()
173
190
  def compute_error_balance(cls, npu_out, bench_out, benchmark_standard: SingleBenchmarkCompareStandard):
@@ -176,11 +193,11 @@ class SingleBenchmarkAccuracyCompare:
176
193
  abs_mask_idx = torch.where(torch.abs(bench_out) < benchmark_standard.small_value, ones, zeros)
177
194
  abs_mask_idx = abs_mask_idx.type(torch.bool)
178
195
  diff_value = torch.subtract(npu_out, bench_out)
179
- diff_value_rel = diff_value / (torch.abs(bench_out) + torch.finfo(torch.float).eps )
196
+ diff_value_rel = diff_value / (torch.abs(bench_out) + torch.finfo(torch.float).eps)
180
197
  rel_and_abs = torch.where(abs_mask_idx, diff_value, diff_value_rel)
181
198
  eb_float = float(torch.mean(rel_and_abs))
182
199
  return eb_float
183
-
200
+
184
201
  @classmethod
185
202
  @func_log_wrapper()
186
203
  def compute_abs_diff(cls, npu_out, bench_out, error_thd, benchmark_standard: SingleBenchmarkCompareStandard):
@@ -200,15 +217,16 @@ class SingleBenchmarkAccuracyCompare:
200
217
  err_for_max = torch.where(abs_err_idx == 1, diff_abs, zeros)
201
218
  logging.debug("err_for_max for abs %s", err_for_max)
202
219
  max_abs_idx = torch.argmax(err_for_max)
203
- max_abs_diff = diff_abs[max_abs_idx]
220
+ if check_idx_valid(diff_abs, max_abs_idx):
221
+ max_abs_diff = diff_abs[max_abs_idx]
204
222
  elif torch.sum(abs_mask_idx) > 0:
205
223
  err_for_max = torch.where(abs_mask_idx == 1, diff_abs, zeros)
206
224
  logging.debug("error_for_max for abs %s", err_for_max)
207
225
  max_abs_idx = torch.argmax(err_for_max)
208
- if err_for_max.max() != 0:
226
+ if err_for_max.max() != 0 and check_idx_valid(diff_abs, max_abs_idx):
209
227
  max_abs_diff = diff_abs[max_abs_idx]
210
228
  return (float(max_abs_diff), int(max_abs_idx) if torch.is_tensor(max_abs_idx) else max_abs_idx)
211
-
229
+
212
230
  @classmethod
213
231
  @func_log_wrapper()
214
232
  def compute_rel_diff(cls, npu_out, bench_out, error_thd, benchmark_standard: SingleBenchmarkCompareStandard):
@@ -221,7 +239,7 @@ class SingleBenchmarkAccuracyCompare:
221
239
  diff_abs = torch.abs(diff_value)
222
240
 
223
241
  rel_mask_idx = torch.where(torch.abs(bench_out) >= benchmark_standard.small_value, ones, zeros)
224
- rel_err = diff_abs / (torch.abs(bench_out) + torch.finfo(torch.float).eps )
242
+ rel_err = diff_abs / (torch.abs(bench_out) + torch.finfo(torch.float).eps)
225
243
  diff_rel = rel_err
226
244
  rel_err_idx = torch.where(rel_err > error_thd, ones, zeros)
227
245
  rel_err_idx = rel_err_idx * rel_mask_idx
@@ -230,19 +248,20 @@ class SingleBenchmarkAccuracyCompare:
230
248
  err_for_max = torch.where(rel_err_idx == 1, diff_rel, zeros)
231
249
  logging.debug("error_for_max for rel %s", err_for_max)
232
250
  max_rel_idx = torch.argmax(err_for_max)
233
- max_rel_diff = diff_rel[max_rel_idx]
251
+ if check_idx_valid(diff_rel, max_rel_idx):
252
+ max_rel_diff = diff_rel[max_rel_idx]
234
253
  elif torch.sum(rel_mask_idx > 0):
235
254
  err_for_max = torch.where(rel_mask_idx == 1, diff_rel, zeros)
236
255
  logging.debug("err_for_max for rel %s", err_for_max)
237
256
  max_rel_idx = torch.argmax(err_for_max)
238
- if torch.sum(err_for_max) != 0:
257
+ if torch.sum(err_for_max) != 0 and check_idx_valid(diff_rel, max_rel_idx):
239
258
  max_rel_diff = diff_rel[max_rel_idx]
240
259
  return (float(max_rel_diff), int(max_rel_idx) if torch.is_tensor(max_rel_idx) else max_rel_idx)
241
260
 
242
261
 
243
262
  class SingleBenchSummary:
244
263
  def __init__(self, precision_result: SingleBenchmarkAccuracyResult, npu_dtype=None,
245
- bench_dtype=None, shape=None, error_thd=None, eb_thd=None, failed_info=None):
264
+ bench_dtype=None, shape=None, error_thd=None, eb_thd=None, failed_info=None):
246
265
  self.npu_dtype = npu_dtype
247
266
  self.bench_dtype = bench_dtype
248
267
  self.shape = shape
@@ -261,12 +280,13 @@ class SingleBenchSummary:
261
280
  return "PASS"
262
281
  else:
263
282
  return "FAILED"
264
-
283
+
265
284
  def get_result_msg(self):
266
285
  result_str = ""
267
286
  if self.failed_info:
268
- return self.failed_info
269
-
287
+ result_str = self.failed_info
288
+ return result_str
289
+
270
290
  if self.result:
271
291
  result_str += "误差均衡性EB: %s <= 阈值%s\n" % (self.error_balance, self.eb_thd)
272
292
  result_str += "最大绝对误差: %s <= 阈值%s\n" % (self.max_abs_diff, self.error_thd)
@@ -290,7 +310,7 @@ class SingleBenchSummary:
290
310
  self.max_rel_diff,
291
311
  )
292
312
  return result_str
293
-
313
+
294
314
  def print_detail_table(self):
295
315
  table = PrettyTable()
296
316
  table.title = "Single Benchmark Metrics Info"
@@ -307,7 +327,7 @@ class SingleBenchSummary:
307
327
  return [self.bench_dtype, self.npu_dtype, self.shape, self.error_balance,
308
328
  self.max_abs_diff, self.max_abs_idx, self.max_rel_diff, self.max_rel_idx,
309
329
  self.eb_thd, self.error_thd, self.result, self.failed_info]
310
-
330
+
311
331
 
312
332
  def single_benchmark_compare(npu_out: torch.Tensor, bench_out: torch.Tensor, high_precision: bool = True):
313
333
  benchmark_standard = SingleBenchmarkCompareStandard(high_precision)
@@ -322,8 +342,9 @@ def single_benchmark_compare(npu_out: torch.Tensor, bench_out: torch.Tensor, hig
322
342
  failed_info
323
343
  ) = (compare_results.accuracy_result, compare_results.error_threshold,
324
344
  compare_results.eb_threshold, compare_results.failed_information)
325
-
326
- summary = SingleBenchSummary(precision_result, str(npu_out.dtype), str(bench_out.dtype), tuple(npu_out.shape), error_thd, eb_thd, failed_info)
345
+
346
+ summary = SingleBenchSummary(precision_result, str(npu_out.dtype), str(bench_out.dtype), tuple(npu_out.shape),
347
+ error_thd, eb_thd, failed_info)
327
348
  result = summary.result
328
349
  details = summary.to_column_value()
329
350
  return result, details
@@ -349,7 +370,7 @@ def calc_status_details_dict(npu_out, bench_out, summary):
349
370
  summary.failed_info = "bench and npu_output dict keys are different."
350
371
  return False, summary.to_column_value()
351
372
  else:
352
- status, details = single_benchmark_compare_wrap(list(bench_out.values(), list(npu_out.values())))
373
+ status, details = single_benchmark_compare_wrap(list(bench_out.values()), list(npu_out.values()))
353
374
  return status, details
354
375
 
355
376
 
@@ -49,7 +49,7 @@ aten_ops_blacklist:
49
49
  - zeros
50
50
  - zeros_like
51
51
 
52
- npu_adjust_autogard:
52
+ npu_adjust_autograd:
53
53
  - adaptive_avg_pool2d
54
54
  - batch_norm
55
55
  - log_softmax
@@ -1,7 +1,23 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import inspect
17
+
18
+ import numpy as np
2
19
  import psutil
3
20
  import torch
4
- import numpy as np
5
21
 
6
22
  try:
7
23
  import torch_npu
@@ -11,6 +27,7 @@ else:
11
27
  pta_cpu_device = torch.device("cpu")
12
28
 
13
29
  from msprobe.core.common.const import CompareConst
30
+ from msprobe.pytorch.common.log import logger
14
31
 
15
32
  cpu_device = torch._C.device("cpu")
16
33
  COLOR_RED = '\033[31m'
@@ -31,24 +48,26 @@ COMPARE_LOGO = '''
31
48
  |_|
32
49
  '''
33
50
 
34
- CSV_COLUMN_NAME = [CompareConst.NPU_NAME,
35
- CompareConst.BENCH_NAME,
36
- CompareConst.NPU_DTYPE,
37
- CompareConst.BENCH_DTYPE,
38
- CompareConst.NPU_SHAPE,
39
- CompareConst.BENCH_SHAPE,
40
- CompareConst.NPU_MAX,
41
- CompareConst.NPU_MIN,
42
- CompareConst.NPU_MEAN,
43
- CompareConst.BENCH_MAX,
44
- CompareConst.BENCH_MIN,
45
- CompareConst.BENCH_MEAN,
46
- CompareConst.COSINE,
47
- CompareConst.MAX_ABS_ERR,
48
- CompareConst.MAX_RELATIVE_ERR,
49
- CompareConst.ACCURACY,
50
- CompareConst.STACK,
51
- CompareConst.ERROR_MESSAGE]
51
+ CSV_COLUMN_NAME = [
52
+ CompareConst.NPU_NAME,
53
+ CompareConst.BENCH_NAME,
54
+ CompareConst.NPU_DTYPE,
55
+ CompareConst.BENCH_DTYPE,
56
+ CompareConst.NPU_SHAPE,
57
+ CompareConst.BENCH_SHAPE,
58
+ CompareConst.NPU_MAX,
59
+ CompareConst.NPU_MIN,
60
+ CompareConst.NPU_MEAN,
61
+ CompareConst.BENCH_MAX,
62
+ CompareConst.BENCH_MIN,
63
+ CompareConst.BENCH_MEAN,
64
+ CompareConst.COSINE,
65
+ CompareConst.MAX_ABS_ERR,
66
+ CompareConst.MAX_RELATIVE_ERR,
67
+ CompareConst.ACCURACY,
68
+ CompareConst.STACK,
69
+ CompareConst.ERROR_MESSAGE
70
+ ]
52
71
 
53
72
  FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16]
54
73
  BOOL_TYPE = [bool, np.uint8]
@@ -58,8 +77,11 @@ INT_TYPE = [np.int32, np.int64]
58
77
  def get_callstack():
59
78
  callstack = []
60
79
  for (_, path, line, func, code, _) in inspect.stack()[2:]:
61
- stack_line = [path, str(line), func, code[0].strip() if code else code]
62
- callstack.append(stack_line)
80
+ try:
81
+ stack_line = [path, str(line), func, code[0].strip() if code else code]
82
+ callstack.append(stack_line)
83
+ except IndexError:
84
+ logger.error("Failed to get callstack for code:{} index out of range".format(code))
63
85
  return callstack
64
86
 
65
87
 
@@ -125,3 +147,9 @@ class DispatchException(Exception):
125
147
 
126
148
  def __str__(self):
127
149
  return self.err_msg
150
+
151
+
152
+ def check_idx_valid(data, idx):
153
+ if data is not None and data.numel() > 0 and 0 <= idx < data.numel():
154
+ return True
155
+ return False
@@ -22,7 +22,7 @@ from collections import namedtuple
22
22
  from msprobe.pytorch.parse_tool.lib.utils import Util
23
23
  from msprobe.pytorch.parse_tool.lib.config import Const
24
24
  from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException
25
- from msprobe.core.common.file_utils import create_directory, load_npy, save_npy_to_txt, write_csv
25
+ from msprobe.core.common.file_utils import create_directory, load_npy, save_npy_to_txt, write_csv, os_walk_for_files
26
26
 
27
27
 
28
28
  class Compare:
@@ -240,20 +240,14 @@ class Compare:
240
240
 
241
241
  def convert_api_dir_to_npy(self, dump_dir, param, output_dir, msaccucmp_path):
242
242
  dump_dir = self.util.path_strip(dump_dir)
243
- for root, _, files in os.walk(dump_dir, topdown=True):
244
- self.util.check_path_valid(root)
245
- for file in files:
246
- file_path = os.path.join(root, file)
247
- file_name = os.path.basename(file_path)
248
- parts = file_name.split(".")
249
- if len(parts) < 5:
250
- continue
251
- op_name = parts[1]
252
- timestamp = parts[-1]
253
- output_path = os.path.join(output_dir, op_name, timestamp)
254
- self.convert_dump_to_npy(file_path, param, output_path, msaccucmp_path)
255
- path_depth = root.count(os.sep)
256
- if path_depth <= Const.MAX_TRAVERSAL_DEPTH:
257
- yield root, _, files
258
- else:
259
- _[:] = []
243
+ files = os_walk_for_files(dump_dir, Const.MAX_TRAVERSAL_DEPTH)
244
+ filepaths = [os.path.join(file['root'], file['file']) for file in files]
245
+ for path in filepaths:
246
+ filename = os.path.basename(path)
247
+ parts = filename.split(".")
248
+ if len(parts) < 5:
249
+ continue
250
+ op_name = parts[1]
251
+ timestamp = parts[-1]
252
+ output_path = os.path.join(output_dir, op_name, timestamp)
253
+ self.convert_dump_to_npy(path, param, output_path, msaccucmp_path)
@@ -33,7 +33,7 @@ class Const:
33
33
  OFFLINE_DUMP_CONVERT_PATTERN = \
34
34
  r"^([A-Za-z0-9_-]+)\.([A-Za-z0-9_-]+)\.([0-9]+)(\.[0-9]+)?\.([0-9]{1,255})" \
35
35
  r"\.([a-z]+)\.([0-9]{1,255})(\.[x0-9]+)?\.npy$"
36
- NUMPY_PATTERN = r"^[\w\-_-]\.npy$"
36
+ NUMPY_PATTERN = r"^[\w\-_.]+\.npy$"
37
37
  NPY_SUFFIX = ".npy"
38
38
  PKL_SUFFIX = ".pkl"
39
39
  DIRECTORY_LENGTH = 4096
@@ -132,8 +132,7 @@ class ParseTool:
132
132
  " '-m' and '-g'.")
133
133
  raise ParseException("My directory path and golden directory path is same.")
134
134
  output_path = self.util.path_strip(args.output_path) if args.output_path else Const.BATCH_COMPARE_DIR
135
- if not os.path.isdir(output_path):
136
- os.makedirs(output_path, mode=0o750)
135
+ create_directory(output_path)
137
136
  self.compare.compare_converted_dir(my_dump_dir, golden_dump_dir, output_path)
138
137
 
139
138
  @catch_exception