mindstudio-probe 8.1.2__py3-none-any.whl → 8.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/METADATA +2 -2
  2. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/RECORD +172 -147
  3. msprobe/README.md +6 -6
  4. msprobe/core/common/const.py +98 -41
  5. msprobe/core/common/db_manager.py +256 -0
  6. msprobe/core/common/file_utils.py +28 -5
  7. msprobe/core/common/log.py +7 -0
  8. msprobe/core/common/megatron_utils.py +59 -0
  9. msprobe/core/common/parallel_state.py +193 -0
  10. msprobe/core/common/utils.py +20 -13
  11. msprobe/core/common_config.py +5 -0
  12. msprobe/core/compare/acc_compare.py +140 -93
  13. msprobe/core/compare/check.py +13 -0
  14. msprobe/core/compare/compare_cli.py +64 -6
  15. msprobe/core/compare/config.py +10 -8
  16. msprobe/core/compare/diff_analyze/diff_analyze_threshold.yaml +14 -0
  17. msprobe/core/compare/diff_analyze/first_diff_analyze.py +135 -0
  18. msprobe/core/compare/diff_analyze/ignore_op_list.yaml +3 -0
  19. msprobe/core/compare/find_first/__init__.py +0 -0
  20. msprobe/core/compare/find_first/analyzer.py +282 -0
  21. msprobe/core/compare/find_first/data_processor.py +35 -0
  22. msprobe/core/compare/find_first/graph.py +188 -0
  23. msprobe/core/compare/find_first/utils.py +189 -0
  24. msprobe/core/compare/highlight.py +74 -101
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +14 -9
  26. msprobe/core/compare/merge_result/merge_result.py +2 -2
  27. msprobe/core/compare/multiprocessing_compute.py +45 -28
  28. msprobe/core/compare/npy_compare.py +7 -10
  29. msprobe/core/compare/utils.py +338 -130
  30. msprobe/core/config_check/checkers/dataset_checker.py +2 -1
  31. msprobe/core/config_check/checkers/env_args_checker.py +5 -5
  32. msprobe/core/config_check/checkers/hyperparameter_checker.py +30 -10
  33. msprobe/core/config_check/checkers/pip_checker.py +4 -3
  34. msprobe/core/config_check/checkers/random_checker.py +3 -3
  35. msprobe/core/config_check/checkers/weights_checker.py +2 -1
  36. msprobe/core/config_check/ckpt_compare/megatron_loader.py +2 -0
  37. msprobe/core/config_check/resource/hyperparameter.yaml +11 -1
  38. msprobe/core/config_check/utils/hyperparameter_parser.py +7 -3
  39. msprobe/core/config_check/utils/utils.py +10 -0
  40. msprobe/core/data_dump/api_registry.py +49 -30
  41. msprobe/core/data_dump/data_collector.py +71 -29
  42. msprobe/core/data_dump/data_processor/base.py +2 -0
  43. msprobe/core/data_dump/data_processor/mindspore_processor.py +47 -53
  44. msprobe/core/data_dump/data_processor/pytorch_processor.py +227 -93
  45. msprobe/core/data_dump/json_writer.py +81 -7
  46. msprobe/core/data_dump/scope.py +4 -6
  47. msprobe/core/hook_manager.py +129 -70
  48. msprobe/core/monitor/csv2db.py +361 -0
  49. msprobe/core/monitor/db_utils.py +278 -0
  50. msprobe/core/monitor/utils.py +35 -1
  51. msprobe/core/service.py +31 -39
  52. msprobe/core/single_save/single_comparator.py +16 -3
  53. msprobe/docs/01.installation.md +51 -19
  54. msprobe/docs/02.config_introduction.md +16 -20
  55. msprobe/docs/03.config_examples.md +26 -0
  56. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  57. msprobe/docs/05.data_dump_PyTorch.md +6 -2
  58. msprobe/docs/06.data_dump_MindSpore.md +44 -7
  59. msprobe/docs/07.accuracy_checker_PyTorch.md +1 -1
  60. msprobe/docs/10.accuracy_compare_PyTorch.md +124 -44
  61. msprobe/docs/11.accuracy_compare_MindSpore.md +75 -7
  62. msprobe/docs/14.data_parse_PyTorch.md +1 -1
  63. msprobe/docs/19.monitor.md +94 -7
  64. msprobe/docs/21.visualization_PyTorch.md +71 -101
  65. msprobe/docs/22.visualization_MindSpore.md +69 -119
  66. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  67. msprobe/docs/25.tool_function_introduction.md +0 -1
  68. msprobe/docs/26.data_dump_PyTorch_baseline.md +7 -7
  69. msprobe/docs/28.debugger_save_instruction.md +184 -81
  70. msprobe/docs/29.data_dump_MSAdapter.md +6 -0
  71. msprobe/docs/31.config_check.md +4 -2
  72. msprobe/docs/36.calculation_result_change.md +75 -0
  73. msprobe/docs/FAQ.md +22 -1
  74. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +6 -2
  75. msprobe/docs/img/compare_result.png +0 -0
  76. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  77. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  78. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  79. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  80. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  81. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  82. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  83. msprobe/docs/visualization/mindspeed_llamafactoary_img/1.png +0 -0
  84. msprobe/docs/visualization/mindspeed_llamafactoary_img/2.png +0 -0
  85. msprobe/docs/visualization/mindspeed_llamafactoary_img/3.png +0 -0
  86. msprobe/docs/visualization/mindspeed_llamafactoary_img/4.png +0 -0
  87. msprobe/docs/visualization/mindspeed_llamafactoary_img/5.png +0 -0
  88. msprobe/docs/visualization/mindspeed_llamafactoary_img/6.png +0 -0
  89. msprobe/docs/visualization/mindspeed_llamafactoary_img/7.png +0 -0
  90. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory-qwen25vl.txt +59 -0
  91. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory1.png +0 -0
  92. msprobe/docs/visualization/mindspeed_llamafactoary_img/llamafactory2.png +0 -0
  93. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed-mm-qwen25vl.txt +80 -0
  94. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed1.png +0 -0
  95. msprobe/docs/visualization/mindspeed_llamafactoary_img/mindspeed2.png +0 -0
  96. msprobe/docs/visualization/mindspeed_llamafactory_mapping.md +330 -0
  97. msprobe/mindspore/__init__.py +1 -1
  98. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -1
  99. msprobe/mindspore/api_accuracy_checker/api_runner.py +9 -6
  100. msprobe/mindspore/api_accuracy_checker/compute_element.py +18 -12
  101. msprobe/mindspore/cell_processor.py +64 -25
  102. msprobe/mindspore/common/utils.py +51 -7
  103. msprobe/mindspore/compare/common_dir_compare.py +45 -37
  104. msprobe/mindspore/compare/ms_compare.py +10 -2
  105. msprobe/mindspore/compare/ms_graph_compare.py +47 -52
  106. msprobe/mindspore/debugger/debugger_config.py +18 -7
  107. msprobe/mindspore/debugger/precision_debugger.py +16 -12
  108. msprobe/mindspore/dump/cell_dump_process.py +130 -68
  109. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +10 -2
  110. msprobe/mindspore/dump/graph_mode_cell_dump.py +35 -9
  111. msprobe/mindspore/dump/graph_tensor_dump.py +11 -0
  112. msprobe/mindspore/dump/hook_cell/api_register.py +19 -20
  113. msprobe/mindspore/dump/hook_cell/hook_cell.py +12 -34
  114. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +142 -21
  115. msprobe/mindspore/dump/kernel_kbyk_dump.py +24 -0
  116. msprobe/mindspore/exception_dump/__init__.py +0 -0
  117. msprobe/mindspore/exception_dump/exception_dump_tool_factory.py +51 -0
  118. msprobe/mindspore/exception_dump/kernel_graph_exception_dump.py +57 -0
  119. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +5 -4
  120. msprobe/mindspore/mindspore_service.py +2 -2
  121. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +12 -7
  122. msprobe/mindspore/monitor/features.py +82 -0
  123. msprobe/mindspore/monitor/module_hook.py +168 -10
  124. msprobe/mindspore/monitor/utils.py +27 -1
  125. msprobe/mindspore/ms_config.py +12 -4
  126. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  127. msprobe/mindspore/task_handler_factory.py +3 -1
  128. msprobe/nan_analyze/graph.py +1 -1
  129. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  130. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  131. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  132. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  133. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -2
  134. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  135. msprobe/pytorch/common/utils.py +1 -21
  136. msprobe/pytorch/compare/pt_compare.py +10 -2
  137. msprobe/pytorch/{hook_module/jit_script_wrapper.py → compare/pt_diff_analyze.py} +3 -15
  138. msprobe/pytorch/compare/utils.py +2 -1
  139. msprobe/pytorch/debugger/debugger_config.py +18 -23
  140. msprobe/pytorch/dump/module_dump/hook_wrapper.py +10 -7
  141. msprobe/pytorch/dump/module_dump/module_processer.py +41 -19
  142. msprobe/pytorch/free_benchmark/main.py +7 -4
  143. msprobe/pytorch/hook_module/api_register.py +62 -24
  144. msprobe/pytorch/hook_module/hook_module.py +9 -29
  145. msprobe/pytorch/hook_module/pt_hook_manager.py +84 -15
  146. msprobe/pytorch/hook_module/script_wrapper.py +140 -0
  147. msprobe/pytorch/hook_module/support_wrap_ops.yaml +6 -0
  148. msprobe/pytorch/monitor/csv2tb.py +1 -1
  149. msprobe/pytorch/monitor/features.py +94 -0
  150. msprobe/pytorch/monitor/module_hook.py +221 -81
  151. msprobe/pytorch/monitor/module_metric.py +27 -1
  152. msprobe/pytorch/monitor/optimizer_collect.py +109 -4
  153. msprobe/pytorch/online_dispatch/dispatch.py +42 -24
  154. msprobe/pytorch/online_dispatch/dump_compare.py +1 -1
  155. msprobe/pytorch/parse_tool/lib/visualization.py +0 -1
  156. msprobe/pytorch/pt_config.py +2 -51
  157. msprobe/pytorch/pytorch_service.py +7 -14
  158. msprobe/visualization/builder/graph_builder.py +192 -63
  159. msprobe/visualization/builder/graph_merger.py +986 -0
  160. msprobe/visualization/builder/msprobe_adapter.py +17 -15
  161. msprobe/visualization/compare/graph_comparator.py +26 -16
  162. msprobe/visualization/db_utils.py +252 -0
  163. msprobe/visualization/graph/base_node.py +2 -22
  164. msprobe/visualization/graph/distributed_analyzer.py +12 -12
  165. msprobe/visualization/graph/graph.py +44 -16
  166. msprobe/visualization/graph_service.py +143 -59
  167. msprobe/visualization/utils.py +103 -4
  168. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  169. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  170. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  171. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  172. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  173. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  174. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  175. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  176. msprobe/pytorch/attl_manager.py +0 -65
  177. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/LICENSE +0 -0
  178. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/WHEEL +0 -0
  179. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/entry_points.txt +0 -0
  180. {mindstudio_probe-8.1.2.dist-info → mindstudio_probe-8.2.1.dist-info}/top_level.txt +0 -0
  181. /msprobe/{pytorch/api_accuracy_checker/tensor_transport_layer → core/compare/diff_analyze}/__init__.py +0 -0
