mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
- msprobe/README.md +6 -6
- msprobe/core/common/const.py +98 -41
- msprobe/core/common/db_manager.py +256 -0
- msprobe/core/common/file_utils.py +28 -5
- msprobe/core/common/log.py +7 -0
- msprobe/core/common/megatron_utils.py +59 -0
- msprobe/core/common/parallel_state.py +193 -0
- msprobe/core/common/utils.py +20 -13
- msprobe/core/common_config.py +5 -0
- msprobe/core/compare/acc_compare.py +140 -93
- msprobe/core/compare/check.py +13 -0
- msprobe/core/compare/compare_cli.py +64 -6
- msprobe/core/compare/config.py +10 -8
- msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
- msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
- msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
- msprobe/core/compare/find_first/__init__.py +0 -0
- msprobe/core/compare/find_first/analyzer.py +282 -0
- msprobe/core/compare/find_first/data_processor.py +35 -0
- msprobe/core/compare/find_first/graph.py +188 -0
- msprobe/core/compare/find_first/utils.py +189 -0
- msprobe/core/compare/highlight.py +74 -101
- msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
- msprobe/core/compare/merge_result/merge_result.py +2 -2
- msprobe/core/compare/multiprocessing_compute.py +45 -28
- msprobe/core/compare/npy_compare.py +7 -10
- msprobe/core/compare/utils.py +338 -130
- msprobe/core/config_check/checkers/dataset_checker.py +2 -1
- msprobe/core/config_check/checkers/env_args_checker.py +5 -5
- msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
- msprobe/core/config_check/checkers/pip_checker.py +4 -3
- msprobe/core/config_check/checkers/random_checker.py +3 -3
- msprobe/core/config_check/checkers/weights_checker.py +2 -1
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
- msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
- msprobe/core/config_check/utils/utils.py +10 -0
- msprobe/core/data_dump/api_registry.py +49 -30
- msprobe/core/data_dump/data_collector.py +71 -29
- msprobe/core/data_dump/data_processor/base.py +2 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
- msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
- msprobe/core/data_dump/json_writer.py +81 -7
- msprobe/core/data_dump/scope.py +4 -6
- msprobe/core/hook_manager.py +129 -70
- msprobe/core/monitor/csv2db.py +361 -0
- msprobe/core/monitor/db_utils.py +278 -0
- msprobe/core/monitor/utils.py +35 -1
- msprobe/core/service.py +31 -39
- msprobe/core/single_save/single_comparator.py +16 -3
- msprobe/docs/01.installation.md +51 -19
- msprobe/docs/02.config_introduction.md +16 -20
- msprobe/docs/03.config_examples.md +26 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +6 -2
- msprobe/docs/06.data_dump_MindSpore.md +44 -7
- msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
- msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
- msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
- msprobe/docs/14.data_parse_PyTorch.md +1 -1
- msprobe/docs/19.monitor.md +94 -7
- msprobe/docs/21.visualization_PyTorch.md +71 -101
- msprobe/docs/22.visualization_MindSpore.md +69 -119
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
- msprobe/docs/28.debugger_save_instruction.md +184 -81
- msprobe/docs/29.data_dump_MSAdapter.md +6 -0
- msprobe/docs/31.config_check.md +4 -2
- msprobe/docs/36.calculation_result_change.md +75 -0
- msprobe/docs/FAQ.md +22 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
- msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
- msprobe/mindspore/__init__.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
- msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
- msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
- msprobe/mindspore/cell_processor.py +64 -25
- msprobe/mindspore/common/utils.py +51 -7
- msprobe/mindspore/compare/common_dir_compare.py +45 -37
- msprobe/mindspore/compare/ms_compare.py +10 -2
- msprobe/mindspore/compare/ms_graph_compare.py +47 -52
- msprobe/mindspore/debugger/debugger_config.py +18 -7
- msprobe/mindspore/debugger/precision_debugger.py +16 -12
- msprobe/mindspore/dump/cell_dump_process.py +130 -68
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
- msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
- msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
- msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
- msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
- msprobe/mindspore/exception_dump/__init__.py +0 -0
- msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
- msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
- msprobe/mindspore/monitor/features.py +82 -0
- msprobe/mindspore/monitor/module_hook.py +168 -10
- msprobe/mindspore/monitor/utils.py +27 -1
- msprobe/mindspore/ms_config.py +12 -4
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/mindspore/task_handler_factory.py +3 -1
- msprobe/nan_analyze/graph.py +1 -1
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +1 -21
- msprobe/pytorch/compare/pt_compare.py +10 -2
- msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
- msprobe/pytorch/compare/utils.py +2 -1
- msprobe/pytorch/debugger/debugger_config.py +18 -23
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
- msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
- msprobe/pytorch/free_benchmark/main.py +7 -4
- msprobe/pytorch/hook_module/api_register.py +62 -24
- msprobe/pytorch/hook_module/hook_module.py +9 -29
- msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
- msprobe/pytorch/hook_module/script_wrapper.py +140 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
- msprobe/pytorch/monitor/csv2tb.py +1 -1
- msprobe/pytorch/monitor/features.py +94 -0
- msprobe/pytorch/monitor/module_hook.py +221 -81
- msprobe/pytorch/monitor/module_metric.py +27 -1
- msprobe/pytorch/monitor/optimizer_collect.py +109 -4
- msprobe/pytorch/online_dispatch/dispatch.py +42 -24
- msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
- msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +7 -14
- msprobe/visualization/builder/graph_builder.py +192 -63
- msprobe/visualization/builder/graph_merger.py +986 -0
- msprobe/visualization/builder/msprobe_adapter.py +17 -15
- msprobe/visualization/compare/graph_comparator.py +26 -16
- msprobe/visualization/db_utils.py +252 -0
- msprobe/visualization/graph/base_node.py +2 -22
- msprobe/visualization/graph/distributed_analyzer.py +12 -12
- msprobe/visualization/graph/graph.py +44 -16
- msprobe/visualization/graph_service.py +143 -59
- msprobe/visualization/utils.py +103 -4
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
- /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
|
@@ -18,6 +18,7 @@ import torch
|
|
|
18
18
|
|
|
19
19
|
from msprobe.pytorch.common.log import logger
|
|
20
20
|
from msprobe.core.monitor.utils import MVResult
|
|
21
|
+
from msprobe.pytorch.monitor.module_metric import get_metrics
|
|
21
22
|
from msprobe.core.common.const import MonitorConst
|
|
22
23
|
|
|
23
24
|
|
|
@@ -26,6 +27,8 @@ class OptimizerMon(object):
|
|
|
26
27
|
self.fp16_to_fp32_param = {}
|
|
27
28
|
self.torch_opt = torch_opt
|
|
28
29
|
self.state = {}
|
|
30
|
+
self.origin_funcs = []
|
|
31
|
+
self.bucket_class = None
|
|
29
32
|
|
|
30
33
|
def narrow_from_flatten(self, param, flatten_state):
|
|
31
34
|
return flatten_state
|
|
@@ -49,11 +52,13 @@ class OptimizerMon(object):
|
|
|
49
52
|
if self.fp16_to_fp32_param and param not in self.fp16_to_fp32_param:
|
|
50
53
|
continue
|
|
51
54
|
grad = param.main_grad if monitor.params_have_main_grad else param.grad
|
|
55
|
+
if grad.__class__.__name__ == 'DTensor':
|
|
56
|
+
grad = grad.to_local()
|
|
52
57
|
element_in_cur_partition = self.fp16_to_fp32_param.get(param, param).numel()
|
|
53
58
|
if param.numel() != element_in_cur_partition:
|
|
54
59
|
if first_param:
|
|
55
60
|
grad = grad.flatten()[-element_in_cur_partition:]
|
|
56
|
-
else:
|
|
61
|
+
else: # supposed to be the last one
|
|
57
62
|
grad = grad.flatten()[:element_in_cur_partition]
|
|
58
63
|
first_param = False
|
|
59
64
|
|
|
@@ -120,6 +125,59 @@ class OptimizerMon(object):
|
|
|
120
125
|
monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
|
|
121
126
|
return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
|
|
122
127
|
|
|
128
|
+
def patch_grad_sync(self, monitor):
|
|
129
|
+
def patch_sync(sync_grad_func):
|
|
130
|
+
def wrapper(bucket):
|
|
131
|
+
grad_dict = {}
|
|
132
|
+
# Megatron between core_r0.6.0 and core_r0.8.0, this bucket is Bucket.
|
|
133
|
+
# When megatron is core_r0.9.0, this bucket is _ParamAndGradBucketGroup.
|
|
134
|
+
# In megatron version core_r0.9.0, func start_grad_sync from Bucket moved to _ParamAndGradBucketGroup.
|
|
135
|
+
bucket_params_id_list = [id(params) for params in bucket.params]
|
|
136
|
+
for param, name in monitor.param2name.items():
|
|
137
|
+
if id(param) not in bucket_params_id_list:
|
|
138
|
+
continue
|
|
139
|
+
grad = param.main_grad if monitor.params_have_main_grad else param.grad
|
|
140
|
+
if grad is None:
|
|
141
|
+
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
142
|
+
continue
|
|
143
|
+
tag = monitor.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
|
|
144
|
+
if tag is None:
|
|
145
|
+
continue
|
|
146
|
+
grad_dict[tag] = grad
|
|
147
|
+
monitor.register_param_call_id("sync_grad_func", tag)
|
|
148
|
+
get_metrics(monitor.ops, grad_dict, monitor.eps, monitor.grad_context.pre)
|
|
149
|
+
out = sync_grad_func(bucket)
|
|
150
|
+
return out
|
|
151
|
+
|
|
152
|
+
return wrapper
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
from megatron.core.distributed.param_and_grad_buffer import Bucket
|
|
156
|
+
self.origin_funcs.append(Bucket.start_grad_sync)
|
|
157
|
+
self.bucket_class = Bucket
|
|
158
|
+
Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync)
|
|
159
|
+
monitor.enable_megatron = True
|
|
160
|
+
logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0")
|
|
161
|
+
except ImportError:
|
|
162
|
+
monitor.enable_megatron = False
|
|
163
|
+
|
|
164
|
+
try:
|
|
165
|
+
from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
|
|
166
|
+
self.origin_funcs.append(_ParamAndGradBucketGroup.start_grad_sync)
|
|
167
|
+
self.bucket_class = _ParamAndGradBucketGroup
|
|
168
|
+
_ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync)
|
|
169
|
+
monitor.enable_megatron = True
|
|
170
|
+
logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0")
|
|
171
|
+
except ImportError:
|
|
172
|
+
monitor.enable_megatron = False | monitor.enable_megatron
|
|
173
|
+
|
|
174
|
+
def restore_grad_sync(self, monitor):
|
|
175
|
+
if not monitor.enable_megatron:
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
self.bucket_class.start_grad_sync = self.origin_funcs[0]
|
|
179
|
+
|
|
180
|
+
|
|
123
181
|
def _get_single_state(self, torch_opt):
|
|
124
182
|
state = {}
|
|
125
183
|
if hasattr(torch_opt, 'param_to_cpu_states_map'):
|
|
@@ -131,7 +189,7 @@ class OptimizerMon(object):
|
|
|
131
189
|
self.state.update(state)
|
|
132
190
|
|
|
133
191
|
|
|
134
|
-
class
|
|
192
|
+
class MegatronMixPrecisionOptimizerMon(OptimizerMon):
|
|
135
193
|
"""
|
|
136
194
|
混合精度优化器监控类。在混合精度训练中监控和管理优化器。
|
|
137
195
|
混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
|
|
@@ -161,7 +219,7 @@ class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
|
|
|
161
219
|
super().map_fp16_to_fp32_param(opt)
|
|
162
220
|
|
|
163
221
|
|
|
164
|
-
class MegatronChainedMixPrecisionOptimizerMon(
|
|
222
|
+
class MegatronChainedMixPrecisionOptimizerMon(MegatronMixPrecisionOptimizerMon):
|
|
165
223
|
def map_fp16_to_fp32_param(self, torch_opt):
|
|
166
224
|
for opt in torch_opt.chained_optimizers:
|
|
167
225
|
super().map_fp16_to_fp32_param(opt)
|
|
@@ -248,6 +306,12 @@ class DeepSpeedZeroOptimizerMon(OptimizerMon):
|
|
|
248
306
|
grad_dict[tag] = grad
|
|
249
307
|
|
|
250
308
|
return grad_dict
|
|
309
|
+
|
|
310
|
+
def patch_grad_sync(self, monitor):
|
|
311
|
+
pass
|
|
312
|
+
|
|
313
|
+
def restore_grad_sync(self, monitor):
|
|
314
|
+
pass
|
|
251
315
|
|
|
252
316
|
|
|
253
317
|
class DeepSpeedZeroOptimizerStage0Mon(DeepSpeedZeroOptimizerMon):
|
|
@@ -291,6 +355,47 @@ class DeepSpeedZeroOptimizerStage1or2Mon(DeepSpeedZeroOptimizerMon):
|
|
|
291
355
|
break
|
|
292
356
|
|
|
293
357
|
|
|
358
|
+
def patch_grad_sync(self, monitor):
|
|
359
|
+
def patch_sync(reduce_func):
|
|
360
|
+
def wrapper(zero_optimizer, *args, **kwargs):
|
|
361
|
+
grad_dict = {}
|
|
362
|
+
for i, param, _ in zero_optimizer.params_in_ipg_bucket:
|
|
363
|
+
if isinstance(param, int): # for ds >= 0.17.0
|
|
364
|
+
param = zero_optimizer.bit16_groups[i][param]
|
|
365
|
+
name = monitor.param2name[param]
|
|
366
|
+
tag = monitor.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
|
|
367
|
+
grad_dict[tag] = zero_optimizer.get_gradient_for_reduction(param)
|
|
368
|
+
monitor.register_param_call_id("sync_grad_func", tag)
|
|
369
|
+
get_metrics(monitor.ops, grad_dict, monitor.eps, monitor.grad_context.pre)
|
|
370
|
+
out = reduce_func(zero_optimizer, *args, **kwargs)
|
|
371
|
+
return out
|
|
372
|
+
|
|
373
|
+
return wrapper
|
|
374
|
+
try:
|
|
375
|
+
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
|
|
376
|
+
self.origin_funcs = [
|
|
377
|
+
DeepSpeedZeroOptimizer.average_tensor,
|
|
378
|
+
DeepSpeedZeroOptimizer.buffered_reduce_fallback
|
|
379
|
+
]
|
|
380
|
+
DeepSpeedZeroOptimizer.average_tensor = patch_sync(DeepSpeedZeroOptimizer.average_tensor)
|
|
381
|
+
DeepSpeedZeroOptimizer.buffered_reduce_fallback = \
|
|
382
|
+
patch_sync(DeepSpeedZeroOptimizer.buffered_reduce_fallback)
|
|
383
|
+
monitor.enable_deepspeed = True
|
|
384
|
+
logger.info('deepspeed enabled')
|
|
385
|
+
except Exception as e:
|
|
386
|
+
monitor.enable_deepspeed = False | monitor.enable_deepspeed
|
|
387
|
+
logger.warning('Seems using deepspeed zero 1 or 2. But patch average tensor failed')
|
|
388
|
+
|
|
389
|
+
def restore_grad_sync(self, monitor):
|
|
390
|
+
if not monitor.enable_deepspeed:
|
|
391
|
+
return
|
|
392
|
+
|
|
393
|
+
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
|
|
394
|
+
DeepSpeedZeroOptimizer.average_tensor = self.origin_funcs[0]
|
|
395
|
+
DeepSpeedZeroOptimizer.buffered_reduce_fallback = self.origin_funcs[1]
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
|
|
294
399
|
class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon):
|
|
295
400
|
def __init__(self, torch_opt):
|
|
296
401
|
super().__init__(torch_opt)
|
|
@@ -314,7 +419,7 @@ class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon):
|
|
|
314
419
|
class OptimizerMonFactory:
|
|
315
420
|
_optimizer_mon_map = {
|
|
316
421
|
"FP32Optimizer": OptimizerMon,
|
|
317
|
-
"Float16OptimizerWithFloat16Params":
|
|
422
|
+
"Float16OptimizerWithFloat16Params": MegatronMixPrecisionOptimizerMon,
|
|
318
423
|
"DistributedOptimizer": MegatronDistributedOptimizerMon,
|
|
319
424
|
"SwapDistributedOptimizer": MegatronDistributedOptimizerMon,
|
|
320
425
|
"ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
|
|
@@ -17,7 +17,7 @@ import json
|
|
|
17
17
|
import os
|
|
18
18
|
import time
|
|
19
19
|
import multiprocessing
|
|
20
|
-
from multiprocessing import Pool
|
|
20
|
+
from multiprocessing import Pool, Lock
|
|
21
21
|
|
|
22
22
|
import torch
|
|
23
23
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
@@ -39,6 +39,7 @@ from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu, ge
|
|
|
39
39
|
from msprobe.pytorch.online_dispatch.compare import Comparator
|
|
40
40
|
from msprobe.core.common.utils import check_str_param, safe_get_value
|
|
41
41
|
|
|
42
|
+
child_global_lock = None
|
|
42
43
|
current_time = time.strftime("%Y%m%d%H%M%S")
|
|
43
44
|
RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
|
|
44
45
|
DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
|
|
@@ -86,14 +87,14 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
86
87
|
yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
|
|
87
88
|
self.get_ops(yaml_path)
|
|
88
89
|
|
|
89
|
-
self.lock = None
|
|
90
|
+
self.lock = Lock() if process_num > 0 else None
|
|
90
91
|
max_process_num = max(int((multiprocessing.cpu_count() + 1) // Const.CPU_QUARTER), 1)
|
|
91
92
|
if process_num > max_process_num:
|
|
92
93
|
logger.error(f"process_num should be less than or equal to {max_process_num}, but got {process_num}!")
|
|
93
94
|
raise DispatchException(f'process_num should be less than or equal to {max_process_num}, '
|
|
94
95
|
f'but got {process_num}!')
|
|
95
96
|
if process_num > 0:
|
|
96
|
-
self.pool = Pool(process_num)
|
|
97
|
+
self.pool = Pool(process_num, initializer=self._init_child_process, initargs=(self.lock,))
|
|
97
98
|
if debug:
|
|
98
99
|
logger.info(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} '
|
|
99
100
|
f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], '
|
|
@@ -114,18 +115,17 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
114
115
|
logger.error("Please check train log, An exception may have occurred!")
|
|
115
116
|
return
|
|
116
117
|
check_file_or_directory_path(summary_path, False)
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
fp_handle.close()
|
|
118
|
+
with FileOpen(summary_path, "r") as fp_handle:
|
|
119
|
+
while True:
|
|
120
|
+
json_line_data = fp_handle.readline()
|
|
121
|
+
if json_line_data == '\n':
|
|
122
|
+
continue
|
|
123
|
+
if len(json_line_data) == 0:
|
|
124
|
+
break
|
|
125
|
+
msg = json.loads(json_line_data)
|
|
126
|
+
if len(msg) < 2:
|
|
127
|
+
raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.")
|
|
128
|
+
self.all_summary[msg[0]] = msg[1]
|
|
129
129
|
|
|
130
130
|
if self.debug_flag:
|
|
131
131
|
input_num = 0
|
|
@@ -163,11 +163,16 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
163
163
|
|
|
164
164
|
call_stack = get_callstack()
|
|
165
165
|
self.call_stack_list.append(call_stack)
|
|
166
|
-
|
|
167
|
-
if
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
self.single_api_index_dict
|
|
166
|
+
|
|
167
|
+
self.lock.acquire() if self.process_num > 0 else None
|
|
168
|
+
try:
|
|
169
|
+
self.api_index += 1
|
|
170
|
+
if aten_api not in self.single_api_index_dict:
|
|
171
|
+
self.single_api_index_dict[aten_api] = 1
|
|
172
|
+
else:
|
|
173
|
+
self.single_api_index_dict[aten_api] += 1
|
|
174
|
+
finally:
|
|
175
|
+
self.lock.release() if self.process_num > 0 else None
|
|
171
176
|
|
|
172
177
|
run_param = self.get_run_param(aten_api, func.__name__, aten_api_overload_name)
|
|
173
178
|
|
|
@@ -180,7 +185,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
180
185
|
cpu_kwargs = []
|
|
181
186
|
data_to_cpu(args, 0, cpu_args)
|
|
182
187
|
data_to_cpu(kwargs, 0, cpu_kwargs)
|
|
183
|
-
|
|
188
|
+
|
|
184
189
|
cpu_args = safe_get_value(cpu_args, 0, "cpu_args")
|
|
185
190
|
cpu_kwargs = safe_get_value(cpu_kwargs, 0, "cpu_kwargs")
|
|
186
191
|
|
|
@@ -194,7 +199,12 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
194
199
|
try:
|
|
195
200
|
cpu_out = func(*cpu_args, **cpu_kwargs)
|
|
196
201
|
except RuntimeError as e:
|
|
197
|
-
self.
|
|
202
|
+
self.lock.acquire() if self.process_num > 0 else None
|
|
203
|
+
try:
|
|
204
|
+
self.api_index -= 1
|
|
205
|
+
self.single_api_index_dict[aten_api] -= 1
|
|
206
|
+
finally:
|
|
207
|
+
self.lock.release() if self.process_num > 0 else None
|
|
198
208
|
logger.warning(f"RuntimeError: {e}")
|
|
199
209
|
logger.warning(f"This aten_api {aten_api} does not support running on cpu, so skip it.")
|
|
200
210
|
return npu_out
|
|
@@ -215,7 +225,7 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
215
225
|
run_param.process_flag = True
|
|
216
226
|
if self.check_fun(func, run_param):
|
|
217
227
|
data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out,
|
|
218
|
-
|
|
228
|
+
child_global_lock)
|
|
219
229
|
self.pool.apply_async(func=dispatch_multiprocess, args=(run_param, data_info),
|
|
220
230
|
error_callback=error_call)
|
|
221
231
|
else:
|
|
@@ -233,12 +243,20 @@ class PtdbgDispatch(TorchDispatchMode):
|
|
|
233
243
|
return True
|
|
234
244
|
return False
|
|
235
245
|
|
|
246
|
+
@staticmethod
|
|
247
|
+
def _init_child_process(lock):
|
|
248
|
+
global child_global_lock
|
|
249
|
+
child_global_lock = lock
|
|
250
|
+
|
|
236
251
|
def get_dir_name(self, tag):
|
|
237
252
|
# guarantee file uniqueness
|
|
238
253
|
time.sleep(1)
|
|
239
|
-
|
|
254
|
+
# 时间格式:年-月-日-时-分-秒-毫秒(精确到千分之一秒)
|
|
255
|
+
time_now = time.strftime("%Y%m%d%H%M%S%f", time.localtime(time.time()))[:-3] # 取前3位毫秒
|
|
256
|
+
|
|
240
257
|
if tag is None or not isinstance(tag, str):
|
|
241
258
|
logger.warning('There is not tag or the type of tag is not string.')
|
|
259
|
+
# 目录名格式:msprobe_rank{设备ID}_{毫秒时间戳}
|
|
242
260
|
dir_name = f'msprobe_rank{self.device_id}_{time_now}'
|
|
243
261
|
else:
|
|
244
262
|
dir_name = f'msprobe_{tag}_rank{self.device_id}_{time_now}'
|
|
@@ -21,7 +21,7 @@ from datetime import datetime, timezone
|
|
|
21
21
|
import torch
|
|
22
22
|
from msprobe.core.common.const import Const
|
|
23
23
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
24
|
-
from msprobe.core.common.file_utils import FileOpen, save_npy, save_json,
|
|
24
|
+
from msprobe.core.common.file_utils import FileOpen, save_npy, save_json, remove_path, check_link
|
|
25
25
|
from msprobe.pytorch.common.log import logger
|
|
26
26
|
|
|
27
27
|
|
msprobe/pytorch/pt_config.py
CHANGED
|
@@ -35,48 +35,15 @@ from msprobe.pytorch.hook_module.utils import get_ops
|
|
|
35
35
|
class TensorConfig(BaseConfig):
|
|
36
36
|
def __init__(self, json_config):
|
|
37
37
|
super().__init__(json_config)
|
|
38
|
-
self.online_run_ut = json_config.get("online_run_ut", False)
|
|
39
|
-
self.nfs_path = json_config.get("nfs_path", "")
|
|
40
|
-
self.host = json_config.get("host", "")
|
|
41
|
-
self.port = json_config.get("port", -1)
|
|
42
|
-
self.tls_path = json_config.get("tls_path", "./")
|
|
43
|
-
self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
|
|
44
38
|
self.check_config()
|
|
45
39
|
self._check_summary_mode()
|
|
46
40
|
self._check_file_format()
|
|
47
|
-
|
|
48
|
-
self._check_online_run_ut()
|
|
41
|
+
|
|
49
42
|
|
|
50
43
|
def _check_file_format(self):
|
|
51
44
|
if self.file_format is not None and self.file_format not in ["npy", "bin"]:
|
|
52
45
|
raise Exception("file_format is invalid")
|
|
53
46
|
|
|
54
|
-
def _check_online_run_ut(self):
|
|
55
|
-
if not isinstance(self.online_run_ut, bool):
|
|
56
|
-
raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.")
|
|
57
|
-
|
|
58
|
-
if not isinstance(self.online_run_ut_recompute, bool):
|
|
59
|
-
raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.")
|
|
60
|
-
|
|
61
|
-
if self.nfs_path:
|
|
62
|
-
check_file_or_directory_path(self.nfs_path, isdir=True)
|
|
63
|
-
return
|
|
64
|
-
|
|
65
|
-
if self.tls_path:
|
|
66
|
-
check_file_or_directory_path(self.tls_path, isdir=True)
|
|
67
|
-
check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
|
|
68
|
-
check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
|
|
69
|
-
check_file_or_directory_path(os.path.join(self.tls_path, "ca.crt"))
|
|
70
|
-
crl_path = os.path.join(self.tls_path, "crl.pem")
|
|
71
|
-
if os.path.exists(crl_path):
|
|
72
|
-
check_file_or_directory_path(crl_path)
|
|
73
|
-
|
|
74
|
-
if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
|
|
75
|
-
raise Exception(f"host: {self.host} is invalid.")
|
|
76
|
-
|
|
77
|
-
if not isinstance(self.port, int) or not (0 < self.port <= 65535):
|
|
78
|
-
raise Exception(f"port: {self.port} is invalid, port range 0-65535.")
|
|
79
|
-
|
|
80
47
|
|
|
81
48
|
class StatisticsConfig(BaseConfig):
|
|
82
49
|
def __init__(self, json_config):
|
|
@@ -251,12 +218,7 @@ class RunUTConfig(BaseConfig):
|
|
|
251
218
|
self.white_list = json_config.get("white_list", Const.DEFAULT_LIST)
|
|
252
219
|
self.black_list = json_config.get("black_list", Const.DEFAULT_LIST)
|
|
253
220
|
self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH)
|
|
254
|
-
|
|
255
|
-
self.nfs_path = json_config.get("nfs_path", "")
|
|
256
|
-
self.host = json_config.get("host", "")
|
|
257
|
-
self.port = json_config.get("port", -1)
|
|
258
|
-
self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
|
|
259
|
-
self.tls_path = json_config.get("tls_path", "./")
|
|
221
|
+
|
|
260
222
|
self.check_run_ut_config()
|
|
261
223
|
|
|
262
224
|
@classmethod
|
|
@@ -274,22 +236,11 @@ class RunUTConfig(BaseConfig):
|
|
|
274
236
|
if not os.path.exists(error_data_path):
|
|
275
237
|
raise Exception("error_data_path: %s does not exist" % error_data_path)
|
|
276
238
|
|
|
277
|
-
@classmethod
|
|
278
|
-
def check_nfs_path_config(cls, nfs_path):
|
|
279
|
-
if nfs_path:
|
|
280
|
-
FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
281
|
-
|
|
282
|
-
@classmethod
|
|
283
|
-
def check_tls_path_config(cls, tls_path):
|
|
284
|
-
if tls_path:
|
|
285
|
-
FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
|
|
286
239
|
|
|
287
240
|
def check_run_ut_config(self):
|
|
288
241
|
RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
|
|
289
242
|
RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list)
|
|
290
243
|
RunUTConfig.check_error_data_path_config(self.error_data_path)
|
|
291
|
-
RunUTConfig.check_nfs_path_config(self.nfs_path)
|
|
292
|
-
RunUTConfig.check_tls_path_config(self.tls_path)
|
|
293
244
|
|
|
294
245
|
|
|
295
246
|
class GradToolConfig(BaseConfig):
|
|
@@ -15,18 +15,14 @@
|
|
|
15
15
|
|
|
16
16
|
from msprobe.core.common.utils import Const
|
|
17
17
|
from msprobe.core.service import BaseService
|
|
18
|
-
from msprobe.pytorch.attl_manager import ATTLManager
|
|
19
18
|
from msprobe.pytorch.common.log import logger
|
|
20
|
-
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
19
|
+
from msprobe.pytorch.common.utils import get_rank_if_initialized
|
|
21
20
|
from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
|
|
22
|
-
from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate
|
|
21
|
+
from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate, redirect_wait
|
|
23
22
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
24
|
-
from msprobe.pytorch.hook_module.jit_script_wrapper import wrap_jit_script_func
|
|
25
23
|
from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager
|
|
26
24
|
from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
|
|
27
|
-
|
|
28
|
-
if torch_version_above_or_equal_2:
|
|
29
|
-
from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
|
|
25
|
+
from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, preprocess_func
|
|
30
26
|
|
|
31
27
|
|
|
32
28
|
class PytorchService(BaseService):
|
|
@@ -45,27 +41,24 @@ class PytorchService(BaseService):
|
|
|
45
41
|
self.logger = logger
|
|
46
42
|
self.api_register = get_api_register()
|
|
47
43
|
self.module_processor = ModuleProcesser(self.data_collector.scope)
|
|
48
|
-
self.
|
|
49
|
-
self.hook_manager = PytorchHookManager(self.data_collector, self.config, self.attl_manager)
|
|
44
|
+
self.hook_manager = PytorchHookManager(self.data_collector, self.config)
|
|
50
45
|
self.api_template = ApiTemplate
|
|
51
46
|
|
|
52
47
|
def _register_hook(self):
|
|
53
|
-
self.attl_manager.attl_init()
|
|
54
48
|
if self._is_mix_level:
|
|
55
49
|
register_optimizer_hook(self.data_collector)
|
|
56
50
|
|
|
57
51
|
def _register_api_hook(self):
|
|
52
|
+
preprocess_func()
|
|
58
53
|
super()._register_api_hook()
|
|
59
|
-
|
|
54
|
+
wrap_script_func()
|
|
55
|
+
redirect_wait()
|
|
60
56
|
|
|
61
57
|
def _register_module_hook(self):
|
|
62
58
|
ModuleProcesser.enable_module_dump = True
|
|
63
59
|
self.module_processor.register_module_hook(self.model, self.build_hook)
|
|
64
60
|
self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.")
|
|
65
61
|
|
|
66
|
-
def _run_ut_dispatch(self, status):
|
|
67
|
-
if torch_version_above_or_equal_2:
|
|
68
|
-
run_ut_dispatch(self.attl_manager.attl, status, self.config.online_run_ut_recompute)
|
|
69
62
|
|
|
70
63
|
def _reset_status(self):
|
|
71
64
|
super()._reset_status()
|