mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,30 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import multiprocessing
2
17
  import os
3
- import json
4
18
  import pandas as pd
5
- from msprobe.core.common.file_utils import FileOpen
19
+ from tqdm import tqdm
20
+ from msprobe.core.common.file_utils import load_json
6
21
  from msprobe.core.common.const import CompareConst, Const
7
22
  from msprobe.core.common.exceptions import FileCheckException
8
23
  from msprobe.core.common.log import logger
9
- from msprobe.core.common.utils import add_time_with_xlsx, CompareException
24
+ from msprobe.core.common.utils import add_time_with_xlsx, CompareException, check_op_str_pattern_valid
10
25
  from msprobe.core.common.file_utils import remove_path
11
- from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op
26
+ from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op, check_dump_json_str, \
27
+ check_stack_json_str
12
28
  from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx
13
29
  from msprobe.core.compare.utils import read_op, merge_tensor, get_un_match_accuracy, get_accuracy
14
30
  from msprobe.core.compare.multiprocessing_compute import _handle_multi_process, ComparisonResult, _save_cmp_result
@@ -21,10 +37,53 @@ class Comparator:
21
37
 
22
38
  def __init__(self):
23
39
  pass
40
+
41
+ @staticmethod
42
+ def get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all, bench_ops_all, *args):
43
+ result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
44
+ bench_ops_all.get(bench_op_name).get('struct')[0],
45
+ npu_ops_all.get(ms_op_name).get('struct')[1],
46
+ bench_ops_all.get(bench_op_name).get('struct')[1],
47
+ npu_ops_all.get(ms_op_name).get('struct')[2],
48
+ bench_ops_all.get(bench_op_name).get('struct')[2],
49
+ CompareConst.PASS if npu_ops_all.get(ms_op_name).get('struct')[2]
50
+ == bench_ops_all.get(bench_op_name).get('struct')[2]
51
+ else CompareConst.DIFF]
52
+ if args[0]:
53
+ result_item.extend(args[1])
54
+ else:
55
+ result_item.append(CompareConst.NONE)
56
+ return result_item
57
+
58
+ @staticmethod
59
+ def calculate_summary_data(npu_summary_data, bench_summary_data, result_item):
60
+ err_msg = ""
61
+ start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF)
62
+ warning_flag = False
63
+ for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)):
64
+ if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)):
65
+ diff = npu_val - bench_val
66
+ if bench_val != 0:
67
+ relative = str(abs((diff / bench_val) * 100)) + '%'
68
+ else:
69
+ relative = "N/A"
70
+ result_item[start_idx + i] = diff
71
+ result_item[start_idx + i + 4] = relative
72
+ magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10)
73
+ if magnitude_diff > 0.5:
74
+ warning_flag = True
75
+ else:
76
+ result_item[start_idx + i] = CompareConst.NONE
77
+ accuracy_check = CompareConst.WARNING if warning_flag else ""
78
+ err_msg += "Need double check api accuracy." if warning_flag else ""
79
+ for i in range(start_idx, len(result_item)):
80
+ if str(result_item[i]) in ('inf', '-inf', 'nan'):
81
+ result_item[i] = f'{result_item[i]}\t'
82
+ result_item.append(accuracy_check)
83
+ result_item.append(err_msg)
24
84
 
25
85
  @classmethod
26
- def make_result_table(cls,result, md5_compare, summary_compare, stack_mode):
27
- header = []
86
+ def make_result_table(cls, result, md5_compare, summary_compare, stack_mode):
28
87
  if md5_compare:
29
88
  header = CompareConst.MD5_COMPARE_RESULT_HEADER[:]
30
89
  elif summary_compare:
@@ -47,17 +106,22 @@ class Comparator:
47
106
  else:
48
107
  for row in result:
49
108
  del row[-1]
50
- result_df = pd.DataFrame(result, columns=header)
109
+ result_df = pd.DataFrame(result, columns=header, dtype='object')
51
110
  return result_df
