mindstudio-probe 8.1.1__py3-none-any.whl → 8.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/RECORD +95 -94
- msprobe/core/common/const.py +3 -0
- msprobe/core/common/file_utils.py +45 -5
- msprobe/core/common/utils.py +117 -13
- msprobe/core/common_config.py +15 -1
- msprobe/core/compare/acc_compare.py +21 -9
- msprobe/core/compare/compare_cli.py +10 -2
- msprobe/core/compare/merge_result/merge_result.py +1 -1
- msprobe/core/compare/utils.py +8 -2
- msprobe/core/config_check/checkers/base_checker.py +2 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +5 -4
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +4 -1
- msprobe/core/config_check/config_check_cli.py +1 -1
- msprobe/core/config_check/config_checker.py +1 -2
- msprobe/core/data_dump/data_collector.py +4 -1
- msprobe/core/data_dump/data_processor/mindspore_processor.py +23 -1
- msprobe/core/data_dump/data_processor/pytorch_processor.py +3 -25
- msprobe/core/debugger/precision_debugger.py +13 -8
- msprobe/core/hook_manager.py +112 -82
- msprobe/core/monitor/utils.py +338 -0
- msprobe/core/service.py +2 -1
- msprobe/core/single_save/single_comparator.py +5 -3
- msprobe/docs/01.installation.md +1 -0
- msprobe/docs/05.data_dump_PyTorch.md +4 -4
- msprobe/docs/07.accuracy_checker_PyTorch.md +14 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +13 -11
- msprobe/docs/10.accuracy_compare_PyTorch.md +3 -1
- msprobe/docs/11.accuracy_compare_MindSpore.md +4 -2
- msprobe/docs/12.overflow_check_PyTorch.md +3 -2
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/14.data_parse_PyTorch.md +35 -32
- msprobe/docs/21.visualization_PyTorch.md +9 -8
- msprobe/docs/22.visualization_MindSpore.md +1 -0
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/24.code_mapping_Mindspore.md +6 -5
- msprobe/docs/31.config_check.md +15 -5
- msprobe/docs/33.generate_operator_MindSpore.md +2 -2
- msprobe/docs/34.RL_collect.md +18 -9
- msprobe/docs/35.nan_analyze.md +4 -3
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +29 -1
- msprobe/mindspore/cell_processor.py +35 -14
- msprobe/mindspore/code_mapping/bind.py +23 -4
- msprobe/mindspore/code_mapping/graph_parser.py +6 -4
- msprobe/mindspore/common/utils.py +3 -0
- msprobe/mindspore/compare/common_dir_compare.py +32 -12
- msprobe/mindspore/compare/ms_graph_compare.py +7 -2
- msprobe/mindspore/compare/utils.py +9 -1
- msprobe/mindspore/debugger/debugger_config.py +13 -11
- msprobe/mindspore/debugger/precision_debugger.py +67 -45
- msprobe/mindspore/dump/dump_tool_factory.py +2 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +14 -9
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +12 -7
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +27 -13
- msprobe/mindspore/dump/jit_dump.py +6 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +13 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +6 -5
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -0
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/monitor/common_func.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +3 -3
- msprobe/mindspore/monitor/utils.py +0 -252
- msprobe/mindspore/ms_config.py +0 -1
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/nan_analyze/graph.py +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +15 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +1 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +1 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -4
- msprobe/pytorch/common/utils.py +0 -16
- msprobe/pytorch/compare/pt_compare.py +5 -0
- msprobe/pytorch/debugger/debugger_config.py +12 -5
- msprobe/pytorch/debugger/precision_debugger.py +8 -1
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +1 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +44 -13
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +2 -0
- msprobe/pytorch/hook_module/hook_module.py +9 -9
- msprobe/pytorch/hook_module/pt_hook_manager.py +7 -7
- msprobe/pytorch/monitor/csv2tb.py +3 -10
- msprobe/pytorch/monitor/features.py +5 -0
- msprobe/pytorch/monitor/module_hook.py +6 -7
- msprobe/pytorch/monitor/module_metric.py +0 -3
- msprobe/pytorch/monitor/optimizer_collect.py +1 -1
- msprobe/pytorch/monitor/utils.py +1 -317
- msprobe/pytorch/online_dispatch/dispatch.py +1 -1
- msprobe/pytorch/online_dispatch/dump_compare.py +7 -1
- msprobe/pytorch/parse_tool/lib/utils.py +2 -4
- msprobe/visualization/graph_service.py +1 -1
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/top_level.txt +0 -0
|
@@ -12,16 +12,9 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
import os
|
|
16
|
-
import re
|
|
17
|
-
from datetime import datetime
|
|
18
15
|
from mindspore import dtype as mstype, Tensor
|
|
19
16
|
|
|
20
17
|
from msprobe.mindspore.monitor.features import FUNC_MAP
|
|
21
|
-
from msprobe.core.common.const import MonitorConst
|
|
22
|
-
from msprobe.core.common.utils import is_int
|
|
23
|
-
from msprobe.core.common.log import logger
|
|
24
|
-
from msprobe.core.common.file_utils import check_file_or_directory_path
|
|
25
18
|
|
|
26
19
|
|
|
27
20
|
def get_single_metrics(op_list, tag, tensor, eps=1e-8, output=None):
|
|
@@ -82,248 +75,3 @@ def is_skip_step(step, start_step, step_interval, has_collect_times=0, collect_t
|
|
|
82
75
|
:return: whether skip or not, bool
|
|
83
76
|
"""
|
|
84
77
|
return step < start_step or (step - start_step) % step_interval != 0 or has_collect_times >= collect_times
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def validate_ops(ops):
|
|
88
|
-
if not isinstance(ops, list):
|
|
89
|
-
raise TypeError("ops should be a list")
|
|
90
|
-
valid_ops = []
|
|
91
|
-
for op in ops:
|
|
92
|
-
if op not in MonitorConst.OP_LIST:
|
|
93
|
-
logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}")
|
|
94
|
-
continue
|
|
95
|
-
valid_ops.append(op)
|
|
96
|
-
if not valid_ops:
|
|
97
|
-
default_op = MonitorConst.OP_LIST[0]
|
|
98
|
-
valid_ops.append(default_op)
|
|
99
|
-
logger.info(f"There is no valid ops, default op {default_op} is used")
|
|
100
|
-
# 增加默认shape和dtype参数
|
|
101
|
-
if "shape" not in valid_ops:
|
|
102
|
-
valid_ops.append("shape")
|
|
103
|
-
if "dtype" not in valid_ops:
|
|
104
|
-
valid_ops.append("dtype")
|
|
105
|
-
return valid_ops
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
def validate_ranks(ranks):
|
|
109
|
-
if not isinstance(ranks, list):
|
|
110
|
-
raise TypeError("module_ranks should be a list")
|
|
111
|
-
for rank in ranks:
|
|
112
|
-
if not isinstance(rank, int):
|
|
113
|
-
raise TypeError(f"element in module_ranks should be a int, get {type(rank)}")
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
def validate_targets(targets):
|
|
117
|
-
if not isinstance(targets, dict):
|
|
118
|
-
raise TypeError('targets in config.json should be a dict')
|
|
119
|
-
for module_name, field in targets.items():
|
|
120
|
-
if not isinstance(module_name, str):
|
|
121
|
-
raise TypeError('key of targets should be module_name[str] in config.json')
|
|
122
|
-
if not isinstance(field, dict):
|
|
123
|
-
raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json')
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
def validate_print_struct(print_struct):
|
|
127
|
-
if not isinstance(print_struct, bool):
|
|
128
|
-
raise TypeError("print_struct should be a bool")
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
def validate_ur_distribution(ur_distribution):
|
|
132
|
-
if not isinstance(ur_distribution, bool):
|
|
133
|
-
raise TypeError('ur_distribution should be a bool')
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
def validate_xy_distribution(xy_distribution):
|
|
137
|
-
if not isinstance(xy_distribution, bool):
|
|
138
|
-
raise TypeError('xy_distribution should be a bool')
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
def validate_wg_distribution(wg_distribution):
|
|
142
|
-
if not isinstance(wg_distribution, bool):
|
|
143
|
-
raise TypeError('wg_distribution should be a bool')
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
def validate_mg_distribution(mg_distribution):
|
|
147
|
-
if not isinstance(mg_distribution, bool):
|
|
148
|
-
raise TypeError('mg_distribution should be a bool')
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
def validate_param_distribution(param_distribution):
|
|
152
|
-
if not isinstance(param_distribution, bool):
|
|
153
|
-
raise TypeError('param_distribution should be a bool')
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
def validate_cc_distribution(cc_distribution):
|
|
157
|
-
if not isinstance(cc_distribution, dict):
|
|
158
|
-
raise TypeError('cc_distribution should be a dictionary')
|
|
159
|
-
expected_keys = {
|
|
160
|
-
'enable': bool,
|
|
161
|
-
'cc_codeline': list,
|
|
162
|
-
'cc_pre_hook': bool,
|
|
163
|
-
'cc_log_only': bool
|
|
164
|
-
}
|
|
165
|
-
for key, value in cc_distribution.items():
|
|
166
|
-
if key in expected_keys:
|
|
167
|
-
if not isinstance(value, expected_keys[key]):
|
|
168
|
-
raise TypeError(f'cc_distribution {key} should be a {expected_keys[key].__name__}')
|
|
169
|
-
else:
|
|
170
|
-
raise TypeError(f'{key} of cc_distribution is not supported.')
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
def validate_alert(alert):
|
|
174
|
-
if not isinstance(alert, dict):
|
|
175
|
-
raise TypeError('alert should be a dictionary')
|
|
176
|
-
rules = alert.get('rules')
|
|
177
|
-
if rules and isinstance(rules, list):
|
|
178
|
-
for rule in rules:
|
|
179
|
-
rule_name = rule.get("rule_name")
|
|
180
|
-
if rule_name and rule_name not in MonitorConst.RULE_NAME:
|
|
181
|
-
raise TypeError(f"{rule_name} is not supported")
|
|
182
|
-
args = rule.get("args")
|
|
183
|
-
if args and isinstance(args, dict):
|
|
184
|
-
threshold = args.get("threshold")
|
|
185
|
-
if not isinstance(threshold, (float, int)) or threshold < 0:
|
|
186
|
-
raise TypeError('threshold must be float and not less than 0')
|
|
187
|
-
dump = alert.get('dump')
|
|
188
|
-
if dump and not isinstance(dump, bool):
|
|
189
|
-
raise TypeError('dump must be bool.')
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
def validate_step_count_per_record(step_count_per_record):
|
|
193
|
-
if not is_int(step_count_per_record):
|
|
194
|
-
raise TypeError('step_count_per_record must be int.')
|
|
195
|
-
if step_count_per_record < 1:
|
|
196
|
-
raise ValueError("step_count_per_record must greater than 0")
|
|
197
|
-
if step_count_per_record > 1e6:
|
|
198
|
-
raise ValueError("step_count_per_record must smaller than 1e6")
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
def validate_start_step(start_step):
|
|
202
|
-
if not is_int(start_step):
|
|
203
|
-
raise TypeError('start_step must be int.')
|
|
204
|
-
if start_step < 0:
|
|
205
|
-
raise ValueError("start_step must greater than 0")
|
|
206
|
-
if start_step > 1e8:
|
|
207
|
-
raise ValueError("start_step must smaller than 1e8")
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
def validate_step_interval(step_interval):
|
|
211
|
-
if not is_int(step_interval):
|
|
212
|
-
raise TypeError('step_interval must be int.')
|
|
213
|
-
if step_interval < 1:
|
|
214
|
-
raise ValueError("step_interval must greater than 1")
|
|
215
|
-
if step_interval > 1e8:
|
|
216
|
-
raise ValueError("step_interval must smaller than 1e8")
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
def validate_collect_times(collect_times):
|
|
220
|
-
if not is_int(collect_times):
|
|
221
|
-
raise TypeError('collect_times must be int.')
|
|
222
|
-
if collect_times < 1:
|
|
223
|
-
raise ValueError("collect_times must greater than 1")
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
def validate_dynamic_on(dynamic_on):
|
|
227
|
-
if not isinstance(dynamic_on, bool):
|
|
228
|
-
raise TypeError('dynamic_on should be a bool')
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
def validate_monitor_mbs_grad(monitor_mbs_grad):
|
|
232
|
-
if not isinstance(monitor_mbs_grad, bool):
|
|
233
|
-
logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.')
|
|
234
|
-
return False
|
|
235
|
-
return monitor_mbs_grad
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
def validate_config(config):
|
|
239
|
-
config['ops'] = validate_ops(config.get('ops', []))
|
|
240
|
-
|
|
241
|
-
eps = config.get('eps', 1e-8)
|
|
242
|
-
if not isinstance(eps, float):
|
|
243
|
-
raise TypeError("eps should be a float")
|
|
244
|
-
|
|
245
|
-
ranks = config.get("module_ranks", [])
|
|
246
|
-
validate_ranks(ranks)
|
|
247
|
-
|
|
248
|
-
targets = config.get("targets", {})
|
|
249
|
-
validate_targets(targets)
|
|
250
|
-
|
|
251
|
-
print_struct = config.get('print_struct', False)
|
|
252
|
-
validate_print_struct(print_struct)
|
|
253
|
-
|
|
254
|
-
ur_distribution = config.get('ur_distribution', False)
|
|
255
|
-
validate_ur_distribution(ur_distribution)
|
|
256
|
-
|
|
257
|
-
xy_distribution = config.get('xy_distribution', False)
|
|
258
|
-
validate_xy_distribution(xy_distribution)
|
|
259
|
-
|
|
260
|
-
wg_distribution = config.get('wg_distribution', False)
|
|
261
|
-
validate_wg_distribution(wg_distribution)
|
|
262
|
-
|
|
263
|
-
mg_distribution = config.get('mg_distribution', False)
|
|
264
|
-
validate_mg_distribution(mg_distribution)
|
|
265
|
-
|
|
266
|
-
param_distribution = config.get('param_distribution', False)
|
|
267
|
-
validate_param_distribution(param_distribution)
|
|
268
|
-
|
|
269
|
-
cc_distribution = config.get('cc_distribution', {})
|
|
270
|
-
validate_cc_distribution(cc_distribution)
|
|
271
|
-
|
|
272
|
-
alert = config.get('alert', {})
|
|
273
|
-
validate_alert(alert)
|
|
274
|
-
|
|
275
|
-
step_count_per_record = config.get('step_count_per_record', 1)
|
|
276
|
-
validate_step_count_per_record(step_count_per_record)
|
|
277
|
-
|
|
278
|
-
start_step = config.get('start_step', 0)
|
|
279
|
-
validate_start_step(start_step)
|
|
280
|
-
|
|
281
|
-
step_interval = config.get('step_interval', 1)
|
|
282
|
-
validate_step_interval(step_interval)
|
|
283
|
-
|
|
284
|
-
collect_times = config.get('collect_times', int(1e8))
|
|
285
|
-
validate_collect_times(collect_times)
|
|
286
|
-
|
|
287
|
-
config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False))
|
|
288
|
-
|
|
289
|
-
dynamic_on = config.get('dynamic_on', False)
|
|
290
|
-
validate_dynamic_on(dynamic_on)
|
|
291
|
-
|
|
292
|
-
if not targets:
|
|
293
|
-
if xy_distribution:
|
|
294
|
-
config["all_xy"] = True
|
|
295
|
-
config["targets"] = {"": {}}
|
|
296
|
-
config["is_select"] = False
|
|
297
|
-
else:
|
|
298
|
-
config["is_select"] = True
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
def time_str2time_digit(time_str):
|
|
302
|
-
time_format = '%b%d_%H-%M-%S'
|
|
303
|
-
try:
|
|
304
|
-
time_digit = datetime.strptime(time_str, time_format)
|
|
305
|
-
except Exception as e:
|
|
306
|
-
raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \
|
|
307
|
-
of existing output dirpath, like 'Dec03_21-34-40'.") from e
|
|
308
|
-
return time_digit
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
def get_target_output_dir(monitor_path, time_start, time_end):
|
|
312
|
-
check_file_or_directory_path(monitor_path, isdir=True)
|
|
313
|
-
time_start = time_str2time_digit(time_start) if time_start is not None else time_start
|
|
314
|
-
time_end = time_str2time_digit(time_end) if time_end is not None else time_end
|
|
315
|
-
if time_start and time_end and time_start > time_end:
|
|
316
|
-
raise ValueError(f"time_start({time_start}) greater than time_end({time_end})")
|
|
317
|
-
result = {}
|
|
318
|
-
for dirname in os.listdir(monitor_path):
|
|
319
|
-
match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname)
|
|
320
|
-
if not match:
|
|
321
|
-
continue
|
|
322
|
-
time_tag = match.group(1)
|
|
323
|
-
rank = match.group(2)
|
|
324
|
-
target_time = time_str2time_digit(time_tag)
|
|
325
|
-
start_ok = time_start is None or target_time >= time_start
|
|
326
|
-
end_ok = time_end is None or target_time <= time_end
|
|
327
|
-
if start_ok and end_ok:
|
|
328
|
-
result[rank] = os.path.join(monitor_path, dirname)
|
|
329
|
-
return result
|
msprobe/mindspore/ms_config.py
CHANGED
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
from msprobe.core.common.const import Const
|
|
17
|
-
from msprobe.core.common.file_utils import load_json
|
|
18
17
|
from msprobe.core.common.utils import is_int
|
|
19
18
|
from msprobe.core.common_config import BaseConfig, CommonConfig
|
|
20
19
|
from msprobe.core.grad_probe.constant import level_adp
|
msprobe/nan_analyze/graph.py
CHANGED
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
from dataclasses import dataclass
|
|
17
17
|
from msprobe.core.common.const import Const
|
|
18
18
|
from msprobe.core.common.log import logger
|
|
19
|
+
from msprobe.core.common.exceptions import MsprobeException
|
|
19
20
|
from msprobe.nan_analyze.utils import FileCache, RankPath, is_ignore_op, check_item_anomaly, NanAnalyseConst
|
|
20
21
|
|
|
21
22
|
|
|
@@ -52,6 +53,9 @@ class DataNode:
|
|
|
52
53
|
|
|
53
54
|
def find_stack(self, stack_info):
|
|
54
55
|
for item in stack_info.values():
|
|
56
|
+
if not isinstance(item, list):
|
|
57
|
+
raise MsprobeException(MsprobeException.UNSUPPORTED_TYPE_ERROR,
|
|
58
|
+
f'The value\'s type in stack.json should be a list, not {type(item)}!')
|
|
55
59
|
if len(item) >= 2 and self.op_name in item[0]:
|
|
56
60
|
return item[1]
|
|
57
61
|
return {}
|
|
@@ -33,7 +33,7 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator
|
|
|
33
33
|
from msprobe.pytorch.common import parse_json_info_forward_backward
|
|
34
34
|
from msprobe.pytorch.common.log import logger
|
|
35
35
|
from msprobe.core.common.file_utils import FileChecker, check_file_suffix, check_link, FileOpen, \
|
|
36
|
-
create_directory, load_json, save_json
|
|
36
|
+
create_directory, load_json, save_json, read_csv
|
|
37
37
|
from msprobe.core.common.file_utils import remove_path
|
|
38
38
|
from msprobe.core.common.const import FileCheckConst, Const
|
|
39
39
|
from msprobe.core.common.utils import CompareException
|
|
@@ -76,9 +76,18 @@ def split_json_file(input_file, num_splits, filter_api):
|
|
|
76
76
|
}
|
|
77
77
|
}
|
|
78
78
|
split_filename = os.path.join(input_dir, f"temp_part{i}.json")
|
|
79
|
-
save_json(split_filename, temp_data)
|
|
80
79
|
split_files.append(split_filename)
|
|
81
|
-
|
|
80
|
+
try:
|
|
81
|
+
save_json(split_filename, temp_data)
|
|
82
|
+
except Exception as e:
|
|
83
|
+
logger.error(f"An error occurred while saving split file: {e}")
|
|
84
|
+
for file in split_files:
|
|
85
|
+
try:
|
|
86
|
+
remove_path(file)
|
|
87
|
+
except FileNotFoundError:
|
|
88
|
+
logger.error(f"File not found and could not be deleted: {file}")
|
|
89
|
+
msg = 'ERROR: Split json file failed, please check the input file and try again.'
|
|
90
|
+
raise CompareException(CompareException.PARSE_FILE_ERROR, msg) from e
|
|
82
91
|
return split_files, total_items
|
|
83
92
|
|
|
84
93
|
|
|
@@ -134,9 +143,9 @@ def run_parallel_ut(config):
|
|
|
134
143
|
|
|
135
144
|
def update_progress_bar(progress_bar, result_csv_path):
|
|
136
145
|
while any(process.poll() is None for process in processes):
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
146
|
+
result_file = read_csv(result_csv_path)
|
|
147
|
+
completed_items = len(result_file)
|
|
148
|
+
progress_bar.update(completed_items - progress_bar.n)
|
|
140
149
|
time.sleep(1)
|
|
141
150
|
|
|
142
151
|
for api_info in config.api_files:
|
|
@@ -293,7 +293,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict
|
|
|
293
293
|
if grad_input_index is not None:
|
|
294
294
|
grad_index = grad_input_index.get('grad_index')
|
|
295
295
|
|
|
296
|
-
if need_backward:
|
|
296
|
+
if need_backward and out is not None:
|
|
297
297
|
if need_to_backward(grad_index, out):
|
|
298
298
|
backward_args = backward_content[api_full_name].get("input")
|
|
299
299
|
func_options = {
|
|
@@ -111,10 +111,8 @@ def load_ssl_pem(key_file, cert_file, ca_file, crl_file):
|
|
|
111
111
|
|
|
112
112
|
try:
|
|
113
113
|
# your_private_key_password
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
import pwinput
|
|
117
|
-
passphrase = pwinput.pwinput("Enter your password: ")
|
|
114
|
+
import pwinput
|
|
115
|
+
passphrase = pwinput.pwinput("Enter your password: ")
|
|
118
116
|
with FileOpen(key_file, "rb") as f:
|
|
119
117
|
key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read(), passphrase.encode())
|
|
120
118
|
del passphrase
|
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -264,10 +264,6 @@ class Const:
|
|
|
264
264
|
NPU = 'NPU'
|
|
265
265
|
DISTRIBUTED = 'Distributed'
|
|
266
266
|
|
|
267
|
-
HIFLOAT8_TYPE = "torch_npu.HiFloat8Tensor"
|
|
268
|
-
FLOAT8_E5M2_TYPE = "torch.float8_e5m2"
|
|
269
|
-
FLOAT8_E4M3FN_TYPE = "torch.float8_e4m3fn"
|
|
270
|
-
|
|
271
267
|
RAISE_PRECISION = {
|
|
272
268
|
torch.float16: torch.float32,
|
|
273
269
|
torch.bfloat16: torch.float32,
|
|
@@ -483,18 +479,6 @@ def is_torch_nn_module(variable):
|
|
|
483
479
|
return isinstance(variable, torch.nn.Module) and not isinstance(variable, torch.jit.ScriptModule)
|
|
484
480
|
|
|
485
481
|
|
|
486
|
-
def is_hifloat8_tensor(tensor):
|
|
487
|
-
if not is_gpu and hasattr(torch_npu, "HiFloat8Tensor") and isinstance(tensor, torch_npu.HiFloat8Tensor):
|
|
488
|
-
return True
|
|
489
|
-
return False
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
def is_float8_tensor(tensor):
|
|
493
|
-
if str(tensor.dtype) in [Const.FLOAT8_E5M2_TYPE, Const.FLOAT8_E4M3FN_TYPE]:
|
|
494
|
-
return True
|
|
495
|
-
return is_hifloat8_tensor(tensor)
|
|
496
|
-
|
|
497
|
-
|
|
498
482
|
def register_forward_pre_hook(module, forward_pre_hook):
|
|
499
483
|
if torch_version_above_or_equal_2:
|
|
500
484
|
module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True)
|
|
@@ -13,6 +13,8 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
from msprobe.core.common.utils import CompareException
|
|
17
|
+
from msprobe.core.common.log import logger
|
|
16
18
|
from msprobe.core.compare.acc_compare import Comparator, ModeConfig, MappingConfig, setup_comparison
|
|
17
19
|
from msprobe.pytorch.compare.utils import read_pt_data
|
|
18
20
|
|
|
@@ -24,6 +26,9 @@ def read_real_data(npu_dir, npu_data_name, bench_dir, bench_data_name, _) -> tup
|
|
|
24
26
|
|
|
25
27
|
|
|
26
28
|
def compare(input_param, output_path, **kwargs):
|
|
29
|
+
if not isinstance(input_param, dict):
|
|
30
|
+
logger.error("input_param should be dict, please check!")
|
|
31
|
+
raise CompareException(CompareException.INVALID_OBJECT_TYPE_ERROR)
|
|
27
32
|
config = setup_comparison(input_param, output_path, **kwargs)
|
|
28
33
|
|
|
29
34
|
mode_config = ModeConfig(config.stack_mode, config.auto_analyze, config.fuzzy_match,
|
|
@@ -98,6 +98,11 @@ class DebuggerConfig:
|
|
|
98
98
|
|
|
99
99
|
def check_model(self, instance, start_model, token_range=None):
|
|
100
100
|
instance.model = start_model if start_model is not None else instance.model
|
|
101
|
+
|
|
102
|
+
if token_range and not instance.model:
|
|
103
|
+
error_info = "The 'model' parameter must be provided when token_range is not None"
|
|
104
|
+
raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
105
|
+
|
|
101
106
|
if self.level not in [Const.LEVEL_L0, Const.LEVEL_MIX] and token_range is None:
|
|
102
107
|
return
|
|
103
108
|
|
|
@@ -110,18 +115,20 @@ class DebuggerConfig:
|
|
|
110
115
|
if is_torch_nn_module(instance.model):
|
|
111
116
|
return
|
|
112
117
|
|
|
113
|
-
error_model = None
|
|
114
118
|
if isinstance(instance.model, (list, tuple)):
|
|
119
|
+
error_model = None
|
|
115
120
|
for model in instance.model:
|
|
116
121
|
if not is_torch_nn_module(model):
|
|
117
122
|
error_model = model
|
|
118
123
|
break
|
|
124
|
+
if error_model is not None:
|
|
125
|
+
error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] "
|
|
126
|
+
f"type, currently there is an unsupported {type(error_model)} type.")
|
|
127
|
+
raise MsprobeException(
|
|
128
|
+
MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
119
129
|
else:
|
|
120
|
-
error_model = instance.model
|
|
121
|
-
|
|
122
|
-
if error_model is not None:
|
|
123
130
|
error_info = (f"The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] "
|
|
124
|
-
f"type, currently there is an unsupported {type(
|
|
131
|
+
f"type, currently there is an unsupported {type(instance.model)} type.")
|
|
125
132
|
raise MsprobeException(
|
|
126
133
|
MsprobeException.INVALID_PARAM_ERROR, error_info)
|
|
127
134
|
|
|
@@ -17,7 +17,7 @@ from torch.utils.data import dataloader
|
|
|
17
17
|
|
|
18
18
|
from msprobe.core.common.const import Const, MsgConst
|
|
19
19
|
from msprobe.core.common.exceptions import MsprobeException
|
|
20
|
-
from msprobe.core.common.utils import check_token_range
|
|
20
|
+
from msprobe.core.common.utils import check_token_range, ThreadSafe
|
|
21
21
|
from msprobe.core.debugger.precision_debugger import BasePrecisionDebugger
|
|
22
22
|
from msprobe.pytorch.common.log import logger
|
|
23
23
|
from msprobe.pytorch.common.utils import check_save_param, is_torch_nn_module
|
|
@@ -81,6 +81,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
81
81
|
return func_wrapper
|
|
82
82
|
|
|
83
83
|
@classmethod
|
|
84
|
+
@ThreadSafe.synchronized
|
|
84
85
|
def start(cls, model=None, token_range=None):
|
|
85
86
|
instance = cls._get_instance()
|
|
86
87
|
if instance is None:
|
|
@@ -95,6 +96,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
95
96
|
instance.service.start(instance.model, token_range)
|
|
96
97
|
|
|
97
98
|
@classmethod
|
|
99
|
+
@ThreadSafe.synchronized
|
|
98
100
|
def stop(cls):
|
|
99
101
|
instance = cls._get_instance()
|
|
100
102
|
if instance is None:
|
|
@@ -105,6 +107,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
105
107
|
instance.service.stop()
|
|
106
108
|
|
|
107
109
|
@classmethod
|
|
110
|
+
@ThreadSafe.synchronized
|
|
108
111
|
def step(cls):
|
|
109
112
|
instance = cls._get_instance()
|
|
110
113
|
if instance is None:
|
|
@@ -112,6 +115,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
112
115
|
cls._instance.service.step()
|
|
113
116
|
|
|
114
117
|
@classmethod
|
|
118
|
+
@ThreadSafe.synchronized
|
|
115
119
|
def monitor(cls, model):
|
|
116
120
|
if not cls._instance:
|
|
117
121
|
raise Exception(MsgConst.NOT_CREATED_INSTANCE)
|
|
@@ -120,6 +124,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
120
124
|
cls._instance.gm.monitor(model)
|
|
121
125
|
|
|
122
126
|
@classmethod
|
|
127
|
+
@ThreadSafe.synchronized
|
|
123
128
|
def save(cls, variable, name, save_backward=True):
|
|
124
129
|
instance = cls._instance
|
|
125
130
|
if not instance:
|
|
@@ -143,6 +148,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
143
148
|
dataloader._BaseDataLoaderIter.__next__ = self._iter_tracer(dataloader._BaseDataLoaderIter.__next__)
|
|
144
149
|
|
|
145
150
|
|
|
151
|
+
@ThreadSafe.synchronized
|
|
146
152
|
def module_dump(module, dump_name):
|
|
147
153
|
if not is_torch_nn_module(module):
|
|
148
154
|
raise MsprobeException(
|
|
@@ -164,6 +170,7 @@ def module_dump(module, dump_name):
|
|
|
164
170
|
instance.module_dumper.start_module_dump(module, dump_name)
|
|
165
171
|
|
|
166
172
|
|
|
173
|
+
@ThreadSafe.synchronized
|
|
167
174
|
def module_dump_end():
|
|
168
175
|
instance = PrecisionDebugger._instance
|
|
169
176
|
if not instance:
|
|
@@ -21,13 +21,11 @@ from torch.utils.hooks import BackwardHook
|
|
|
21
21
|
from msprobe.core.common.const import Const
|
|
22
22
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
23
23
|
from msprobe.pytorch.common.log import logger
|
|
24
|
-
from msprobe.pytorch.common.utils import is_float8_tensor
|
|
25
24
|
|
|
26
25
|
|
|
27
26
|
def wrap_setup_backward_hook(func):
|
|
28
27
|
def requires_clone(tensor):
|
|
29
|
-
return isinstance(tensor, torch.Tensor) and
|
|
30
|
-
tensor.requires_grad and torch.is_grad_enabled()
|
|
28
|
+
return isinstance(tensor, torch.Tensor) and tensor.requires_grad and torch.is_grad_enabled()
|
|
31
29
|
|
|
32
30
|
@recursion_depth_decorator("Dump: wrap_setup_backward_hook.parse_tensor", max_depth=Const.DUMP_MAX_DEPTH)
|
|
33
31
|
def parse_tensor(item, tensor_list):
|