@@ -18,6 +18,7 @@ import torch
18
18
 
19
19
  from msprobe.pytorch.common.log import logger
20
20
  from msprobe.core.monitor.utils import MVResult
21
+ from msprobe.pytorch.monitor.module_metric import get_metrics
21
22
  from msprobe.core.common.const import MonitorConst
22
23
 
23
24
 
@@ -26,6 +27,8 @@ class OptimizerMon(object):
26
27
  self.fp16_to_fp32_param = {}
27
28
  self.torch_opt = torch_opt
28
29
  self.state = {}
30
+ self.origin_funcs = []
31
+ self.bucket_class = None
29
32
 
30
33
  def narrow_from_flatten(self, param, flatten_state):
31
34
  return flatten_state
@@ -49,11 +52,13 @@ class OptimizerMon(object):
49
52
  if self.fp16_to_fp32_param and param not in self.fp16_to_fp32_param:
50
53
  continue
51
54
  grad = param.main_grad if monitor.params_have_main_grad else param.grad
55
+ if grad.__class__.__name__ == 'DTensor':
56
+ grad = grad.to_local()
52
57
  element_in_cur_partition = self.fp16_to_fp32_param.get(param, param).numel()
53
58
  if param.numel() != element_in_cur_partition:
54
59
  if first_param:
55
60
  grad = grad.flatten()[-element_in_cur_partition:]
56
- else: # supposed to be the last one
61
+ else: # supposed to be the last one
57
62
  grad = grad.flatten()[:element_in_cur_partition]
58
63
  first_param = False
59
64
 
@@ -120,6 +125,59 @@ class OptimizerMon(object):
120
125
  monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
121
126
  return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
122
127
 
128
+ def patch_grad_sync(self, monitor):
129
+ def patch_sync(sync_grad_func):
130
+ def wrapper(bucket):
131
+ grad_dict = {}
132
+ # Megatron between core_r0.6.0 and core_r0.8.0, this bucket is Bucket.
133
+ # When megatron is core_r0.9.0, this bucket is _ParamAndGradBucketGroup.
134
+ # In megatron version core_r0.9.0, func start_grad_sync from Bucket moved to _ParamAndGradBucketGroup.
135
+ bucket_params_id_list = [id(params) for params in bucket.params]
136
+ for param, name in monitor.param2name.items():
137
+ if id(param) not in bucket_params_id_list:
138
+ continue
139
+ grad = param.main_grad if monitor.params_have_main_grad else param.grad
140
+ if grad is None:
141
+ logger.warning(f"grad is None: {name}, maybe something wrong happened.")
142
+ continue
143
+ tag = monitor.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
144
+ if tag is None:
145
+ continue
146
+ grad_dict[tag] = grad
147
+ monitor.register_param_call_id("sync_grad_func", tag)
148
+ get_metrics(monitor.ops, grad_dict, monitor.eps, monitor.grad_context.pre)
149
+ out = sync_grad_func(bucket)
150
+ return out
151
+
152
+ return wrapper
153
+
154
+ try:
155
+ from megatron.core.distributed.param_and_grad_buffer import Bucket
156
+ self.origin_funcs.append(Bucket.start_grad_sync)
157
+ self.bucket_class = Bucket
158
+ Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync)
159
+ monitor.enable_megatron = True
160
+ logger.info("megatron version is >= core_r0.6.0 <= core_r0.8.0")
161
+ except ImportError:
162
+ monitor.enable_megatron = False
163
+
164
+ try:
165
+ from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
166
+ self.origin_funcs.append(_ParamAndGradBucketGroup.start_grad_sync)
167
+ self.bucket_class = _ParamAndGradBucketGroup
168
+ _ParamAndGradBucketGroup.start_grad_sync = patch_sync(_ParamAndGradBucketGroup.start_grad_sync)
169
+ monitor.enable_megatron = True
170
+ logger.info("megatron version is > core_r0.8.0 <= core_r0.9.0")
171
+ except ImportError:
172
+ monitor.enable_megatron = False | monitor.enable_megatron
173
+
174
+ def restore_grad_sync(self, monitor):
175
+ if not monitor.enable_megatron:
176
+ return
177
+
178
+ self.bucket_class.start_grad_sync = self.origin_funcs[0]
179
+
180
+
123
181
  def _get_single_state(self, torch_opt):