52
111
 
53
112
  @classmethod
54
- def gen_merge_list(self, json_data, op_name,stack_json_data, summary_compare, md5_compare):
113
+ def gen_merge_list(cls, json_data, op_name, stack_json_data, summary_compare, md5_compare):
55
114
  op_data = json_data['data'][op_name]
115
+ check_dump_json_str(op_data, op_name)
56
116
  op_parsed_list = read_op(op_data, op_name)
57
- if op_name in stack_json_data:
58
- op_parsed_list.append({'full_op_name': op_name, 'full_info': stack_json_data[op_name]})
59
- else:
60
- op_parsed_list.append({'full_op_name': op_name, 'full_info': None})
117
+
118
+ stack_info = stack_json_data.get(op_name)
119
+ if stack_info is not None:
120
+ check_stack_json_str(stack_info, op_name)
121
+ op_parsed_list.append({
122
+ 'full_op_name': op_name,
123
+ 'full_info': stack_info
124
+ })
61
125
 
62
126
  merge_list = merge_tensor(op_parsed_list, summary_compare, md5_compare)
63
127
  return merge_list
@@ -67,7 +131,7 @@ class Comparator:
67
131
  b_op_name = bench_dict["op_name"]
68
132
  graph_mode = check_graph_mode(a_op_name[0], b_op_name[0])
69
133
 
70
- frame_name = getattr(self,"frame_name")
134
+ frame_name = getattr(self, "frame_name")
71
135
  if frame_name == "PTComparator":
72
136
  from msprobe.pytorch.compare.match import graph_mapping
73
137
  if graph_mode:
@@ -94,11 +158,11 @@ class Comparator:
94
158
  return n_index, len(bench_queue) - 1
95
159
  return -1, -1
96
160
 
97
- def compare_process(self, file_handles, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False):
98
- npu_json_handle, bench_json_handle, stack_json_handle = file_handles
99
- npu_json_data = json.load(npu_json_handle)
100
- bench_json_data = json.load(bench_json_handle)
101
- stack_json_data = json.load(stack_json_handle)
161
+ def compare_process(self, file_lists, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False):
162
+ npu_json_path, bench_json_path, stack_json_path = file_lists
163
+ npu_json_data = load_json(npu_json_path)
164
+ bench_json_data = load_json(bench_json_path)
165
+ stack_json_data = load_json(stack_json_path)
102
166
 
103
167
  if fuzzy_match:
104
168
  logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.")
@@ -114,14 +178,19 @@ class Comparator:
114
178
  last_npu_ops_len = 0
115
179
  last_bench_ops_len = 0
116
180
 
181
+ npu_api_nums = len(npu_json_data['data'])
182
+ progress_bar = tqdm(total=npu_api_nums, desc="API/Module Read Progress", unit="item", ncols=100)
183
+
117
184
  while True:
118
185
  if not read_err_npu and not read_err_bench:
119
186
  break
120
187
  try:
121
188
  last_npu_ops_len = len(npu_ops_queue)
122
189
  op_name_npu = next(ops_npu_iter)
190
+ check_op_str_pattern_valid(op_name_npu)
123
191
  read_err_npu = True
124
- npu_merge_list = self.gen_merge_list(npu_json_data,op_name_npu,stack_json_data,summary_compare,md5_compare)
192
+ npu_merge_list = self.gen_merge_list(npu_json_data, op_name_npu, stack_json_data,
193
+ summary_compare, md5_compare)
125
194
  if npu_merge_list:
126
195
  npu_ops_queue.append(npu_merge_list)
127
196
  except StopIteration:
@@ -129,12 +198,16 @@ class Comparator:
129
198
  try:
130
199
  last_bench_ops_len = len(bench_ops_queue)
131
200
  op_name_bench = next(ops_bench_iter)
