mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -144,6 +144,20 @@ class IdentMetric(Metric):
144
144
  return tensor
145
145
 
146
146
 
147
+ @register_config_metric("shape")
148
+ class ShapeMetric(Metric):
149
+ @staticmethod
150
+ def get_metric_value(tensor, eps):
151
+ return tensor.shape
152
+
153
+
154
+ @register_config_metric("dtype")
155
+ class DtypeMetric(Metric):
156
+ @staticmethod
157
+ def get_metric_value(tensor, eps):
158
+ return tensor.dtype
159
+
160
+
147
161
  def get_metrics(ops, tag2tensor, eps, out_dict=None):
148
162
  """
149
163
  :param ops: ["op1", "op2"]
@@ -12,129 +12,123 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
-
16
- from collections import defaultdict
15
+ from abc import abstractmethod
17
16
 
18
17
  import torch
19
- import torch.distributed as dist
20
18
 
21
19
  from msprobe.pytorch.common.log import logger
22
- from msprobe.pytorch.monitor.utils import MVResult, MVGradResult
20
+ from msprobe.pytorch.monitor.utils import MVResult
21
+ from msprobe.core.common.const import MonitorConst
23
22
 
24
23
 
25
24
  class OptimizerMon(object):
26
- def __init__(self) -> None:
25
+ def __init__(self, torch_opt) -> None:
27
26
  self.fp16_to_fp32_param = {}
28
- self.is_stage3 = False
27
+ self.torch_opt = torch_opt
28
+ self.state = {}
29
29
 
30
- def fetch_mv(self, monitor, torch_opt, params2name):
31
- pass
30
+ def narrow_from_flatten(self, param, flatten_state):
31
+ return flatten_state
32
+
33
+ def get_state(self, torch_opt):
34
+ if hasattr(torch_opt, 'chained_optimizers'):
35
+ for opt in torch_opt.chained_optimizers:
36
+ self._get_single_state(opt)
37
+ else:
38
+ self._get_single_state(torch_opt)
32
39
 
33
- def _fetch_mv_in_adam(self, monitor, torch_opt, params2name):
34
- exp_avg_dict = defaultdict(float)
35
- exp_avg_sq_dict = defaultdict(float)
36
- update_dict = defaultdict()
37
- ratio_dict = defaultdict()
40
+ def fetch_grad(self, monitor, params2name):
41
+ if not self.fp16_to_fp32_param:
42
+ self.map_fp16_to_fp32_param(self.torch_opt)
43
+
44
+ grad_dict = {}
45
+ first_param = True
38
46
  for param, name in params2name.items():
39
- if param in self.fp16_to_fp32_param:
40
- param = self.fp16_to_fp32_param[param]
41
-
42
- if param in torch_opt.state:
43
- state_param = torch_opt.state.get(param, None)
44
- exp_avg = state_param.get("exp_avg", None)
45
- exp_avg_sq = state_param.get("exp_avg_sq", None)
46
- if exp_avg is None or exp_avg_sq is None:
47
- logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.")
48
- continue
47
+ if monitor.duplicate_param.get(name, False):
48
+ continue
49
+ if self.fp16_to_fp32_param and param not in self.fp16_to_fp32_param:
50
+ continue
51
+ grad = param.main_grad if monitor.params_have_main_grad else param.grad
52
+ element_in_cur_partition = self.fp16_to_fp32_param.get(param, param).numel()
53
+ if param.numel() != element_in_cur_partition:
54
+ if first_param:
55
+ grad = grad.flatten()[-element_in_cur_partition:]
56
+ else: # supposed to be the last one
57
+ grad = grad.flatten()[:element_in_cur_partition]
58
+ first_param = False
59
+
60
+ if grad is None:
61
+ if not monitor.fsdp_wrapped_module:
62
+ logger.warning(f"grad is None: {name}, maybe something wrong happened.")
63
+ continue
64
+ tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
65
+ monitor.register_param_call_id("hook_optimizer", tag)
66
+ grad_dict[tag] = grad
67
+ return grad_dict
68
+
69
+ def map_fp16_to_fp32_param(self, torch_opt):
70
+ pass
71
+
72
+ def fetch_mv(self, monitor, params2name):
73
+ if not self.fp16_to_fp32_param:
74
+ self.map_fp16_to_fp32_param(self.torch_opt)
75
+ if not self.state:
76
+ self.get_state(self.torch_opt)
77
+
78
+ exp_avg_dict = {}
79
+ exp_avg_sq_dict = {}
80
+ update_dict = {}
81
+ ratio_dict = {}
82
+
83
+ if not self.state:
84
+ logger.warning('optimizer state can not accessed')
85
+ return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
86
+
87
+ for lp_param, name in params2name.items():
88
+ if lp_param in self.fp16_to_fp32_param:
89
+ hp_param = self.fp16_to_fp32_param[lp_param]
90
+ else:
91
+ hp_param = lp_param
92
+
93
+ if hp_param in self.state:
94
+ state_param = self.state.get(hp_param, {})
95
+ exp_avg = self.narrow_from_flatten(lp_param, state_param.get("exp_avg", None))
96
+ exp_avg_sq = self.narrow_from_flatten(lp_param, state_param.get("exp_avg_sq", None))
49
97
  if monitor.mv_distribution:
50
98
  exp_avg_dict[name] = exp_avg
51
99
  exp_avg_sq_dict[name] = exp_avg_sq
52
100
  if monitor.mg_direction:
53
101
  exp_avg_dict[name] = exp_avg
54
102
  if monitor.ur_distribution:
55
- if len(torch_opt.param_groups) > 1:
56
- logger.info(f"the length of torch_opt.param_groups is {len(torch_opt.param_groups)}.")
103
+ if len(self.torch_opt.param_groups) > 1:
104
+ logger.info(f"the length of torch_opt.param_groups is {len(self.torch_opt.param_groups)}.")
57
105
  if 'step' in state_param:
58
106
  step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
59
- elif 'step' in torch_opt.param_groups[0]:
60
- step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed
107
+ elif 'step' in self.torch_opt.param_groups[0]:
108
+ step = self.torch_opt.param_groups[0]['step'] # AdamW from mindspeed
61
109
  else:
62
110
  logger.warning(f"step of {name} is None, maybe something wrong happened.")
63
111
  continue
64
- exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
65
- exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
66
- update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
112
+ if exp_avg is None or exp_avg_sq is None:
113
+ logger.warning(f"exp_avg or exp_avg_sq of {name} is None, skip calculation.")
114
+ continue
115
+ exp_avg_hat = exp_avg / (1 - self.torch_opt.defaults['betas'][0] ** step)
116
+ exp_avg_sq_hat = exp_avg_sq / (1 - self.torch_opt.defaults['betas'][1] ** step)
117
+ update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + self.torch_opt.defaults['eps'])
67
118
  ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
68
119
  monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
69
120
  monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
70
121
  return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
71
-
72
- def _fetch_mv_grad_in_adam(self, monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat):
73
- exp_avg_dict = defaultdict(float)
74
- exp_avg_sq_dict = defaultdict(float)
75
- update_dict = defaultdict()
76
- ratio_dict = defaultdict()
77
- param2name = defaultdict()
78
- fp32_partitioned_groups_flat_grad = defaultdict()
79
- partition_id = dist.get_rank()
80
-
81
- def get_flatten_grad(self, optimizer, group_idx):
82
- if fp32_partitioned_groups_flat[group_idx].grad is None:
83
- if partition_id == dist.get_world_size() - 1 and not self.is_stage3:
84
- fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned(
85
- optimizer.averaged_gradients[group_idx],
86
- int(optimizer.partition_size[group_idx])
87
- ).to(fp32_partitioned_groups_flat[group_idx].dtype)
88
- else:
89
- fp32_partitioned_groups_flat_grad = optimizer.flatten(
90
- optimizer.averaged_gradients[group_idx]
91
- ).to(fp32_partitioned_groups_flat[group_idx].dtype)
92
- return fp32_partitioned_groups_flat_grad
93
- else:
94
- return fp32_partitioned_groups_flat[group_idx].grad
95
-
96
- for group_idx in range(len(fp32_partitioned_groups_flat)):
97
- fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, torch_opt, group_idx)
98
-
99
- for name in params2name.values():
100
- start_idx, end_idx, group_idx, group_with_rank = name2indices[name]
101
- if group_with_rank != partition_id and isinstance(group_with_rank, int):
102
- continue
103
- fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx]
104
- fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx]
105
- param2name[fp32_param] = name
106
- if not torch_opt.state:
107
- continue
108
- state_param = list(torch_opt.state.values())[group_idx]
109
- exp_avg = state_param.get("exp_avg", None)
110
- exp_avg_sq = state_param.get("exp_avg_sq", None)
111
- if exp_avg is None or exp_avg_sq is None:
112
- logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.")
113
- continue
114
- exp_avg = exp_avg[start_idx: end_idx]
115
- exp_avg_sq = exp_avg_sq[start_idx: end_idx]
116
- if monitor.mv_distribution:
117
- exp_avg_dict[name] = exp_avg
118
- exp_avg_sq_dict[name] = exp_avg_sq
119
- if monitor.mg_direction:
120
- exp_avg_dict[name] = exp_avg
121
- if monitor.ur_distribution:
122
- if 'step' in state_param:
123
- step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
124
- elif 'step' in torch_opt.param_groups[group_idx]:
125
- step = torch_opt.param_groups[group_idx]['step'] # AdamW from mindspeed
126
- else:
127
- logger.warning(f"step of {name} is None, maybe something wrong happened.")
128
- continue
129
- exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
130
- exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
131
- update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
132
- ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
133
- monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
134
- monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
135
- del fp32_partitioned_groups_flat_grad
136
- return MVGradResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict,
137
- grad=param2name)
122
+
123
+ def _get_single_state(self, torch_opt):
124
+ state = {}
125
+ if hasattr(torch_opt, 'param_to_cpu_states_map'):
126
+ state = torch_opt.param_to_cpu_states_map
127
+ elif hasattr(torch_opt, 'state'):
128
+ state = torch_opt.state
129
+ elif hasattr(torch_opt, 'optimizer') and hasattr(torch_opt.optimizer, 'state'):
130
+ state = torch_opt.optimizer.state
131
+ self.state.update(state)
138
132
 
139
133
 
140
134
  class MixPrecisionOptimizerMon(OptimizerMon):
@@ -142,21 +136,14 @@ class MixPrecisionOptimizerMon(OptimizerMon):
142
136
  混合精度优化器监控类。在混合精度训练中监控和管理优化器。
143
137
  混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
144
138
  """