124
182
  state = {}
125
183
  if hasattr(torch_opt, 'param_to_cpu_states_map'):
@@ -131,7 +189,7 @@ class OptimizerMon(object):
131
189
  self.state.update(state)
132
190
 
133
191
 
134
- class MixPrecisionOptimizerMon(OptimizerMon):
192
+ class MegatronMixPrecisionOptimizerMon(OptimizerMon):
135
193
  """
136
194
  混合精度优化器监控类。在混合精度训练中监控和管理优化器。
137
195
  混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
@@ -161,7 +219,7 @@ class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
161
219
  super().map_fp16_to_fp32_param(opt)
162
220
 
163
221
 
164
- class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
222
+ class MegatronChainedMixPrecisionOptimizerMon(MegatronMixPrecisionOptimizerMon):
165
223
  def map_fp16_to_fp32_param(self, torch_opt):
166
224
  for opt in torch_opt.chained_optimizers:
167
225
  super().map_fp16_to_fp32_param(opt)
@@ -248,6 +306,12 @@ class DeepSpeedZeroOptimizerMon(OptimizerMon):
248
306
  grad_dict[tag] = grad
249
307
 
250
308
  return grad_dict
309
+
310
+ def patch_grad_sync(self, monitor):
311
+ pass
312
+
313
+ def restore_grad_sync(self, monitor):
314
+ pass
251
315
 
252
316
 
253
317
  class DeepSpeedZeroOptimizerStage0Mon(DeepSpeedZeroOptimizerMon):
@@ -291,6 +355,47 @@ class DeepSpeedZeroOptimizerStage1or2Mon(DeepSpeedZeroOptimizerMon):
291
355
  break
292
356
 
293
357
 
358
+ def patch_grad_sync(self, monitor):
359
+ def patch_sync(reduce_func):
360
+ def wrapper(zero_optimizer, *args, **kwargs):
361
+ grad_dict = {}
362
+ for i, param, _ in zero_optimizer.params_in_ipg_bucket:
363
+ if isinstance(param, int): # for ds >= 0.17.0
364
+ param = zero_optimizer.bit16_groups[i][param]
365
+ name = monitor.param2name[param]
366
+ tag = monitor.name2tag.get(name, {}).get(MonitorConst.PRE_GRAD)
367
+ grad_dict[tag] = zero_optimizer.get_gradient_for_reduction(param)
368
+ monitor.register_param_call_id("sync_grad_func", tag)
369
+ get_metrics(monitor.ops, grad_dict, monitor.eps, monitor.grad_context.pre)
370
+ out = reduce_func(zero_optimizer, *args, **kwargs)
371
+ return out
372
+
373
+ return wrapper
374
+ try:
375
+ from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
376
+ self.origin_funcs = [
377
+ DeepSpeedZeroOptimizer.average_tensor,
378
+ DeepSpeedZeroOptimizer.buffered_reduce_fallback
379
+ ]
380
+ DeepSpeedZeroOptimizer.average_tensor = patch_sync(DeepSpeedZeroOptimizer.average_tensor)
381
+ DeepSpeedZeroOptimizer.buffered_reduce_fallback = \
382
+ patch_sync(DeepSpeedZeroOptimizer.buffered_reduce_fallback)
383
+ monitor.enable_deepspeed = True
384
+ logger.info('deepspeed enabled')
385
+ except Exception as e:
386
+ monitor.enable_deepspeed = False | monitor.enable_deepspeed
387
+ logger.warning('Seems using deepspeed zero 1 or 2. But patch average tensor failed')
388
+
389
+ def restore_grad_sync(self, monitor):
390
+ if not monitor.enable_deepspeed:
391
+ return
392
+
393
+ from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
394
+ DeepSpeedZeroOptimizer.average_tensor = self.origin_funcs[0]
395
+ DeepSpeedZeroOptimizer.buffered_reduce_fallback = self.origin_funcs[1]
396
+
397
+
398
+
294
399
  class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon):
295
400
  def __init__(self, torch_opt):
296
401
  super().__init__(torch_opt)
@@ -314,7 +419,7 @@ class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon):
314
419
  class OptimizerMonFactory:
315
420
  _optimizer_mon_map = {
316
421
  "FP32Optimizer": OptimizerMon,
317
- "Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
422
+ "Float16OptimizerWithFloat16Params": MegatronMixPrecisionOptimizerMon,
318
423
  "DistributedOptimizer": MegatronDistributedOptimizerMon,
319
424
  "SwapDistributedOptimizer": MegatronDistributedOptimizerMon,
320
425
  "ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
@@ -17,7 +17,7 @@ import json
17
17
  import os
18
18
  import time
19
19
  import multiprocessing
20
- from multiprocessing import Pool
20
+ from multiprocessing import Pool, Lock
21
21
 
22
22
  import torch
23
23
  from torch.utils._python_dispatch import TorchDispatchMode
@@ -39,6 +39,7 @@ from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu, ge
39
39
  from msprobe.pytorch.online_dispatch.compare import Comparator
40
40
  from msprobe.core.common.utils import check_str_param, safe_get_value
41
41
 
42
+ child_global_lock = None
42
43
  current_time = time.strftime("%Y%m%d%H%M%S")
43
44
  RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
44
45
  DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv"
@@ -86,14 +87,14 @@ class PtdbgDispatch(TorchDispatchMode):
86
87
  yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml")
87
88
  self.get_ops(yaml_path)
88
89
 
89
- self.lock = None
90
+ self.lock = Lock() if process_num > 0 else None
90
91
  max_process_num = max(int((multiprocessing.cpu_count() + 1) // Const.CPU_QUARTER), 1)
91
92
  if process_num > max_process_num:
92
93
  logger.error(f"process_num should be less than or equal to {max_process_num}, but got {process_num}!")
93
94
  raise DispatchException(f'process_num should be less than or equal to {max_process_num}, '
94
95
  f'but got {process_num}!')
95
96
  if process_num > 0:
96
- self.pool = Pool(process_num)
97
+ self.pool = Pool(process_num, initializer=self._init_child_process, initargs=(self.lock,))
97
98
  if debug:
98
99
  logger.info(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} '
99
100
  f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], '
@@ -114,18 +115,17 @@ class PtdbgDispatch(TorchDispatchMode):
114
115
  logger.error("Please check train log, An exception may have occurred!")
115
116
  return
116
117
  check_file_or_directory_path(summary_path, False)
117
- fp_handle = FileOpen(summary_path, "r")
118
- while True:
119
- json_line_data = fp_handle.readline()
120
- if json_line_data == '\n':
121
- continue
122
- if len(json_line_data) == 0:
123
- break
124
- msg = json.loads(json_line_data)
125
- if len(msg) < 2:
126
- raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.")
127
- self.all_summary[msg[0]] = msg[1]
128
- fp_handle.close()
118
+ with FileOpen(summary_path, "r") as fp_handle:
119
+ while True:
120
+ json_line_data = fp_handle.readline()
121
+ if json_line_data == '\n':
122
+ continue
123
+ if len(json_line_data) == 0:
124
+ break
125
+ msg = json.loads(json_line_data)
126
+ if len(msg) < 2:
127
+ raise ValueError("JSON data does not contain enough elements. Expected at least 2 elements.")
128
+ self.all_summary[msg[0]] = msg[1]
129
129
 
130
130
  if self.debug_flag:
131
131
  input_num = 0
@@ -163,11 +163,16 @@ class PtdbgDispatch(TorchDispatchMode):
163
163
 
164
164
  call_stack = get_callstack()
165
165
  self.call_stack_list.append(call_stack)
166
- self.api_index += 1
167
- if aten_api not in self.single_api_index_dict:
168
- self.single_api_index_dict[aten_api] = 1
169
- else:
170
- self.single_api_index_dict[aten_api] += 1
166
+
167
+ self.lock.acquire() if self.process_num > 0 else None
168
+ try:
169
+ self.api_index += 1
170
+ if aten_api not in self.single_api_index_dict:
171
+ self.single_api_index_dict[aten_api] = 1
172
+ else:
173
+ self.single_api_index_dict[aten_api] += 1
174
+ finally:
175
+ self.lock.release() if self.process_num > 0 else None
171
176
 
172
177
  run_param = self.get_run_param(aten_api, func.__name__, aten_api_overload_name)
173
178
 
@@ -180,7 +185,7 @@ class PtdbgDispatch(TorchDispatchMode):
180
185
  cpu_kwargs = []
181
186
  data_to_cpu(args, 0, cpu_args)
182
187
  data_to_cpu(kwargs, 0, cpu_kwargs)
183
-
188
+
184
189
  cpu_args = safe_get_value(cpu_args, 0, "cpu_args")
185
190
  cpu_kwargs = safe_get_value(cpu_kwargs, 0, "cpu_kwargs")
186
191
 
@@ -194,7 +199,12 @@ class PtdbgDispatch(TorchDispatchMode):
194
199
  try:
195
200
  cpu_out = func(*cpu_args, **cpu_kwargs)
196
201
  except RuntimeError as e:
197
- self.api_index -= 1
202
+ self.lock.acquire() if self.process_num > 0 else None
203
+ try:
204
+ self.api_index -= 1
205
+ self.single_api_index_dict[aten_api] -= 1
206
+ finally:
207
+ self.lock.release() if self.process_num > 0 else None
198
208
  logger.warning(f"RuntimeError: {e}")
199
209
  logger.warning(f"This aten_api {aten_api} does not support running on cpu, so skip it.")
200
210
  return npu_out
@@ -215,7 +225,7 @@ class PtdbgDispatch(TorchDispatchMode):
215
225
  run_param.process_flag = True
216
226
  if self.check_fun(func, run_param):
217
227
  data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out,
218
- self.lock)
228
+ child_global_lock)
219
229
  self.pool.apply_async(func=dispatch_multiprocess, args=(run_param, data_info),
220
230
  error_callback=error_call)
221
231
  else:
@@ -233,12 +243,20 @@ class PtdbgDispatch(TorchDispatchMode):
233
243
  return True
234
244
  return False
235
245
 
246
+ @staticmethod
247
+ def _init_child_process(lock):
248
+ global child_global_lock
249
+ child_global_lock = lock
250
+
236
251
  def get_dir_name(self, tag):
237
252
  # guarantee file uniqueness
238
253
  time.sleep(1)
239
- time_now = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))
254
+ # 时间格式:年-月-日-时-分-秒-毫秒(精确到千分之一秒)
255
+ time_now = time.strftime("%Y%m%d%H%M%S%f", time.localtime(time.time()))[:-3] # 取前3位毫秒
256
+
240
257
  if tag is None or not isinstance(tag, str):
241
258
  logger.warning('There is not tag or the type of tag is not string.')
259
+ # 目录名格式:msprobe_rank{设备ID}_{毫秒时间戳}
242
260
  dir_name = f'msprobe_rank{self.device_id}_{time_now}'
243
261
  else:
244
262
  dir_name = f'msprobe_{tag}_rank{self.device_id}_{time_now}'
@@ -21,7 +21,7 @@ from datetime import datetime, timezone
21
21
  import torch
22
22
  from msprobe.core.common.const import Const
23
23
  from msprobe.core.common.decorator import recursion_depth_decorator
24
- from msprobe.core.common.file_utils import FileOpen, save_npy, save_json, check_link, remove_path
24
+ from msprobe.core.common.file_utils import FileOpen, save_npy, save_json, remove_path, check_link
25
25
  from msprobe.pytorch.common.log import logger
26
26
 
27
27
 
@@ -83,4 +83,3 @@ class Visualization:
83
83
  self.util.log.info("\nStatistic Info:")
84
84
  title_printed = True
85
85
  self.util.log.info(summery_info)
86
- pkl_handle.close()
@@ -35,48 +35,15 @@ from msprobe.pytorch.hook_module.utils import get_ops
35
35
  class TensorConfig(BaseConfig):
36
36
  def __init__(self, json_config):
37
37
  super().__init__(json_config)
38
- self.online_run_ut = json_config.get("online_run_ut", False)
39
- self.nfs_path = json_config.get("nfs_path", "")
40
- self.host = json_config.get("host", "")
41
- self.port = json_config.get("port", -1)
42
- self.tls_path = json_config.get("tls_path", "./")
43
- self.online_run_ut_recompute = json_config.get("online_run_ut_recompute", False)
44
38
  self.check_config()
45
39
  self._check_summary_mode()
46
40
  self._check_file_format()
47
- if self.online_run_ut:
48
- self._check_online_run_ut()
41
+
49
42
 
50
43
  def _check_file_format(self):
51
44
  if self.file_format is not None and self.file_format not in ["npy", "bin"]:
52
45
  raise Exception("file_format is invalid")
53
46
 
54
- def _check_online_run_ut(self):
55
- if not isinstance(self.online_run_ut, bool):
56
- raise Exception(f"online_run_ut: {self.online_run_ut} is invalid.")
57
-
58
- if not isinstance(self.online_run_ut_recompute, bool):
59
- raise Exception(f"online_run_ut_recompute: {self.online_run_ut_recompute} is invalid.")
60
-
61
- if self.nfs_path:
62
- check_file_or_directory_path(self.nfs_path, isdir=True)
63
- return
64
-
65
- if self.tls_path:
66
- check_file_or_directory_path(self.tls_path, isdir=True)
67
- check_file_or_directory_path(os.path.join(self.tls_path, "client.key"))
68
- check_file_or_directory_path(os.path.join(self.tls_path, "client.crt"))
69
- check_file_or_directory_path(os.path.join(self.tls_path, "ca.crt"))
70
- crl_path = os.path.join(self.tls_path, "crl.pem")
71
- if os.path.exists(crl_path):
72
- check_file_or_directory_path(crl_path)
73
-
74
- if not isinstance(self.host, str) or not re.match(Const.ipv4_pattern, self.host):
75
- raise Exception(f"host: {self.host} is invalid.")
76
-
77
- if not isinstance(self.port, int) or not (0 < self.port <= 65535):
78
- raise Exception(f"port: {self.port} is invalid, port range 0-65535.")
79
-
80
47
 
81
48
  class StatisticsConfig(BaseConfig):
82
49
  def __init__(self, json_config):
@@ -251,12 +218,7 @@ class RunUTConfig(BaseConfig):
251
218
  self.white_list = json_config.get("white_list", Const.DEFAULT_LIST)
252
219
  self.black_list = json_config.get("black_list", Const.DEFAULT_LIST)
253
220
  self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH)
254
- self.is_online = json_config.get("is_online", False)
255
- self.nfs_path = json_config.get("nfs_path", "")
256
- self.host = json_config.get("host", "")
257
- self.port = json_config.get("port", -1)
258
- self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST)
259
- self.tls_path = json_config.get("tls_path", "./")
221
+
260
222
  self.check_run_ut_config()
261
223
 
262
224
  @classmethod
@@ -274,22 +236,11 @@ class RunUTConfig(BaseConfig):
274
236
  if not os.path.exists(error_data_path):
275
237
  raise Exception("error_data_path: %s does not exist" % error_data_path)
276
238
 
277
- @classmethod
278
- def check_nfs_path_config(cls, nfs_path):
279
- if nfs_path:
280
- FileChecker(nfs_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
281
-
282
- @classmethod
283
- def check_tls_path_config(cls, tls_path):
284
- if tls_path:
285
- FileChecker(tls_path, FileCheckConst.DIR, FileCheckConst.READ_ABLE).common_check()
286
239
 
287
240
  def check_run_ut_config(self):
288
241
  RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list)
289
242
  RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list)
290
243
  RunUTConfig.check_error_data_path_config(self.error_data_path)
291
- RunUTConfig.check_nfs_path_config(self.nfs_path)
292
- RunUTConfig.check_tls_path_config(self.tls_path)
293
244
 
294
245
 
295
246
  class GradToolConfig(BaseConfig):
@@ -15,18 +15,14 @@
15
15
 
16
16
  from msprobe.core.common.utils import Const
17
17
  from msprobe.core.service import BaseService
18
- from msprobe.pytorch.attl_manager import ATTLManager
19
18
  from msprobe.pytorch.common.log import logger
20
- from msprobe.pytorch.common.utils import get_rank_if_initialized, torch_version_above_or_equal_2
19
+ from msprobe.pytorch.common.utils import get_rank_if_initialized
21
20
  from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser
22
- from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate
21
+ from msprobe.pytorch.hook_module.api_register import get_api_register, ApiTemplate, redirect_wait
23
22
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
- from msprobe.pytorch.hook_module.jit_script_wrapper import wrap_jit_script_func
25
23
  from msprobe.pytorch.hook_module.pt_hook_manager import PytorchHookManager
26
24
  from msprobe.pytorch.hook_module.register_optimizer_hook import register_optimizer_hook
27
-
28
- if torch_version_above_or_equal_2:
29
- from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.dump_dispatch import run_ut_dispatch
25
+ from msprobe.pytorch.hook_module.script_wrapper import wrap_script_func, preprocess_func
30
26
 
31
27
 
32
28
  class PytorchService(BaseService):
@@ -45,27 +41,24 @@ class PytorchService(BaseService):
45
41
  self.logger = logger
46
42
  self.api_register = get_api_register()
47
43
  self.module_processor = ModuleProcesser(self.data_collector.scope)
48
- self.attl_manager = ATTLManager(self.config)
49
- self.hook_manager = PytorchHookManager(self.data_collector, self.config, self.attl_manager)
44
+ self.hook_manager = PytorchHookManager(self.data_collector, self.config)
50
45
  self.api_template = ApiTemplate
51
46
 
52
47
  def _register_hook(self):
53
- self.attl_manager.attl_init()
54
48
  if self._is_mix_level:
55
49
  register_optimizer_hook(self.data_collector)
56
50
 
57
51
  def _register_api_hook(self):
52
+ preprocess_func()
58
53
  super()._register_api_hook()
59
- wrap_jit_script_func()
54
+ wrap_script_func()
55
+ redirect_wait()
60
56
 
61
57
  def _register_module_hook(self):
62
58
  ModuleProcesser.enable_module_dump = True
63
59
  self.module_processor.register_module_hook(self.model, self.build_hook)
64
60
  self.logger.info(f"The module {self.config.task} hook function is successfully mounted to the model.")
65
61
 
66
- def _run_ut_dispatch(self, status):
67
- if torch_version_above_or_equal_2:
68
- run_ut_dispatch(self.attl_manager.attl, status, self.config.online_run_ut_recompute)
69
62
 
70
63
  def _reset_status(self):
71
64
  super()._reset_status()