132
- bench_merge_list = self.gen_merge_list(bench_json_data,op_name_bench,stack_json_data,summary_compare,md5_compare)
201
+ check_op_str_pattern_valid(op_name_bench)
202
+ bench_merge_list = self.gen_merge_list(bench_json_data, op_name_bench, stack_json_data,
203
+ summary_compare, md5_compare)
133
204
  if bench_merge_list:
134
205
  bench_ops_queue.append(bench_merge_list)
135
206
  except StopIteration:
136
207
  read_err_bench = False
137
208
 
209
+ progress_bar.update(1)
210
+
138
211
  # merge all boolean expressions
139
212
  both_empty = not npu_ops_queue and not bench_ops_queue
140
213
  no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len)
@@ -163,7 +236,91 @@ class Comparator:
163
236
 
164
237
  result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
165
238
  return result_df
166
-
239
+
240
+ def merge_data(self, json_data, stack_json_data, summary_compare, md5_compare):
241
+ ops_all = {}
242
+ for op_name in json_data.get('data', {}):
243
+ merge_list = self.gen_merge_list(json_data, op_name, stack_json_data, summary_compare,
244
+ md5_compare)
245
+ if merge_list:
246
+ input_index, output_index = 0, 0
247
+ for index, input_or_output in enumerate(merge_list['op_name']):
248
+ input_or_output_list = input_or_output.split(Const.SEP)
249
+ data_name = merge_list.get('data_name')
250
+ data_name = data_name[index] if data_name else None
251
+ if Const.INPUT in input_or_output_list or Const.KWARGS in input_or_output_list:
252
+ ops_all[input_or_output] = {'struct': merge_list.get('input_struct')[input_index],
253
+ 'summary': merge_list.get('summary')[index],
254
+ 'data_name': data_name,
255
+ 'stack_info': merge_list.get('stack_info')}
256
+ input_index += 1
257
+
258
+ elif Const.OUTPUT in input_or_output_list:
259
+ ops_all[input_or_output] = {'struct': merge_list.get('output_struct')[output_index],
260
+ 'summary': merge_list.get('summary')[index],
261
+ 'data_name': data_name,
262
+ 'stack_info': merge_list.get('stack_info')}
263
+ output_index += 1
264
+ return ops_all
265
+
266
+ def get_accuracy(self, npu_ops_all, bench_ops_all, summary_compare, md5_compare):
267
+ result = []
268
+ for ms_op_name, bench_op_name in self.data_mapping_dict.items():
269
+ if ms_op_name in npu_ops_all and bench_op_name in bench_ops_all:
270
+ npu_stack_info = npu_ops_all.get(ms_op_name).get("stack_info", None)
271
+ bench_stack_info = bench_ops_all.get(bench_op_name).get("stack_info", None)
272
+ has_stack = npu_stack_info and bench_stack_info
273
+ if md5_compare:
274
+ result.append(self.get_result_md5_compare(ms_op_name, bench_op_name, npu_ops_all,
275
+ bench_ops_all, has_stack, npu_stack_info))
276
+ continue
277
+ if summary_compare:
278
+ result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
279
+ bench_ops_all.get(bench_op_name).get('struct')[0],
280
+ npu_ops_all.get(ms_op_name).get('struct')[1],
281
+ bench_ops_all.get(bench_op_name).get('struct')[1],
282
+ " ", " ", " ", " ", " ", " ", " ", " "]
283
+ else:
284
+ result_item = [ms_op_name, bench_op_name, npu_ops_all.get(ms_op_name).get('struct')[0],
285
+ bench_ops_all.get(bench_op_name).get('struct')[0],
286
+ npu_ops_all.get(ms_op_name).get('struct')[1],
287
+ bench_ops_all.get(bench_op_name).get('struct')[1],
288
+ " ", " ", " ", " ", " "]
289
+ npu_summary_data = npu_ops_all.get(ms_op_name).get("summary")
290
+ result_item.extend(npu_summary_data)
291
+ bench_summary_data = bench_ops_all.get(bench_op_name).get("summary")
292
+ result_item.extend(bench_summary_data)
293
+ if summary_compare:
294
+ self.calculate_summary_data(npu_summary_data, bench_summary_data, result_item)
295
+ else:
296
+ result_item.append(CompareConst.ACCURACY_CHECK_YES)
297
+ result_item.append("")
298
+ if has_stack:
299
+ result_item.extend(npu_stack_info)
300
+ else:
301
+ result_item.append(CompareConst.NONE)
302
+ if not (summary_compare or md5_compare):
303
+ result_item.append(npu_ops_all.get(ms_op_name).get("data_name", None))
304
+ result.append(result_item)
305
+ elif ms_op_name not in npu_ops_all:
306
+ logger.warning(f'Can not find npu op name : `{ms_op_name}` in npu dump json file.')
307
+ elif bench_op_name not in npu_ops_all:
308
+ logger.warning(f'Can not find bench op name : `{bench_op_name}` in bench dump json file.')
309
+ return result
310
+
311
+ def compare_process_custom(self, file_lists, stack_mode, summary_compare=False, md5_compare=False):
312
+ npu_json_path, bench_json_path, stack_json_path = file_lists
313
+ npu_json_data = load_json(npu_json_path)
314
+ bench_json_data = load_json(bench_json_path)
315
+ stack_json_data = load_json(stack_json_path)
316
+
317
+ npu_ops_all = self.merge_data(npu_json_data, stack_json_data, summary_compare, md5_compare)
318
+ bench_ops_all = self.merge_data(bench_json_data, stack_json_data, summary_compare, md5_compare)
319
+
320
+ result = self.get_accuracy(npu_ops_all, bench_ops_all, summary_compare, md5_compare)
321
+ result_df = self.make_result_table(result, md5_compare, summary_compare, stack_mode)
322
+ return result_df
323
+
167
324
  def compare_by_op(self, npu_op_name, bench_op_name, op_name_mapping_dict, input_param):