145
-
146
- def map_fp16_tp_fp32_param(self, torch_opt):
139
+ def map_fp16_to_fp32_param(self, torch_opt):
147
140
  for fp16_group, fp32_group in zip(torch_opt.float16_groups, torch_opt.fp32_from_float16_groups):
148
141
  for fp16_param, fp32_param in zip(fp16_group, fp32_group):
149
142
  self.fp16_to_fp32_param[fp16_param] = fp32_param
150
143
 
151
- def fetch_mv(self, monitor, torch_opt, params2name):
152
- if not self.fp16_to_fp32_param and torch_opt is not None:
153
- self.map_fp16_tp_fp32_param(torch_opt)
154
-
155
- return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
156
-
157
144
 
158
145
  class MegatronDistributedOptimizerMon(OptimizerMon):
159
- def map_fp16_tp_fp32_param(self, torch_opt):
146
+ def map_fp16_to_fp32_param(self, torch_opt):
160
147
  if not (hasattr(torch_opt, "model_float16_groups") and
161
148
  hasattr(torch_opt, "shard_fp32_from_float16_groups")):
162
149
  raise Exception(
@@ -167,192 +154,176 @@ class MegatronDistributedOptimizerMon(OptimizerMon):
167
154
  for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
168
155
  self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
169
156
 
170
- def fetch_mv(self, monitor, torch_opt, params2name):
171
- if not self.fp16_to_fp32_param and torch_opt is not None:
172
- self.map_fp16_tp_fp32_param(torch_opt)
173
-
174
- return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
175
-
176
-
177
- class MegatronFP32OptimizerMon(OptimizerMon):
178
- def fetch_mv(self, monitor, torch_opt, params2name):
179
- return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
180
-
181
157
 
182
158
  class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
183
- def fetch_mv(self, monitor, torch_opt, params2name):
184
- if not self.fp16_to_fp32_param and torch_opt is not None:
185
- for opt in torch_opt.chained_optimizers:
186
- self.map_fp16_tp_fp32_param(opt)
187
-
188
- if not isinstance(torch_opt, torch.optim.Optimizer) and not hasattr(torch_opt, 'state'):
189
- torch_opt.state = {}
190
- for opt in torch_opt.chained_optimizers:
191
- torch_opt.state.update(opt.optimizer.state)
192
- return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
159
+ def map_fp16_to_fp32_param(self, torch_opt):
160
+ for opt in torch_opt.chained_optimizers:
161
+ super().map_fp16_to_fp32_param(opt)
193
162
 
194
163
 
195
164
  class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
196
- def fetch_mv(self, monitor, torch_opt, params2name):
197
- if not self.fp16_to_fp32_param and torch_opt is not None:
198
- for opt in torch_opt.chained_optimizers:
199
- self.map_fp16_tp_fp32_param(opt)
165
+ def map_fp16_to_fp32_param(self, torch_opt):
166
+ for opt in torch_opt.chained_optimizers:
167
+ super().map_fp16_to_fp32_param(opt)
200
168
 
201
- if not isinstance(torch_opt, torch.optim.Optimizer) and not hasattr(torch_opt, 'state'):
202
- torch_opt.state = {}
203
- for opt in torch_opt.chained_optimizers:
204
- torch_opt.state.update(opt.optimizer.state)
205
- return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
206
169
 
207
-
208
- class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon):
209
- def get_group_index(self, torch_opt):
210
- bit16_groups = torch_opt.bf16_groups
211
- param2group = defaultdict()
212
- for group_idx, bit16_group in enumerate(bit16_groups):
170
+ class DeepSpeedZeroOptimizerMon(OptimizerMon):
171
+ """
172
+ Base monitor class for DeepSpeed ZeRO optimizer.
173
+ ZeRO stage 0 no partition
174
+ ZeRO stage 1 partitions optimizer states across data parallel processes.
175
+ ZeRO stage 2 additionally partitions gradients.
176
+ ZeRO stage 3 additionally partitions parameters.
177
+
178
+ This class provides monitoring capabilities for ZeRO optimizers by:
179
+ - Handling gradient collection for different ZeRO stages
180
+ - Managing optimizer state access for monitoring
181
+ """
182
+ def __init__(self, torch_opt):
183
+ super().__init__(torch_opt)
184
+ self.stage = ''
185
+ self.bit16_groups = []
186
+ self.fp32_flat_groups = []
187
+ self.param2group = ()
188
+ self.param2index = []
189
+ self.group_offset = {}
190
+
191
+ @abstractmethod
192
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
193
+ raise NotImplementedError
194
+
195
+ def param_not_in_partition(self, lp_param, group_idx):
196
+ param_slice_mapping = self.torch_opt.state_dict()['param_slice_mappings'][group_idx]
197
+ hp_address = param_slice_mapping.get(self.torch_opt.param_names.get(lp_param))
198
+ return hp_address is None
199
+
200
+ def get_position(self, lp_param, group_idx):
201
+ param_slice_mapping = self.torch_opt.state_dict()['param_slice_mappings'][group_idx]
202
+ hp_address = param_slice_mapping.get(self.torch_opt.param_names.get(lp_param))
203
+ return hp_address.start, hp_address.numel
204
+
205
+ def get_group_index(self):
206
+ param2group = {}
207
+ for group_idx, bit16_group in enumerate(self.bit16_groups):
213
208
  for param in bit16_group:
