mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -14,11 +14,13 @@ class Const:
14
14
  REGEX_PREFIX_MAX_LENGTH = 20
15
15
  REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$"
16
16
  FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
17
+ STRING_BLACKLIST = r"^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]"
17
18
  COMMA = ","
18
19
  FLOAT_EPSILON = np.finfo(float).eps
19
20
  OFF = 'OFF'
20
21
  BACKWARD = 'backward'
21
22
  FORWARD = 'forward'
23
+ JIT = 'Jit'
22
24
  PRIMITIVE_PREFIX = 'Primitive'
23
25
  DEFAULT_LIST = []
24
26
  DEFAULT_PATH = './'
@@ -30,6 +32,7 @@ class Const:
30
32
  FOUR_SEGMENT = 4
31
33
  SIX_SEGMENT = 6
32
34
  SEVEN_SEGMENT = 7
35
+ MAX_DEPTH = 10
33
36
 
34
37
  # dump mode
35
38
  ALL = "all"
@@ -78,6 +81,7 @@ class Const:
78
81
  RUN_UT = "run_ut"
79
82
  GRAD_PROBE = "grad_probe"
80
83
  TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE]
84
+ DUMP_DATA_COLLECTION_LIST = [STATISTICS, TENSOR]
81
85
  LEVEL_L0 = "L0"
82
86
  LEVEL_L1 = "L1"
83
87
  LEVEL_L2 = "L2"
@@ -100,6 +104,30 @@ class Const:
100
104
  CUDA_LOWERCASE = 'cuda'
101
105
  DISTRIBUTED = 'Distributed'
102
106
 
107
+ # struct json param
108
+ ORIGIN_DATA = "origin_data"
109
+ SCOPE = "scope"
110
+ STACK = "stack"
111
+
112
+ ATEN = "Aten"
113
+ MODULE_WHITE_LIST = ["torch", "numpy"]
114
+
115
+ FUNC_SKIP_LIST = ["construct", "__call__"]
116
+
117
+ FILE_SKIP_LIST = ["site-packages/mindspore", "package/mindspore", "msprobe", "site-packages/torch", "package/torch"]
118
+
119
+ STACK_FILE_INDEX = 0
120
+
121
+ STACK_FUNC_INDEX = 2
122
+
123
+ STACK_FUNC_ELE_INDEX = 1
124
+
125
+ CONSTRUCT_NAME_INDEX = -3
126
+
127
+ NAME_FIRST_POSSIBLE_INDEX = -4
128
+
129
+ NAME_SECOND_POSSIBLE_INDEX = -5
130
+
103
131
  INPLACE_LIST = [
104
132
  "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
105
133
  "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single", "all_to_all",
@@ -114,6 +142,23 @@ class Const:
114
142
  "int32_to_int64": ["cross_entropy"]
115
143
  }
116
144
 
145
+ FILL_CHAR_NUMS = 50
146
+ TOOL_ENDS_SUCCESSFULLY = f"{TOOL_NAME} ends successfully."
147
+ WITHOUT_CALL_STACK = "The call stack retrieval failed."
148
+
149
+ STEP = "step"
150
+ RANK = "rank"
151
+ HYPHEN = "-"
152
+ STEP_RANK_MAXIMUM_RANGE = [int(0), int(1e6)]
153
+
154
+ # data type const
155
+ FLOAT16 = "Float16"
156
+ FLOAT32 = "Float32"
157
+ BFLOAT16 = "BFloat16"
158
+ TORCH_FLOAT16 = "torch.float16"
159
+ TORCH_FLOAT32 = "torch.float32"
160
+ TORCH_BFLOAT16 = "torch.bfloat16"
161
+
117
162
 
118
163
  class CompareConst:
119
164
  """
@@ -159,6 +204,7 @@ class CompareConst:
159
204
  INPUT_STRUCT = "input_struct"
160
205
  OUTPUT_STRUCT = "output_struct"
161
206
  SUMMARY = "summary"
207
+ MAX_EXCEL_LENGTH = 1048576
162
208
 
163
209
  COMPARE_RESULT_HEADER = [
164
210
  NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR,
@@ -197,6 +243,8 @@ class CompareConst:
197
243
  ERROR = 'error'
198
244
  SKIP = 'SKIP'
199
245
  N_A = 'N/A'
246
+ INF = 'inf'
247
+ NEG_INF = '-inf'
200
248
  BFLOAT16_MIN = -3.3895313892515355e+38
201
249
  BFLOAT16_MAX = 3.3895313892515355e+38
202
250
  BFLOAT16_EPS = 3.90625e-3 # 2 ** -8
@@ -274,7 +322,8 @@ class FileCheckConst:
274
322
  MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
275
323
  MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
276
324
  MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
277
- MAX_YAML_SIZE = 1048576 # 10 * 1024 * 1024
325
+ MAX_YAML_SIZE = 1048576 # 1 * 1024 * 1024
326
+ COMMOM_FILE_SIZE = 1048576 # 1 * 1024 * 1024
278
327
  DIR = "dir"
279
328
  FILE = "file"
280
329
  DATA_DIR_AUTHORITY = 0o750
@@ -287,6 +336,7 @@ class FileCheckConst:
287
336
  CSV_SUFFIX: MAX_CSV_SIZE,
288
337
  YAML_SUFFIX: MAX_YAML_SIZE
289
338
  }
339
+ CSV_BLACK_LIST = r'^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]'
290
340
 
291
341
 
292
342
  class OverflowConst:
@@ -329,11 +379,22 @@ class MsgConst:
329
379
  """
330
380
  Class for log messages const