168
325
  npu_bench_name_list = op_name_mapping_dict[npu_op_name]
169
326
  data_name = npu_bench_name_list[1]
@@ -178,9 +335,11 @@ class Comparator:
178
335
  if frame_name == "MSComparator":
179
336
  n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.NUMPY_SUFFIX)
180
337
  if self.cross_frame:
181
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.PT_SUFFIX, load_pt_file=True)
338
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
339
+ bench_op_name + Const.PT_SUFFIX, load_pt_file=True)
182
340
  else:
183
- b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.NUMPY_SUFFIX)
341
+ b_value = read_npy_data(input_param.get("bench_dump_data_dir"),
342
+ bench_op_name + Const.NUMPY_SUFFIX)
184
343
  else:
185
344
  n_value = read_npy_data(input_param.get("npu_dump_data_dir"), npu_op_name + Const.PT_SUFFIX)
186
345
  b_value = read_npy_data(input_param.get("bench_dump_data_dir"), bench_op_name + Const.PT_SUFFIX)
@@ -237,19 +396,31 @@ class Comparator:
237
396
  file_path = os.path.join(os.path.realpath(output_path), file_name)
238
397
  remove_path(file_path)
239
398
  highlight_dict = {'red_rows': [], 'yellow_rows': []}
240
-
241
- with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \
242
- FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \
243
- FileOpen(input_parma.get("stack_json_path"), "r") as stack_json:
399
+
400
+ npu_json = input_parma.get("npu_json_path")
401
+ bench_json = input_parma.get("bench_json_path")
402
+ stack_json = input_parma.get("stack_json_path")
403
+ if self.data_mapping:
404
+ result_df = self.compare_process_custom([npu_json, bench_json, stack_json], stack_mode,
405
+ summary_compare, md5_compare)
406
+ else:
244
407
  result_df = self.compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match,
245
- summary_compare, md5_compare)
408
+ summary_compare, md5_compare)
409
+
410
+ if not result_df.values.tolist():
411
+ logger.warning("Can`t match any op.")
412
+ return
246
413
 
247
414
  if not md5_compare and not summary_compare:
248
415
  result_df = self._do_multi_process(input_parma, result_df)
416
+
417
+ logger.info("Highlight suspicious API/Module start.")
249
418
  find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare)
250
419
  highlight_rows_xlsx(result_df, highlight_dict, file_path)
420
+ logger.info("Highlight suspicious API/Module finish.")
421
+
251
422
  if auto_analyze:
252
- advisor = Advisor(result_df, output_path)
423
+ advisor = Advisor(result_df, output_path, suffix)
253
424
  advisor.analysis()
254
425
 
255
426
  def compare_ops(self, idx, dump_path_dict, result_df, lock, input_param):
@@ -265,13 +436,14 @@ class Comparator:
265
436
  bench_op_name = result_df.iloc[i, 1]
266
437
  if is_print_compare_log:
267
438
  logger.info("start compare: {}".format(npu_op_name))
268
- cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = self.compare_by_op(
269
- npu_op_name, bench_op_name, dump_path_dict, input_param)
439
+ cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = \
440
+ self.compare_by_op(npu_op_name, bench_op_name, dump_path_dict, input_param)
270
441
  if is_print_compare_log:
271
442
  logger.info(
272
- "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, one_thousand_err_ratio {}, "
273
- "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err, err_msg,
274
- one_thousand_err_ratio, five_thousand_err_ratio))
443
+ "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, \
444
+ one_thousand_err_ratio {}, "
445
+ "five_thousand_err_ratio {}".format(npu_op_name, cos_sim, max_abs_err, max_relative_err,
446
+ err_msg, one_thousand_err_ratio, five_thousand_err_ratio))
275
447
  cos_result.append(cos_sim)
276
448
  max_err_result.append(max_abs_err)
277
449
  max_relative_err_result.append(max_relative_err)
@@ -290,9 +462,10 @@ class Comparator:
290
462
 
291
463
  return _save_cmp_result(idx, cr, result_df, lock)
292
464
 
293
- def _do_multi_process(self,input_parma, result_df):
465
+ def _do_multi_process(self, input_parma, result_df):
294
466
  try:
295
- result_df = _handle_multi_process(self.compare_ops, input_parma, result_df, multiprocessing.Manager().RLock())
467
+ result_df = _handle_multi_process(self.compare_ops, input_parma, result_df,
468
+ multiprocessing.Manager().RLock())
296
469
  return result_df
297
470
  except ValueError as e:
298
471
  logger.error('result dataframe is not found.')
@@ -1,5 +1,22 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from msprobe.core.common.log import logger
2
- from msprobe.core.compare.utils import rename_api
17
+ from msprobe.core.compare.utils import rename_api
18
+ from msprobe.core.common.utils import check_op_str_pattern_valid, CompareException
19
+ from msprobe.core.common.const import Const
3
20
 
4
21
 
5
22
  dtype_mapping = {
@@ -34,8 +51,15 @@ def check_struct_match(npu_dict, bench_dict, cross_frame=False):
34
51
  if not is_match:
35
52
  if len(npu_struct_in) == 0 or len(bench_struct_in) == 0 or len(npu_struct_in) != len(bench_struct_in):
36
53
  return False
37
- struct_in_is_match = check_type_shape_match(npu_struct_in, bench_struct_in)
38
- struct_out_is_match = check_type_shape_match(npu_struct_out, bench_struct_out)
54
+ try:
55
+ struct_in_is_match = check_type_shape_match(npu_struct_in, bench_struct_in)
56
+ struct_out_is_match = check_type_shape_match(npu_struct_out, bench_struct_out)
57
+ except CompareException as error:
58
+ err_msg = f'index out of bounds error occurs in npu or bench api, please check!\n' \
59
+ f'npu_dict: {npu_dict}' \
60
+ f'bench_dict: {bench_dict}'
61
+ logger.error(err_msg)
62
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
39
63
  is_match = struct_in_is_match and struct_out_is_match
40
64
  return is_match
41
65
 
@@ -43,17 +67,27 @@ def check_struct_match(npu_dict, bench_dict, cross_frame=False):
43
67
  def check_type_shape_match(npu_struct, bench_struct):
44
68
  shape_type_match = False
45
69
  for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct):