214
209
  param2group[param] = group_idx
215
210
  return param2group
216
-
217
- def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
218
- param2group = self.get_group_index(torch_opt)
219
- exp_avg_dict = defaultdict(float)
220
- exp_avg_sq_dict = defaultdict(float)
221
- update_dict = defaultdict()
222
- ratio_dict = defaultdict()
223
-
224
- param_slice_mappings = torch_opt.state_dict()['param_slice_mappings']
225
- for param, name in params2name.items():
226
- group_idx = param2group[param]
227
- state = torch_opt.optimizer.state[torch_opt.fp32_groups_flat_partition[group_idx]]
228
- if state.get('exp_avg', None) is None:
229
- logger.warning(f"optimizer state is None. Something is wrong if this is not the first step")
230
- break
231
- param_slice_mapping = param_slice_mappings[group_idx]
232
- hp_address = param_slice_mapping.get(torch_opt.param_names[param])
233
- if hp_address is None:
211
+
212
+ def get_param_index(self, lp_param, group_idx):
213
+ if not self.param2index:
214
+ for group in self.bit16_groups:
215
+ param2index = {}
216
+ for index, param in enumerate(group):
217
+ param2index[param] = index
218
+ self.param2index.append(param2index)
219
+
220
+ return self.param2index[group_idx][lp_param]
221
+
222
+ def narrow_from_flatten(self, param, flatten_state):
223
+ if flatten_state is None:
224
+ return flatten_state
225
+ group_idx = self.param2group[param]
226
+ if self.param_not_in_partition(param, group_idx):
227
+ return None
228
+ start, numel = self.get_position(param, group_idx)
229
+ return flatten_state.narrow(0, start, numel)
230
+
231
+ def map_fp16_to_fp32_param(self, torch_opt):
232
+ for group_idx, group in enumerate(self.bit16_groups):
233
+ for param in group:
234
+ self.fp16_to_fp32_param[param] = self.fp32_flat_groups[group_idx]
235
+
236
+ def fetch_grad(self, monitor, params2name):
237
+ grad_dict = {}
238
+ for lp_param, name in params2name.items():
239
+ group_idx = self.param2group[lp_param]
240
+ param_id = self.get_param_index(lp_param, group_idx)
241
+ if self.param_not_in_partition(lp_param, group_idx):
234
242
  continue
