mindstudio-probe 8.1.1__py3-none-any.whl → 8.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/RECORD +95 -94
  3. msprobe/core/common/const.py +3 -0
  4. msprobe/core/common/file_utils.py +45 -5
  5. msprobe/core/common/utils.py +117 -13
  6. msprobe/core/common_config.py +15 -1
  7. msprobe/core/compare/acc_compare.py +21 -9
  8. msprobe/core/compare/compare_cli.py +10 -2
  9. msprobe/core/compare/merge_result/merge_result.py +1 -1
  10. msprobe/core/compare/utils.py +8 -2
  11. msprobe/core/config_check/checkers/base_checker.py +2 -0
  12. msprobe/core/config_check/checkers/hyperparameter_checker.py +5 -4
  13. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +4 -1
  14. msprobe/core/config_check/config_check_cli.py +1 -1
  15. msprobe/core/config_check/config_checker.py +1 -2
  16. msprobe/core/data_dump/data_collector.py +4 -1
  17. msprobe/core/data_dump/data_processor/mindspore_processor.py +23 -1
  18. msprobe/core/data_dump/data_processor/pytorch_processor.py +3 -25
  19. msprobe/core/debugger/precision_debugger.py +13 -8
  20. msprobe/core/hook_manager.py +112 -82
  21. msprobe/core/monitor/utils.py +338 -0
  22. msprobe/core/service.py +2 -1
  23. msprobe/core/single_save/single_comparator.py +5 -3
  24. msprobe/docs/01.installation.md +1 -0
  25. msprobe/docs/05.data_dump_PyTorch.md +4 -4
  26. msprobe/docs/07.accuracy_checker_PyTorch.md +14 -11
  27. msprobe/docs/09.accuracy_checker_MindSpore.md +13 -11
  28. msprobe/docs/10.accuracy_compare_PyTorch.md +3 -1
  29. msprobe/docs/11.accuracy_compare_MindSpore.md +4 -2
  30. msprobe/docs/12.overflow_check_PyTorch.md +3 -2
  31. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  32. msprobe/docs/14.data_parse_PyTorch.md +35 -32
  33. msprobe/docs/21.visualization_PyTorch.md +9 -8
  34. msprobe/docs/22.visualization_MindSpore.md +1 -0
  35. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  36. msprobe/docs/24.code_mapping_Mindspore.md +6 -5
  37. msprobe/docs/31.config_check.md +15 -5
  38. msprobe/docs/33.generate_operator_MindSpore.md +2 -2
  39. msprobe/docs/34.RL_collect.md +18 -9
  40. msprobe/docs/35.nan_analyze.md +4 -3
  41. msprobe/docs/FAQ.md +3 -0
  42. msprobe/docs/img/ms_layer.png +0 -0
  43. msprobe/mindspore/api_accuracy_checker/api_runner.py +29 -1
  44. msprobe/mindspore/cell_processor.py +35 -14
  45. msprobe/mindspore/code_mapping/bind.py +23 -4
  46. msprobe/mindspore/code_mapping/graph_parser.py +6 -4
  47. msprobe/mindspore/common/utils.py +3 -0
  48. msprobe/mindspore/compare/common_dir_compare.py +32 -12
  49. msprobe/mindspore/compare/ms_graph_compare.py +7 -2
  50. msprobe/mindspore/compare/utils.py +9 -1
  51. msprobe/mindspore/debugger/debugger_config.py +13 -11
  52. msprobe/mindspore/debugger/precision_debugger.py +67 -45
  53. msprobe/mindspore/dump/dump_tool_factory.py +2 -0
  54. msprobe/mindspore/dump/hook_cell/hook_cell.py +14 -9
  55. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +12 -7
  56. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +27 -13
  57. msprobe/mindspore/dump/jit_dump.py +6 -3
  58. msprobe/mindspore/dump/kernel_kbyk_dump.py +13 -6
  59. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +6 -5
  60. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +2 -2
  61. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -0
  62. msprobe/mindspore/mindspore_service.py +2 -2
  63. msprobe/mindspore/monitor/common_func.py +1 -1
  64. msprobe/mindspore/monitor/module_hook.py +3 -3
  65. msprobe/mindspore/monitor/utils.py +0 -252
  66. msprobe/mindspore/ms_config.py +0 -1
  67. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  68. msprobe/nan_analyze/graph.py +4 -0
  69. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +15 -6
  70. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +1 -1
  71. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +1 -1
  72. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -4
  73. msprobe/pytorch/common/utils.py +0 -16
  74. msprobe/pytorch/compare/pt_compare.py +5 -0
  75. msprobe/pytorch/debugger/debugger_config.py +12 -5
  76. msprobe/pytorch/debugger/precision_debugger.py +8 -1
  77. msprobe/pytorch/dump/module_dump/hook_wrapper.py +1 -3
  78. msprobe/pytorch/dump/module_dump/module_processer.py +44 -13
  79. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +2 -0
  80. msprobe/pytorch/hook_module/hook_module.py +9 -9
  81. msprobe/pytorch/hook_module/pt_hook_manager.py +7 -7
  82. msprobe/pytorch/monitor/csv2tb.py +3 -10
  83. msprobe/pytorch/monitor/features.py +5 -0
  84. msprobe/pytorch/monitor/module_hook.py +6 -7
  85. msprobe/pytorch/monitor/module_metric.py +0 -3
  86. msprobe/pytorch/monitor/optimizer_collect.py +1 -1
  87. msprobe/pytorch/monitor/utils.py +1 -317
  88. msprobe/pytorch/online_dispatch/dispatch.py +1 -1
  89. msprobe/pytorch/online_dispatch/dump_compare.py +7 -1
  90. msprobe/pytorch/parse_tool/lib/utils.py +2 -4
  91. msprobe/visualization/graph_service.py +1 -1
  92. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/LICENSE +0 -0
  93. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/WHEEL +0 -0
  94. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/entry_points.txt +0 -0
  95. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,9 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- import os