46
- npu_type = npu_type_shape[0]
47
- npu_shape = npu_type_shape[1]
48
- bench_type = bench_type_shape[0]
49
- bench_shape = bench_type_shape[1]
70
+ try:
71
+ npu_type = npu_type_shape[0]
72
+ npu_shape = npu_type_shape[1]
73
+ bench_type = bench_type_shape[0]
74
+ bench_shape = bench_type_shape[1]
75
+ except IndexError as error:
76
+ logger.error(f'length of npu_type_shape: {npu_type_shape} and bench_type_shape: {bench_type_shape} '
77
+ f'should both be 2, please check!')
78
+ raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from error
50
79
  shape_match = npu_shape == bench_shape
51
80
  type_match = npu_type == bench_type
52
81
  if not type_match:
53
- ms_type=[["Float16", "Float32"], ["Float32", "Float16"],["Float16", "BFloat16"],["BFloat16", "Float16"]]
54
- torch_type=[["torch.float16", "torch.float32"], ["torch.float32", "torch.float16"],
55
- ["torch.float16", "torch.bfloat16"], ["torch.bfloat16", "torch.float16"]]
56
- if ([npu_type, bench_type] in ms_type)or ([npu_type, bench_type] in torch_type):
82
+ ms_type = [
83
+ [Const.FLOAT16, Const.FLOAT32], [Const.FLOAT32, Const.FLOAT16],
84
+ [Const.FLOAT16, Const.BFLOAT16], [Const.BFLOAT16, Const.FLOAT16]
85
+ ]
86
+ torch_type = [
87
+ [Const.TORCH_FLOAT16, Const.TORCH_FLOAT32], [Const.TORCH_FLOAT32, Const.TORCH_FLOAT16],
88
+ [Const.TORCH_FLOAT16, Const.TORCH_BFLOAT16], [Const.TORCH_BFLOAT16, Const.TORCH_FLOAT16]
89
+ ]
90
+ if ([npu_type, bench_type] in ms_type) or ([npu_type, bench_type] in torch_type):
57
91
  type_match = True
58
92
  else:
59
93
  type_match = False
@@ -64,9 +98,9 @@ def check_type_shape_match(npu_struct, bench_struct):
64
98
 
65
99
 
66
100
  def check_graph_mode(a_op_name, b_op_name):
67
- if "Aten" in a_op_name and "Aten" not in b_op_name:
101
+ if Const.ATEN in a_op_name and Const.ATEN not in b_op_name:
68
102
  return True
69
- if "Aten" not in a_op_name and "Aten" in b_op_name:
103
+ if Const.ATEN not in a_op_name and Const.ATEN in b_op_name:
70
104
  return True
71
105
  return False
72
106
 
@@ -83,13 +117,64 @@ def fuzzy_check_op(npu_name_list, bench_name_list):
83
117
 
84
118
 
85
119
  def fuzzy_check_name(npu_name, bench_name):
86
- if "forward" in npu_name and "forward" in bench_name:
87
- is_match = rename_api(npu_name, "forward") == rename_api(bench_name, "forward")
88
- elif "backward" in npu_name and "backward" in bench_name:
89
- is_match = rename_api(npu_name, "backward") == rename_api(bench_name, "backward")
120
+ if Const.FORWARD in npu_name and Const.FORWARD in bench_name:
121
+ is_match = rename_api(npu_name, Const.FORWARD) == rename_api(bench_name, Const.FORWARD)
122
+ elif Const.BACKWARD in npu_name and Const.BACKWARD in bench_name:
123
+ is_match = rename_api(npu_name, Const.BACKWARD) == rename_api(bench_name, Const.BACKWARD)
90
124
  else:
91
125
  is_match = npu_name == bench_name
92
126
  return is_match
93
127
 
94
128
 
129
+ def check_dump_json_str(op_data, op_name):
130
+ input_list = op_data.get(Const.INPUT_ARGS, None) if op_data.get(Const.INPUT_ARGS, None) else op_data.get(
131
+ Const.INPUT, None)
132
+ input_kwargs = op_data.get(Const.INPUT_KWARGS, None)
133
+ output_list = op_data.get(Const.OUTPUT, None)
134
+
135
+ args = [input_list, input_kwargs, output_list]
136
+ for arg in args:
137
+ if not arg:
138
+ continue
139
+ if isinstance(arg, dict):
140
+ check_json_key_value(arg, op_name)
141
+ else:
142
+ for ele in arg:
143
+ if not ele:
144
+ continue
145
+ check_json_key_value(ele, op_name)
146
+
147
+
148
+ def check_json_key_value(input_output, op_name, depth=0):
149
+ if depth > Const.MAX_DEPTH:
150
+ logger.error(f"string check of data info of {op_name} exceeds the recursion limit.")
151
+ return
152
+ if isinstance(input_output, list):
153
+ for item in input_output:
154
+ check_json_key_value(item, op_name, depth+1)
155
+ elif isinstance(input_output, dict):
156
+ for key, value in input_output.items():
157
+ if isinstance(value, dict):
158
+ check_json_key_value(value, op_name, depth+1)
159
+ else:
160
+ valid_key_value(key, value, op_name)
161
+
95
162
 
163
+ def valid_key_value(key, value, op_name):
164
+ if key == "shape" and not isinstance(value, (list, tuple)):
165
+ logger.error(f"shape of input or output of {op_name} is not list or tuple, please check!")
166
+ raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
167
+ elif key == "requires_grad" and not isinstance(value, bool):
168
+ logger.error(f"requires_grad of input or output of {op_name} is not bool, please check!")
169
+ raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
170
+ else:
171
+ check_op_str_pattern_valid(value, op_name)
172
+
173
+
174
+ def check_stack_json_str(stack_info, op_name):
175
+ if isinstance(stack_info, list):
176
+ for item in stack_info:
177
+ check_op_str_pattern_valid(item, op_name, stack=True)
178
+ else:
179
+ logger.error(f"Expected stack_info to be a list, but got {type(stack_info).__name__} for '{op_name}'")
180
+ raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import json
2
17
  from msprobe.core.common.file_utils import FileOpen, check_file_type
3
18
  from msprobe.core.common.const import FileCheckConst, Const
@@ -23,8 +38,11 @@ def compare_cli(args):
23
38
  input_param["bench_json_path"] = input_param.pop("bench_path")
24
39
  input_param["stack_json_path"] = input_param.pop("stack_path")
25
40
  if frame_name == Const.PT_FRAMEWORK:
41
+ kwargs = {
42
+ "data_mapping": args.data_mapping
43
+ }
26
44
  compare(input_param, args.output_path, stack_mode=args.stack_mode, auto_analyze=auto_analyze,
27
- fuzzy_match=args.fuzzy_match)
45
+ fuzzy_match=args.fuzzy_match, **kwargs)
28
46
  else:
29
47
  kwargs = {
30
48
  "stack_mode": args.stack_mode,
@@ -32,6 +50,8 @@ def compare_cli(args):
32
50
  "fuzzy_match": args.fuzzy_match,
33
51
  "cell_mapping": args.cell_mapping,
34
52
  "api_mapping": args.api_mapping,
53
+ "data_mapping": args.data_mapping,
54
+ "layer_mapping": args.layer_mapping
35
55
  }
36
56
 
37
57
  ms_compare(input_param, args.output_path, **kwargs)
@@ -1,5 +1,21 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import math
2
17
  import abc
18
+ import re
3
19
  from collections import namedtuple
4
20
  import numpy as np
5
21
  import openpyxl