331
381
  """
332
- CLEAR_SYMBOL = "\033[K"
333
382
  MSPROBE_LOG_LEVEL = "MSPROBE_LOG_LEVEL"
334
- LEVEL = ["INFO", "WARNING", "ERROR", "DEBUG"]
383
+ LOG_LEVEL_ENUM = ["0", "1", "2", "3", "4"]
384
+ LOG_LEVEL = ["DEBUG", "INFO", "WARNING", "ERROR"]
385
+ class LogLevel:
386
+ class DEBUG:
387
+ value = 0
388
+ class INFO:
389
+ value = 1
390
+ class WARNING:
391
+ value = 2
392
+ class ERROR:
393
+ value = 3
335
394
  SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"]
336
395
 
396
+ NOT_CREATED_INSTANCE = "PrecisionDebugger instance is not created."
397
+
337
398
 
338
399
  class GraphMode:
339
400
  NPY_MODE = "NPY_MODE"
@@ -13,8 +13,8 @@ class MsprobeException(CodedException):
13
13
  OVERFLOW_NUMS_ERROR = 1
14
14
 
15
15
  err_strs = {
16
- INVALID_PARAM_ERROR: "[msprobe] 无效参数: ",
17
- OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:"
16
+ INVALID_PARAM_ERROR: "[msprobe] 无效参数:",
17
+ OVERFLOW_NUMS_ERROR: "[msprobe] 超过预设溢出次数 当前溢出次数:"
18
18
  }
19
19
 
20
20
 
@@ -22,6 +22,7 @@ import re
22
22
  import shutil
23
23
  import yaml
24
24
  import numpy as np
25
+ import pandas as pd
25
26
 
26
27
  from msprobe.core.common.log import logger
27
28
  from msprobe.core.common.exceptions import FileCheckException
@@ -187,7 +188,7 @@ def check_other_user_writable(path):
187
188
 
188
189
  def check_path_owner_consistent(path):
189
190
  file_owner = os.stat(path).st_uid
190
- if file_owner != os.getuid():
191
+ if file_owner != os.getuid() and os.getuid() != 0:
191
192
  logger.error('The file path %s may be insecure because is does not belong to you.' % path)
192
193
  raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR)
193
194
 
@@ -214,7 +215,9 @@ def check_common_file_size(file_path):
214
215
  for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items():
215
216
  if file_path.endswith(suffix):
216
217
  check_file_size(file_path, max_size)
217
- break
218
+ return
219
+ check_file_size(file_path, FileCheckConst.COMMOM_FILE_SIZE)
220
+
218
221
 
219
222
 
220
223
  def check_file_suffix(file_path, file_suffix):
@@ -322,7 +325,7 @@ def check_file_type(path):
322
325
  elif os.path.isfile(path):
323
326
  return FileCheckConst.FILE
324
327
  else:
325
- logger.error('Neither a file nor a directory.')
328
+ logger.error(f'{path} does not exist, please check!')
326
329
  raise FileCheckException(FileCheckException.INVALID_FILE_ERROR)
327
330
 
328
331
 
@@ -338,10 +341,10 @@ def load_yaml(yaml_path):
338
341
  return yaml_data
339
342
 
340
343
 
341
- def load_npy(filepath, enable_pickle=False):
344
+ def load_npy(filepath):
342
345
  check_file_or_directory_path(filepath)
343
346
  try:
344
- npy = np.load(filepath, allow_pickle=enable_pickle)
347
+ npy = np.load(filepath)
345
348
  except Exception as e:
346
349
  logger.error(f"The numpy file failed to load. Please check the path: {filepath}.")
347
350
  raise RuntimeError(f"Load numpy file {filepath} failed.") from e
@@ -374,6 +377,20 @@ def save_json(json_path, data, indent=None):
374
377
  change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY)
375
378
 
376
379
 
380
+ def save_yaml(yaml_path, data):
381
+ yaml_path = os.path.realpath(yaml_path)
382
+ check_path_before_create(yaml_path)
383
+ try:
384
+ with FileOpen(yaml_path, 'w') as f:
385
+ fcntl.flock(f, fcntl.LOCK_EX)
386
+ yaml.dump(data, f, sort_keys=False)
387
+ fcntl.flock(f, fcntl.LOCK_UN)
388
+ except Exception as e:
389
+ logger.error(f'Save yaml file "{os.path.basename(yaml_path)}" failed.')
390
+ raise RuntimeError(f"Save yaml file {yaml_path} failed.") from e
391
+ change_mode(yaml_path, FileCheckConst.DATA_FILE_AUTHORITY)
392
+
393
+
377
394
  def move_file(src_path, dst_path):
378
395
  check_file_or_directory_path(src_path)
379
396
  check_path_before_create(dst_path)
@@ -396,9 +413,9 @@ def save_npy(data, filepath):
396
413
  change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
397
414
 
398
415
 
399
- def save_npy_to_txt(self, data, dst_file='', align=0):
416
+ def save_npy_to_txt(data, dst_file='', align=0):
400
417
  if os.path.exists(dst_file):
401
- self.log.info("Dst file %s exists, will not save new one.", dst_file)
418
+ logger.info("Dst file %s exists, will not save new one." % dst_file)
402
419
  return
403
420
  shape = data.shape
404
421
  data = data.flatten()
@@ -411,7 +428,7 @@ def save_npy_to_txt(self, data, dst_file='', align=0):
411
428
  try:
412
429
  np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g')
413
430
  except Exception as e:
414
- self.log.error("An unexpected error occurred: %s when savetxt to %s" % (str(e)), dst_file)
431
+ logger.error("An unexpected error occurred: %s when savetxt to %s" % (str(e), dst_file))
415
432
  change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY)
416
433
 
417
434
 
@@ -431,7 +448,25 @@ def save_workbook(workbook, file_path):
431
448
  change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY)
432
449
 
433
450
 
434
- def write_csv(data, filepath, mode="a+"):
451
+ def write_csv(data, filepath, mode="a+", malicious_check=False):
452
+ def csv_value_is_valid(value: str) -> bool:
453
+ if not isinstance(value, str):
454
+ return True
455
+ try:
456
+ # -1.00 or +1.00 should be consdiered as digit numbers
457
+ float(value)
458
+ except ValueError:
459
+ # otherwise, they will be considered as formular injections
460
+ return not bool(re.compile(FileCheckConst.CSV_BLACK_LIST).search(value))
461
+ return True
462
+
463
+ if malicious_check:
464
+ for row in data:
465
+ for cell in row:
466
+ if not csv_value_is_valid(cell):
467
+ raise RuntimeError(f"Malicious value [{cell}] is not allowed " \
468
+ f"to be written into the csv: {filepath}.")
469
+
435
470
  file_path = os.path.realpath(filepath)
436
471
  check_path_before_create(filepath)
437
472
  try:
@@ -444,6 +479,16 @@ def write_csv(data, filepath, mode="a+"):
444
479
  change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
445
480
 
446
481
 
482
+ def read_csv(filepath):
483
+ check_file_or_directory_path(filepath)
484
+ try:
485
+ csv_data = pd.read_csv(filepath)
486
+ except Exception as e:
487
+ logger.error(f"The csv file failed to load. Please check the path: {filepath}.")
488
+ raise RuntimeError(f"Read csv file {filepath} failed.") from e
489
+ return csv_data
490
+
491
+
447
492
  def remove_path(path):
448
493
  if not os.path.exists(path):
449
494
  return
@@ -0,0 +1,38 @@
1
+ import os
2
+ from msprobe.core.common.file_utils import load_yaml
3
+
4
+
5
+ class InplaceOpChecker:
6
+ OP_FUNCTIONAL = 'functional'
7
+ OP_TENSOR = 'tensor'
8
+ OP_TORCH = 'torch'
9
+ OP_DISTRIBUTED = 'distributed'
10
+
11
+ INPLACE_OPS_DICT = None
12
+
13
+ @classmethod
14
+ def load_ops(cls):
15
+ if cls.INPLACE_OPS_DICT is None:
16
+ cls.INPLACE_OPS_DICT = dict()
17
+ cur_path = os.path.dirname(os.path.realpath(__file__))
18
+ yaml_path = os.path.join(cur_path, "inplace_ops.yaml")
19
+ all_ops = load_yaml(yaml_path)
20
+ cls.INPLACE_OPS_DICT[cls.OP_FUNCTIONAL] = all_ops.get('inplace_functional_op')
21
+ cls.INPLACE_OPS_DICT[cls.OP_TENSOR] = all_ops.get('inplace_tensor_op')
22
+ cls.INPLACE_OPS_DICT[cls.OP_TORCH] = all_ops.get('inplace_torch_op')
23
+ cls.INPLACE_OPS_DICT[cls.OP_DISTRIBUTED] = all_ops.get('inplace_distributed_op')
24
+
25
+ @classmethod
26
+ def check(cls, api, category='distributed'):
27
+ """
28
+ 给定api和分类,检查其是否为inplace操作
29
+ """
30
+ if not cls.INPLACE_OPS_DICT:
31
+ cls.load_ops()
32
+
33
+ if category not in cls.INPLACE_OPS_DICT.keys():
34
+ return False
35
+ return api in cls.INPLACE_OPS_DICT[category]
36
+
37
+
38
+ InplaceOpChecker.load_ops()
@@ -0,0 +1,251 @@
1
+ inplace_functional_op:
2
+ - threshold_
3
+ - relu_
4
+ - hardtanh_
5
+ - elu_
6
+ - selu_
7
+ - celu_
8
+ - leaky_relu_
9
+ - rrelu_
10
+
11
+ inplace_tensor_op:
12
+ - __iadd__
13
+ - __iand__
14
+ - __idiv__
15
+ - __ifloordiv__
16
+ - __ilshift__
17
+ - __imod__
18
+ - __imul__
19
+ - __ior__
20
+ - __irshift__
21
+ - __isub__
22
+ - __ixor__
23
+ - abs_
24
+ - absolute_
25
+ - acos_
26
+ - acosh_
27
+ - add_
28
+ - addbmm_
29
+ - addcdiv_
30
+ - addcmul_
31
+ - addmm_
32
+ - addmv_
33
+ - addr_
34
+ - arccos_
35
+ - arccosh_
36
+ - arcsin_
37
+ - arcsinh_
38
+ - arctan_
39
+ - arctanh_
40
+ - asin_
41
+ - asinh_
42
+ - atan2_
43
+ - atan_
44
+ - atanh_
45
+ - baddbmm_
46
+ - bernoulli_
47
+ - bitwise_and_
48
+ - bitwise_not_
49
+ - bitwise_or_
50
+ - bitwise_xor_
51
+ - cauchy_
52
+ - ceil_
53
+ - clamp_
54
+ - clamp_max_
55
+ - clamp_min_
56
+ - clip_
57
+ - copysign_
58
+ - cos_
59
+ - cosh_
60
+ - cumprod_
61
+ - cumsum_
62
+ - deg2rad_
63
+ - digamma_
64
+ - div_
65
+ - divide_
66
+ - eq_
67
+ - erf_
68
+ - erfc_
69
+ - erfinv_
70
+ - exp2_
71
+ - exp_
72
+ - expm1_
73
+ - exponential_
74
+ - fill_
75
+ - fill_diagonal_
76
+ - fix_
77
+ - float_power_
78
+ - floor_
79
+ - floor_divide_
80
+ - fmod_
81
+ - frac_
82
+ - gcd_
83
+ - ge_
84
+ - geometric_
85
+ - greater_
86
+ - gt_
87
+ - greater_equal_
88
+ - heaviside_
89
+ - hypot_
90
+ - igamma_
91
+ - igammac_
92
+ - index_add_
93
+ - index_copy_
94
+ - index_fill_
95
+ - index_put_
96
+ - lcm_
97
+ - ldexp_
98
+ - le_
99
+ - lerp_
100
+ - less_
101
+ - less_equal_
102
+ - lgamma_
103
+ - log10_
104
+ - log1p_
105
+ - log2_
106
+ - log_
107
+ - log_normal_
108
+ - logical_and_
109
+ - logical_not_
110
+ - logical_or_
111
+ - logical_xor_
112
+ - logit_
113
+ - lt_
114
+ - map2_
115
+ - map_
116
+ - masked_fill_
117
+ - masked_scatter_
118
+ - mul_
119
+ - multiply_
120
+ - mvlgamma_
121
+ - ne_
122
+ - neg_
123
+ - negative_
124
+ - normal_
125
+ - not_equal_
126
+ - pow_
127
+ - polygamma_
128
+ - put_
129
+ - rad2deg_
130
+ - reciprocal_
131
+ - relu_
132
+ - remainder_
133
+ - renorm_
134
+ - resize_
135
+ - resize_as_
136
+ - round_
137
+ - rsqrt_
138
+ - scatter_
139
+ - scatter_add_
140
+ - sgn_
141
+ - sigmoid_
142
+ - sign_
143
+ - sin_
144
+ - sinc_
145
+ - sinh_
146
+ - sqrt_
147
+ - square_
148
+ - squeeze_
149
+ - sub_
150
+ - t_
151
+ - tan_
152
+ - tanh_
153
+ - transpose_
154
+ - tril_
155
+ - triu_
156
+ - true_divide_
157
+ - trunc_
158
+ - unsqueeze_
159
+ - xlogy_
160
+
161
+ inplace_torch_op:
162
+ - _add_relu_
163
+ - abs_
164
+ - acos_
165
+ - acosh_
166
+ - addmv_
167
+ - alpha_dropout_
168
+ - arccos_
169
+ - arccosh_
170
+ - arcsin_
171
+ - arcsinh_
172
+ - arctan_
173
+ - arctanh_
174
+ - asin_
175
+ - asinh_
176
+ - atan_
177
+ - atanh_
178
+ - ceil_
179
+ - celu_
180
+ - clamp_
181
+ - clamp_max_
182
+ - clamp_min_
183
+ - clip_
184
+ - cos_
185
+ - cosh_
186
+ - deg2rad_
187
+ - dropout_
188
+ - embedding_renorm_
189
+ - erf_
190
+ - erfc_
191
+ - exp2_
192
+ - exp_
193
+ - expm1_
194
+ - feature_alpha_dropout_
195
+ - feature_dropout_
196
+ - fill_
197
+ - fix_
198
+ - floor_
199
+ - frac_
200
+ - gcd_
201
+ - index_put_
202
+ - lcm_
203
+ - ldexp_
204
+ - log10_
205
+ - log1p_
206
+ - log2_
207
+ - log_
208
+ - logit_
209
+ - nan_to_num_
210
+ - neg_
211
+ - negative_
212
+ - rad2deg_
213
+ - reciprocal_
214
+ - relu_
215
+ - resize_as_
216
+ - round_
217
+ - rrelu_
218
+ - rsqrt_
219
+ - selu_
220
+ - sigmoid_
221
+ - sin_
222
+ - sinc_
223
+ - sinh_
224
+ - sqrt_
225
+ - square_
226
+ - tan_
227
+ - tanh_
228
+ - threshold_
229
+ - trunc_
230
+ - xlogy_
231
+
232
+ inplace_distributed_op:
233
+ - broadcast
234
+ - all_reduce
235
+ - reduce
236
+ - all_gather
237
+ - gather
238
+ - scatter
239
+ - reduce_scatter
240
+ - _reduce_scatter_base
241
+ - _all_gather_base
242
+ - send
243
+ - recv
244
+ - irecv
245
+ - isend
246
+ - all_to_all_single
247
+ - all_to_all
248
+ - all_gather_into_tensor
249
+ - reduce_scatter_tensor
250
+
251
+
@@ -4,13 +4,19 @@ import sys
4
4
  from functools import wraps
5
5
  from msprobe.core.common.const import MsgConst
6
6
 
7
- MSPROBE_LOG_LEVEL = os.environ.get(MsgConst.MSPROBE_LOG_LEVEL, "")
8
-
9
7
 
10
8
  class BaseLogger:
11
9
  def __init__(self):
12
10
  self.rank = None
11
+ self.level = self.get_level()
13
12
 
13
+ @staticmethod
14
+ def get_level():
15
+ input_level = os.environ.get(MsgConst.MSPROBE_LOG_LEVEL)
16
+ if input_level not in MsgConst.LOG_LEVEL_ENUM:
17
+ return MsgConst.LogLevel.INFO.value
18
+ else:
19
+ return int(input_level)
14
20
 
15
21
  def get_rank(self):
16
22
  return self.rank
@@ -22,23 +28,26 @@ class BaseLogger:
22
28
  msg = msg.replace(char, '_')
23
29
  return func(self, msg, **kwargs)
24
30
  return func_level
25
-
26
- @filter_special_chars
27
- def info(self, msg, **kwargs):
28
- self._print_log(MsgConst.LEVEL[0], msg, **kwargs)
29
-
31
+
30
32
  @filter_special_chars
31
33
  def error(self, msg):
32
- self._print_log(MsgConst.LEVEL[2], msg)
34
+ if self.level <= MsgConst.LogLevel.ERROR.value:
35
+ self._print_log(MsgConst.LOG_LEVEL[3], msg)
33
36
 
34
37
  @filter_special_chars
35
38
  def warning(self, msg):
36
- self._print_log(MsgConst.LEVEL[1], msg)
39
+ if self.level <= MsgConst.LogLevel.WARNING.value:
40
+ self._print_log(MsgConst.LOG_LEVEL[2], msg)
41
+
42
+ @filter_special_chars
43
+ def info(self, msg):
44
+ if self.level <= MsgConst.LogLevel.INFO.value:
45
+ self._print_log(MsgConst.LOG_LEVEL[1], msg)
37
46
 
38
47
  @filter_special_chars
39
48
  def debug(self, msg):
40
- if MSPROBE_LOG_LEVEL == MsgConst.LEVEL[3]:
41
- self._print_log(MsgConst.LEVEL[3], msg)
49
+ if self.level <= MsgConst.LogLevel.DEBUG.value:
50
+ self._print_log(MsgConst.LOG_LEVEL[0], msg)
42
51
 
43
52
  def on_rank_0(self, func):
44
53
  def func_rank_0(*args, **kwargs):
@@ -73,4 +82,5 @@ class BaseLogger:
73
82
  print(full_msg, end=end)
74
83
  sys.stdout.flush()
75
84
 
85
+
76
86
  logger = BaseLogger()