16
- import re
17
- from datetime import datetime
18
15
  from mindspore import dtype as mstype, Tensor
19
16
 
20
17
  from msprobe.mindspore.monitor.features import FUNC_MAP
21
- from msprobe.core.common.const import MonitorConst
22
- from msprobe.core.common.utils import is_int
23
- from msprobe.core.common.log import logger
24
- from msprobe.core.common.file_utils import check_file_or_directory_path
25
18
 
26
19
 
27
20
  def get_single_metrics(op_list, tag, tensor, eps=1e-8, output=None):
@@ -82,248 +75,3 @@ def is_skip_step(step, start_step, step_interval, has_collect_times=0, collect_t
82
75
  :return: whether skip or not, bool
83
76
  """
84
77
  return step < start_step or (step - start_step) % step_interval != 0 or has_collect_times >= collect_times
85
-
86
-
87
- def validate_ops(ops):
88
- if not isinstance(ops, list):
89
- raise TypeError("ops should be a list")
90
- valid_ops = []
91
- for op in ops:
92
- if op not in MonitorConst.OP_LIST:
93
- logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}")
94
- continue
95
- valid_ops.append(op)
96
- if not valid_ops:
97
- default_op = MonitorConst.OP_LIST[0]
98
- valid_ops.append(default_op)
99
- logger.info(f"There is no valid ops, default op {default_op} is used")
100
- # 增加默认shape和dtype参数
101
- if "shape" not in valid_ops:
102
- valid_ops.append("shape")
103
- if "dtype" not in valid_ops:
104
- valid_ops.append("dtype")
105
- return valid_ops
106
-
107
-
108
- def validate_ranks(ranks):
109
- if not isinstance(ranks, list):
110
- raise TypeError("module_ranks should be a list")
111
- for rank in ranks:
112
- if not isinstance(rank, int):
113
- raise TypeError(f"element in module_ranks should be a int, get {type(rank)}")
114
-
115
-
116
- def validate_targets(targets):
117
- if not isinstance(targets, dict):
118
- raise TypeError('targets in config.json should be a dict')
119
- for module_name, field in targets.items():
120
- if not isinstance(module_name, str):
121
- raise TypeError('key of targets should be module_name[str] in config.json')
122
- if not isinstance(field, dict):
123
- raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json')
124
-
125
-
126
- def validate_print_struct(print_struct):
127
- if not isinstance(print_struct, bool):
128
- raise TypeError("print_struct should be a bool")
129
-
130
-
131
- def validate_ur_distribution(ur_distribution):
132
- if not isinstance(ur_distribution, bool):
133
- raise TypeError('ur_distribution should be a bool')
134
-
135
-
136
- def validate_xy_distribution(xy_distribution):
137
- if not isinstance(xy_distribution, bool):
138
- raise TypeError('xy_distribution should be a bool')
139
-
140
-
141
- def validate_wg_distribution(wg_distribution):
142
- if not isinstance(wg_distribution, bool):
143
- raise TypeError('wg_distribution should be a bool')
144
-
145
-
146
- def validate_mg_distribution(mg_distribution):
147
- if not isinstance(mg_distribution, bool):
148
- raise TypeError('mg_distribution should be a bool')
149
-
150
-
151
- def validate_param_distribution(param_distribution):
152
- if not isinstance(param_distribution, bool):
153
- raise TypeError('param_distribution should be a bool')
154
-
155
-
156
- def validate_cc_distribution(cc_distribution):
157
- if not isinstance(cc_distribution, dict):
158
- raise TypeError('cc_distribution should be a dictionary')
159
- expected_keys = {
160
- 'enable': bool,
161
- 'cc_codeline': list,
162
- 'cc_pre_hook': bool,
163
- 'cc_log_only': bool
164
- }
165
- for key, value in cc_distribution.items():
166
- if key in expected_keys:
167
- if not isinstance(value, expected_keys[key]):
168
- raise TypeError(f'cc_distribution {key} should be a {expected_keys[key].__name__}')
169
- else:
170
- raise TypeError(f'{key} of cc_distribution is not supported.')
171
-
172
-
173
- def validate_alert(alert):
174
- if not isinstance(alert, dict):
175
- raise TypeError('alert should be a dictionary')
176
- rules = alert.get('rules')
177
- if rules and isinstance(rules, list):
178
- for rule in rules:
179
- rule_name = rule.get("rule_name")
180
- if rule_name and rule_name not in MonitorConst.RULE_NAME:
181
- raise TypeError(f"{rule_name} is not supported")
182
- args = rule.get("args")
183
- if args and isinstance(args, dict):
184
- threshold = args.get("threshold")
185
- if not isinstance(threshold, (float, int)) or threshold < 0:
186
- raise TypeError('threshold must be float and not less than 0')
187
- dump = alert.get('dump')
188
- if dump and not isinstance(dump, bool):
189
- raise TypeError('dump must be bool.')
190
-
191
-
192
- def validate_step_count_per_record(step_count_per_record):
193
- if not is_int(step_count_per_record):
194
- raise TypeError('step_count_per_record must be int.')
195
- if step_count_per_record < 1:
196
- raise ValueError("step_count_per_record must greater than 0")
197
- if step_count_per_record > 1e6:
198
- raise ValueError("step_count_per_record must smaller than 1e6")
199
-
200
-
201
- def validate_start_step(start_step):
202
- if not is_int(start_step):
203
- raise TypeError('start_step must be int.')
204
- if start_step < 0:
205
- raise ValueError("start_step must greater than 0")
206
- if start_step > 1e8:
207
- raise ValueError("start_step must smaller than 1e8")
208
-
209
-
210
- def validate_step_interval(step_interval):
211
- if not is_int(step_interval):
212
- raise TypeError('step_interval must be int.')
213
- if step_interval < 1:
214
- raise ValueError("step_interval must greater than 1")
215
- if step_interval > 1e8:
216
- raise ValueError("step_interval must smaller than 1e8")
217
-
218
-
219
- def validate_collect_times(collect_times):
220
- if not is_int(collect_times):
221
- raise TypeError('collect_times must be int.')
222
- if collect_times < 1:
223
- raise ValueError("collect_times must greater than 1")
224
-
225
-
226
- def validate_dynamic_on(dynamic_on):
227
- if not isinstance(dynamic_on, bool):
228
- raise TypeError('dynamic_on should be a bool')
229
-
230
-
231
- def validate_monitor_mbs_grad(monitor_mbs_grad):
232
- if not isinstance(monitor_mbs_grad, bool):
233
- logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.')
234
- return False
235
- return monitor_mbs_grad
236
-
237
-
238
- def validate_config(config):
239
- config['ops'] = validate_ops(config.get('ops', []))
240
-
241
- eps = config.get('eps', 1e-8)
242
- if not isinstance(eps, float):
243
- raise TypeError("eps should be a float")
244
-
245
- ranks = config.get("module_ranks", [])
246
- validate_ranks(ranks)
247
-
248
- targets = config.get("targets", {})
249
- validate_targets(targets)
250
-
251
- print_struct = config.get('print_struct', False)
252
- validate_print_struct(print_struct)
253
-
254
- ur_distribution = config.get('ur_distribution', False)
255
- validate_ur_distribution(ur_distribution)
256
-
257
- xy_distribution = config.get('xy_distribution', False)
258
- validate_xy_distribution(xy_distribution)
259
-
260
- wg_distribution = config.get('wg_distribution', False)
261
- validate_wg_distribution(wg_distribution)
262
-
263
- mg_distribution = config.get('mg_distribution', False)
264
- validate_mg_distribution(mg_distribution)
265
-
266
- param_distribution = config.get('param_distribution', False)
267
- validate_param_distribution(param_distribution)
268
-
269
- cc_distribution = config.get('cc_distribution', {})
270
- validate_cc_distribution(cc_distribution)
271
-
272
- alert = config.get('alert', {})
273
- validate_alert(alert)
274
-
275
- step_count_per_record = config.get('step_count_per_record', 1)
276
- validate_step_count_per_record(step_count_per_record)
277
-
278
- start_step = config.get('start_step', 0)
279
- validate_start_step(start_step)
280
-
281
- step_interval = config.get('step_interval', 1)
282
- validate_step_interval(step_interval)
283
-
284
- collect_times = config.get('collect_times', int(1e8))
285
- validate_collect_times(collect_times)
286
-
287
- config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False))
288
-
289
- dynamic_on = config.get('dynamic_on', False)
290
- validate_dynamic_on(dynamic_on)
291
-
292
- if not targets:
293
- if xy_distribution:
294
- config["all_xy"] = True
295
- config["targets"] = {"": {}}
296
- config["is_select"] = False
297
- else:
298
- config["is_select"] = True
299
-
300
-
301
- def time_str2time_digit(time_str):
302
- time_format = '%b%d_%H-%M-%S'
303
- try:
304
- time_digit = datetime.strptime(time_str, time_format)
305
- except Exception as e:
306
- raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \
307
- of existing output dirpath, like 'Dec03_21-34-40'.") from e
308
- return time_digit
309
-
310
-
311
- def get_target_output_dir(monitor_path, time_start, time_end):
312
- check_file_or_directory_path(monitor_path, isdir=True)
313
- time_start = time_str2time_digit(time_start) if time_start is not None else time_start
314
- time_end = time_str2time_digit(time_end) if time_end is not None else time_end
315
- if time_start and time_end and time_start > time_end:
316
- raise ValueError(f"time_start({time_start}) greater than time_end({time_end})")
317
- result = {}
318
- for dirname in os.listdir(monitor_path):
319
- match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname)
320
- if not match:
321
- continue
322
- time_tag = match.group(1)
323
- rank = match.group(2)
324
- target_time = time_str2time_digit(time_tag)
325
- start_ok = time_start is None or target_time >= time_start
326
- end_ok = time_end is None or target_time <= time_end
327
- if start_ok and end_ok:
328
- result[rank] = os.path.join(monitor_path, dirname)
329
- return result
@@ -14,7 +14,6 @@
14
14
  # limitations under the License.
15
15
 
16
16
  from msprobe.core.common.const import Const
17
- from msprobe.core.common.file_utils import load_json
18
17
  from msprobe.core.common.utils import is_int
19
18
  from msprobe.core.common_config import BaseConfig, CommonConfig
20
19
  from msprobe.core.grad_probe.constant import level_adp
@@ -48,4 +48,4 @@ class OverflowCheckToolFactory:
48
48
  logger.error(f"Overflow check is not supported in {config.execution_mode} mode "
49
49
  f"when level is {config.level}.")
50
50
  raise ValueError
51
- return tool(config)
51
+ return (tool(config),)
@@ -16,6 +16,7 @@
16
16
  from dataclasses import dataclass
17
17
  from msprobe.core.common.const import Const
18
18
  from msprobe.core.common.log import logger
19
+ from msprobe.core.common.exceptions import MsprobeException
19
20
  from msprobe.nan_analyze.utils import FileCache, RankPath, is_ignore_op, check_item_anomaly, NanAnalyseConst
20
21
 
21
22
 
@@ -52,6 +53,9 @@ class DataNode:
52
53
 
53
54
  def find_stack(self, stack_info):
54
55
  for item in stack_info.values():
56
+ if not isinstance(item, list):
57
+ raise MsprobeException(MsprobeException.UNSUPPORTED_TYPE_ERROR,
58
+ f'The value\'s type in stack.json should be a list, not {type(item)}!')
55
59
  if len(item) >= 2 and self.op_name in item[0]:
56
60
  return item[1]
57
61
  return {}
@@ -33,7 +33,7 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
33
33
  from msprobe.pytorch.common import parse_json_info_forward_backward
34
34
  from msprobe.pytorch.common.log import logger
35
35
  from msprobe.core.common.file_utils import FileChecker, check_file_suffix, check_link, FileOpen, \
36
- create_directory, load_json, save_json
36
+ create_directory, load_json, save_json, read_csv
37
37
  from msprobe.core.common.file_utils import remove_path
38
38
  from msprobe.core.common.const import FileCheckConst, Const
39
39
  from msprobe.core.common.utils import CompareException
@@ -76,9 +76,18 @@ def split_json_file(input_file, num_splits, filter_api):
76
76
  }
77
77
  }
78
78
  split_filename = os.path.join(input_dir, f"temp_part{i}.json")
79
- save_json(split_filename, temp_data)
80
79
  split_files.append(split_filename)
81
-
80
+ try:
81
+ save_json(split_filename, temp_data)
82
+ except Exception as e:
83
+ logger.error(f"An error occurred while saving split file: {e}")
84
+ for file in split_files:
85
+ try:
86
+ remove_path(file)
87
+ except FileNotFoundError:
88
+ logger.error(f"File not found and could not be deleted: {file}")
89
+ msg = 'ERROR: Split json file failed, please check the input file and try again.'
90
+ raise CompareException(CompareException.PARSE_FILE_ERROR, msg) from e
82
91
  return split_files, total_items
83
92
 
84
93
 
@@ -134,9 +143,9 @@ def run_parallel_ut(config):
134
143
 
135
144
  def update_progress_bar(progress_bar, result_csv_path):
136
145
  while any(process.poll() is None for process in processes):
137
- with FileOpen(result_csv_path, 'r') as result_file:
138
- completed_items = len(result_file.readlines()) - 1
139
- progress_bar.update(completed_items - progress_bar.n)
146
+ result_file = read_csv(result_csv_path)
147
+ completed_items = len(result_file)
148
+ progress_bar.update(completed_items - progress_bar.n)
140
149
  time.sleep(1)
141
150
 
142
151
  for api_info in config.api_files:
@@ -293,7 +293,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
293
293
  if grad_input_index is not None:
294
294
  grad_index = grad_input_index.get('grad_index')
295
295
 
296
- if need_backward:
296
+ if need_backward and out is not None:
297
297
  if need_to_backward(grad_index, out):
298
298
  backward_args = backward_content[api_full_name].get("input")
299
299
  func_options = {
@@ -43,9 +43,9 @@ CommonCompareConfig = namedtuple('CommonCompareConfig', ['compare', 'handle_func
43
43
 
44
44
 
45
45
  def get_gpu_device():
46
+ is_gpu = False
46
47
  try:
47
48
  import torch_npu
48
- is_gpu = False
49
49
  except ImportError:
50
50
  is_gpu = True
51
51
  return is_gpu
@@ -111,10 +111,8 @@ def load_ssl_pem(key_file, cert_file, ca_file, crl_file):
111
111
 
112
112
  try:
113
113
  # your_private_key_password
114
- passphrase = ""
115
- if not passphrase:
116
- import pwinput
117
- passphrase = pwinput.pwinput("Enter your password: ")
114
+ import pwinput
115
+ passphrase = pwinput.pwinput("Enter your password: ")
118
116
  with FileOpen(key_file, "rb") as f:
119
117
  key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read(), passphrase.encode())
120
118
  del passphrase
@@ -264,10 +264,6 @@ class Const:
264
264
  NPU = 'NPU'
265
265
  DISTRIBUTED = 'Distributed'
266
266
 
267
- HIFLOAT8_TYPE = "torch_npu.HiFloat8Tensor"
268
- FLOAT8_E5M2_TYPE = "torch.float8_e5m2"
269
- FLOAT8_E4M3FN_TYPE = "torch.float8_e4m3fn"
270
-
271
267
  RAISE_PRECISION = {
272
268
  torch.float16: torch.float32,
273
269
  torch.bfloat16: torch.float32,
@@ -483,18 +479,6 @@ def is_torch_nn_module(variable):
483
479
  return isinstance(variable, torch.nn.Module) and not isinstance(variable, torch.jit.ScriptModule)
484
480
 
485
481
 
486
- def is_hifloat8_tensor(tensor):
487
- if not is_gpu and hasattr(torch_npu, "HiFloat8Tensor") and isinstance(tensor, torch_npu.HiFloat8Tensor):
488
- return True
489
- return False
490
-
491
-
492
- def is_float8_tensor(tensor):
493
- if str(tensor.dtype) in [Const.FLOAT8_E5M2_TYPE, Const.FLOAT8_E4M3FN_TYPE]:
494
- return True
495
- return is_hifloat8_tensor(tensor)
496
-
497
-
498
482
  def register_forward_pre_hook(module, forward_pre_hook):
499
483
  if torch_version_above_or_equal_2:
500
484
  module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
@@ -13,6 +13,8 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from msprobe.core.common.utils import CompareException
17
+ from msprobe.core.common.log import logger
16
18
  from msprobe.core.compare.acc_compare import Comparator, ModeConfig, MappingConfig, setup_comparison
17
19
  from msprobe.pytorch.compare.utils import read_pt_data
18
20
 
@@ -24,6 +26,9 @@ def read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name, _) -> tup
24
26
 
25
27
 
26
28
  def compare(input_param, output_path, **kwargs):
29
+ if not isinstance(input_param, dict):
30
+ logger.error("input_param should be dict, please check!")
31
+ raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
27
32
  config = setup_comparison(input_param, output_path, **kwargs)
28
33
 
29
34
  mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match,
@@ -98,6 +98,11 @@ class DebuggerConfig:
98
98
 
99
99
  def check_model(self, instance, start_model, token_range=None):
100
100
  instance.model = start_model if start_model is not None else instance.model
101
+
102
+ if token_range and not instance.model:
103
+ error_info = "The 'model' parameter must be provided when token_range is not None"
104
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, error_info)
105
+
101
106
  if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX] and token_range is None:
102
107
  return
103
108
 
@@ -110,18 +115,20 @@ class DebuggerConfig:
110
115
  if is_torch_nn_module(instance.model):
111
116
  return
112
117
 
113
- error_model = None
114
118
  if isinstance(instance.model, (list, tuple)):
119
+ error_model = None
115
120
  for model in instance.model:
116
121
  if not is_torch_nn_module(model):
117
122
  error_model = model
118
123
  break
124
+ if error_model is not None:
125
+ error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] "
126
+ f"type, currently there is an unsupported {type(error_model)} type.")
127
+ raise MsprobeException(
128
+ MsprobeException.INVALID_PARAM_ERROR, error_info)
119
129
  else:
120
- error_model = instance.model
121
-
122
- if error_model is not None:
123
130
  error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] "
124
- f"type, currently there is an unsupported {type(error_model)} type.")
131
+ f"type, currently there is an unsupported {type(instance.model)} type.")
125
132
  raise MsprobeException(
126
133
  MsprobeException.INVALID_PARAM_ERROR, error_info)
127
134
 
@@ -17,7 +17,7 @@ from torch.utils.data import dataloader
17
17
 
18
18
  from msprobe.core.common.const import Const, MsgConst
19
19
  from msprobe.core.common.exceptions import MsprobeException
20
- from msprobe.core.common.utils import check_token_range
20
+ from msprobe.core.common.utils import check_token_range, ThreadSafe
21
21
  from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger
22
22
  from msprobe.pytorch.common.log import logger
23
23
  from msprobe.pytorch.common.utils import check_save_param, is_torch_nn_module
@@ -81,6 +81,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
81
81
  return func_wrapper
82
82
 
83
83
  @classmethod
84
+ @ThreadSafe.synchronized
84
85
  def start(cls, model=None, token_range=None):
85
86
  instance = cls._get_instance()
86
87
  if instance is None:
@@ -95,6 +96,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
95
96
  instance.service.start(instance.model, token_range)
96
97
 
97
98
  @classmethod
99
+ @ThreadSafe.synchronized
98
100
  def stop(cls):
99
101
  instance = cls._get_instance()
100
102
  if instance is None:
@@ -105,6 +107,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
105
107
  instance.service.stop()
106
108
 
107
109
  @classmethod
110
+ @ThreadSafe.synchronized
108
111
  def step(cls):
109
112
  instance = cls._get_instance()
110
113
  if instance is None:
@@ -112,6 +115,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
112
115
  cls._instance.service.step()
113
116
 
114
117
  @classmethod
118
+ @ThreadSafe.synchronized
115
119
  def monitor(cls, model):
116
120
  if not cls._instance:
117
121
  raise Exception(MsgConst.NOT_CREATED_INSTANCE)
@@ -120,6 +124,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
120
124
  cls._instance.gm.monitor(model)
121
125
 
122
126
  @classmethod
127
+ @ThreadSafe.synchronized
123
128
  def save(cls, variable, name, save_backward=True):
124
129
  instance = cls._instance
125
130
  if not instance:
@@ -143,6 +148,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
143
148
  dataloader._BaseDataLoaderIter.__next__ = self._iter_tracer(dataloader._BaseDataLoaderIter.__next__)
144
149
 
145
150
 
151
+ @ThreadSafe.synchronized
146
152
  def module_dump(module, dump_name):
147
153
  if not is_torch_nn_module(module):
148
154
  raise MsprobeException(
@@ -164,6 +170,7 @@ def module_dump(module, dump_name):
164
170
  instance.module_dumper.start_module_dump(module, dump_name)
165
171
 
166
172
 
173
+ @ThreadSafe.synchronized
167
174
  def module_dump_end():
168
175
  instance = PrecisionDebugger._instance
169
176
  if not instance:
@@ -21,13 +21,11 @@ from torch.utils.hooks import BackwardHook
21
21
  from msprobe.core.common.const import Const
22
22
  from msprobe.core.common.decorator import recursion_depth_decorator
23
23
  from msprobe.pytorch.common.log import logger
24
- from msprobe.pytorch.common.utils import is_float8_tensor
25
24
 
26
25
 
27
26
  def wrap_setup_backward_hook(func):
28
27
  def requires_clone(tensor):
29
- return isinstance(tensor, torch.Tensor) and not is_float8_tensor(tensor) and \
30
- tensor.requires_grad and torch.is_grad_enabled()
28
+ return isinstance(tensor, torch.Tensor) and tensor.requires_grad and torch.is_grad_enabled()
31
29
 
32
30
  @recursion_depth_decorator("Dump: wrap_setup_backward_hook.parse_tensor", max_depth=Const.DUMP_MAX_DEPTH)
33
31
  def parse_tensor(item, tensor_list):