@@ -7,7 +23,7 @@ from openpyxl.styles import PatternFill
7
23
  from msprobe.core.common.utils import get_header_index
8
24
  from msprobe.core.common.file_utils import save_workbook
9
25
  from msprobe.core.common.log import logger
10
- from msprobe.core.common.const import CompareConst
26
+ from msprobe.core.common.const import CompareConst, FileCheckConst
11
27
 
12
28
 
13
29
  class HighlightCheck(abc.ABC):
@@ -34,9 +50,11 @@ class CheckOneThousandErrorRatio(HighlightCheck):
34
50
  def apply(self, info, color_columns, summary_compare=True):
35
51
  api_in, api_out, num = info
36
52
  one_thousand_index = get_header_index('One Thousandth Err Ratio', summary_compare)
37
- if not isinstance(api_in[one_thousand_index], (float, int)) or not isinstance(api_out[one_thousand_index], (float, int)):
53
+ if (not isinstance(api_in[one_thousand_index], (float, int)) or
54
+ not isinstance(api_out[one_thousand_index], (float, int))):
38
55
  return
39
- if api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED:
56
+ if (api_in[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_IN_RED and
57
+ api_out[one_thousand_index] < CompareConst.ONE_THOUSAND_ERROR_OUT_RED):
40
58
  color_columns.red.append(num)
41
59
  elif api_in[one_thousand_index] - api_out[one_thousand_index] > CompareConst.ONE_THOUSAND_ERROR_DIFF_YELLOW:
42
60
  color_columns.yellow.append(num)
@@ -66,7 +84,8 @@ class CheckMaxRelativeDiff(HighlightCheck):
66
84
  return
67
85
  if output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_RED:
68
86
  color_columns.red.append(num)
69
- elif output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW:
87
+ elif (output_max_relative_diff > CompareConst.MAX_RELATIVE_OUT_YELLOW and
88
+ input_max_relative_diff < CompareConst.MAX_RELATIVE_IN_YELLOW):
70
89
  color_columns.yellow.append(num)
71
90
 
72
91
 
@@ -193,7 +212,8 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, m
193
212
  input_num = num
194
213
  else:
195
214
  output_num = num
196
- find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare, md5_compare)
215
+ find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict,
216
+ summary_compare, md5_compare)
197
217
 
198
218
 
199
219
  def highlight_rows_xlsx(result_df, highlight_dict, file_path):
@@ -205,12 +225,16 @@ def highlight_rows_xlsx(result_df, highlight_dict, file_path):
205
225
 
206
226
  # write header
207
227
  for j, col_name in enumerate(result_df.columns, start=1):
228
+ if not csv_value_is_valid(col_name):
229
+ raise RuntimeError(f"Malicious value [{col_name}] is not allowed to be written into the xlsx: {file_path}.")
208
230
  ws.cell(row=1, column=j, value=col_name)
209
231
 
210
232
  for i, row in enumerate(result_df.iterrows(), start=2):
211
233
  for j, value in enumerate(row[1], start=1):
212
234
  if not isinstance(value, (float, int)):
213
235
  value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value)
236
+ if not csv_value_is_valid(value):
237
+ raise RuntimeError(f"Malicious value [{value}] is not allowed to be written into the xlsx: {file_path}.")
214
238
  ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value)
215
239
 
216
240
  if (i - 2) in highlight_dict['red_rows']:
@@ -221,3 +245,15 @@ def highlight_rows_xlsx(result_df, highlight_dict, file_path):
221
245
  end_color=CompareConst.YELLOW, fill_type="solid")
222
246
 
223
247
  save_workbook(wb, file_path)
248
+
249
+
250
+ def csv_value_is_valid(value: str) -> bool:
251
+ if not isinstance(value, str):
252
+ return True
253
+ try:
254
+ # -1.00 or +1.00 should be consdiered as digit numbers
255
+ float(value)
256
+ except ValueError:
257
+ # otherwise, they will be considered as formular injections
258
+ return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
259
+ return True