235
- start = hp_address.start
236
- numel = hp_address.numel
237
-
238
- if monitor.mv_distribution:
239
- exp_avg_dict[name] = state['exp_avg'].narrow(0, start, numel)
240
- exp_avg_sq_dict[name] = state['exp_avg_sq'].narrow(0, start, numel)
241
- if monitor.mg_direction:
242
- exp_avg_dict[name] = state['exp'].narrow(0, start, numel)
243
- if monitor.ur_distribution:
244
- if len(torch_opt.param_groups) > 1:
245
- logger.info(f"the length of torch_opt.param_groups is {len(torch_opt.param_groups)}.")
246
- if 'step' in state:
247
- step = state['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
248
- elif 'step' in torch_opt.param_groups[0]:
249
- step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed
243
+ if self.stage == '1or2':
244
+ param_id = param_id - self.group_offset[group_idx] - 1
245
+ grad = self.get_grad_for_param(lp_param, group_idx, param_id)
246
+ tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
247
+ monitor.register_param_call_id("hook_optimizer", tag)
248
+ grad_dict[tag] = grad
249
+
250
+ return grad_dict
251
+
252
+
253
+ class DeepSpeedZeroOptimizerStage0Mon(DeepSpeedZeroOptimizerMon):
254
+ def __init__(self, torch_opt):
255
+ super().__init__(torch_opt)
256
+ self.stage = '0'
257
+ self.bit16_groups = torch_opt.bf16_groups
258
+ self.fp32_flat_groups = torch_opt.fp32_groups_flat_partition
259
+ self.param2group = self.get_group_index()
260
+
261
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
262
+ return self.torch_opt.fp32_groups_gradient_dict[group_idx][param_id]
263
+
264
+
265
+ class DeepSpeedZeroOptimizerStage1or2Mon(DeepSpeedZeroOptimizerMon):
266
+ def __init__(self, torch_opt):
267
+ super().__init__(torch_opt)
268
+ self.stage = '1or2'
269
+ self.bit16_groups = torch_opt.bit16_groups
270
+ self.fp32_flat_groups = torch_opt.single_partition_of_fp32_groups
271
+ self.param2group = self.get_group_index()
272
+ self.group_offset = {}
273
+ self.get_group_offset()
274
+
275
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
276
+ if getattr(self.torch_opt, "cpu_offload", False):
277
+ grads = self.torch_opt.single_partition_of_fp32_groups[group_idx].grad
278
+ start, numel = self.get_position(lp_param, group_idx)
279
+ grad = grads.narrow(0, start, numel)
280
+ else:
281
+ grad = self.torch_opt.averaged_gradients[group_idx][param_id]
282
+ return grad
283
+
284
+ def get_group_offset(self):
285
+ for group_idx, group in enumerate(self.bit16_groups):
286
+ self.group_offset[group_idx] = -1
287
+ for lp_param in group:
288
+ if self.param_not_in_partition(lp_param, group_idx):
289
+ self.group_offset[group_idx] = self.get_param_index(lp_param, group_idx)
250
290
  else:
251
- logger.warning(f"step of {name} is None, maybe something wrong happened.")
252
- continue
253
- exp_avg = state['exp_avg'].narrow(0, start, numel)
254
- exp_avg_sq = state['exp_avg_sq'].narrow(0, start, numel)
255
- exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
256
- exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
257
- update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
258
- ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
259
- monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
260
- monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
261
- return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
262
-
291
+ break
263
292
 
264
- class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
265
- def get_param_index(self, params2name, name2index, torch_opt):
266
- fp16_groups = torch_opt.fp16_partitioned_groups
267
- name2indices = defaultdict()
268
- index_length = defaultdict()
269
- index = 0
270
- idx = 0
271
- for group_idx, fp16_group in enumerate(fp16_groups):
272
- for param in fp16_group:
273
- param_length = len(param.flatten())
274
- index_length[idx] = (index, index + param_length, group_idx)
275
- index += param_length
276
- idx += 1
277
- for _, name in params2name.items():
278
- idx = name2index[name]
279
- start_idx, end_idx, group_idx = index_length[idx]
280
- name2indices[name] = (start_idx, end_idx, group_idx, None)
281
- return name2indices
282
-
283
- def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
284
- self.is_stage3 = True
285
- fp32_partitioned_groups_flat = torch_opt.fp32_partitioned_groups_flat
286
- return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
287
-
288
-
289
- class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
290
- @staticmethod
291
- def get_group_index(fp32_length, world_size, index):
292
- for i in range(len(fp32_length) - 1):
293
- if fp32_length[i] <= index < fp32_length[i + 1]:
294
- interval_start = fp32_length[i]
295
- interval_length = fp32_length[i + 1] - fp32_length[i]
296
- sub_interval_length = interval_length // world_size
297
- sub_index = (index - interval_start) // sub_interval_length
298
- sub_interval_start = interval_start + sub_index * sub_interval_length
299
- return sub_interval_start, min(sub_index, world_size - 1)
300
- return fp32_length[-1], 0
301
-
302
- def get_param_index(self, params2name, name2index, torch_opt):
303
- padding = torch_opt.groups_padding
304
- world_size = dist.get_world_size()
305
- fp32_length = [0]
306
- for fp32_group_index, single_partition_of_fp32_group in enumerate(torch_opt.single_partition_of_fp32_groups):
307
- fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index])
308
-
309
- bf16_groups = []
310
- name2indices = defaultdict()
311
- index_length = defaultdict()
312
- index = 0
313
- idx = 0
314
- for group_idx, bf16_group in enumerate(torch_opt.bit16_groups):
315
- bf16_groups.extend(bf16_group)
316
- for param in bf16_group:
317
- param_length = len(param.flatten())
318
- group_index, group_with_rank = self.get_group_index(fp32_length, world_size, index)
319
- index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank)
320
- index += param_length
321
- idx += 1
322
- group_length = len(bf16_groups) / len(torch_opt.bit16_groups)
323
- for _, name in params2name.items():
324
- name_index = name2index[name]
325
- start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index]
326
- need_padding = True if group_with_rank == world_size - 1 else False
327
- new_start_idx = start_idx - group_index
328
- new_end_idx = end_idx - group_index
329
- if need_padding and group_length - 1 <= name_index <= len(bf16_groups) - 1 and name_index % (
330
- group_length - 1) == 0:
331
- new_end_idx -= padding[int(name_index // (group_length - 1) - 1)]
332
- name2indices[name] = (new_start_idx, new_end_idx, group_idx, group_with_rank)
333
- return name2indices
334
-
335
- def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
336
- fp32_partitioned_groups_flat = torch_opt.single_partition_of_fp32_groups
337
- return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
338
-
339
-
340
- class DummyOptimizerMon(OptimizerMon):
341
- def fetch_mv(self, monitor, torch_opt, params2name):
342
- return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
293
+
294
+ class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon):
295
+ def __init__(self, torch_opt):
296
+ super().__init__(torch_opt)
297
+ self.stage = '3'
298
+ self.bit16_groups = torch_opt.fp16_groups
299
+ self.fp32_flat_groups = torch_opt.fp32_partitioned_groups_flat
300
+ self.param2group = self.get_group_index()
301
+
302
+ def param_not_in_partition(self, lp_param, group_idx):
303
+ """Each param partioned across all zero ranks"""
304
+ return False
305
+
306
+ def get_position(self, lp_param, group_idx):
307
+ param_id = self.torch_opt.get_param_id(lp_param)
308
+ return self.torch_opt.grad_position[param_id][1:]
309
+
310
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
311
+ return self.torch_opt.averaged_gradients[group_idx][param_id]
343
312
 
344
313
 
345
314
  class OptimizerMonFactory:
346
315
  _optimizer_mon_map = {
347
- "FP32Optimizer": MegatronFP32OptimizerMon,
316
+ "FP32Optimizer": OptimizerMon,
348
317
  "Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
349
318
  "DistributedOptimizer": MegatronDistributedOptimizerMon,
319
+ "SwapDistributedOptimizer": MegatronDistributedOptimizerMon,
350
320
  "ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
321
+ "ChainedSwapDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
351
322
  "ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
352
323
  "BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon,
353
324
  "DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon,
354
325
  "DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon,
355
- "Adam": DummyOptimizerMon
326
+ "Adam": OptimizerMon
356
327
  }
357
328
 
358
329
  @staticmethod
@@ -361,6 +332,7 @@ class OptimizerMonFactory:
361
332
  optimizer_class = optimizer.__class__.__name__
362
333
  if optimizer_class == "ChainedOptimizer":
363
334
  optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__
335
+ logger.info(f'The optimizer type is {optimizer_class}')
364
336
 
365
- optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, DummyOptimizerMon)
366
- return optimizer_mon_class(), optimizer_class
337
+ optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, OptimizerMon)
338
+ return optimizer_mon_class(optimizer)