mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (261) hide show
  1. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
  2. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
  3. msprobe/README.md +57 -21
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +224 -82
  6. msprobe/core/common/decorator.py +50 -0
  7. msprobe/core/common/exceptions.py +5 -3
  8. msprobe/core/common/file_utils.py +274 -40
  9. msprobe/core/common/framework_adapter.py +169 -0
  10. msprobe/core/common/global_lock.py +86 -0
  11. msprobe/core/common/runtime.py +25 -0
  12. msprobe/core/common/utils.py +148 -72
  13. msprobe/core/common_config.py +7 -0
  14. msprobe/core/compare/acc_compare.py +640 -462
  15. msprobe/core/compare/check.py +36 -107
  16. msprobe/core/compare/compare_cli.py +4 -0
  17. msprobe/core/compare/config.py +72 -0
  18. msprobe/core/compare/highlight.py +217 -215
  19. msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
  20. msprobe/core/compare/merge_result/merge_result.py +12 -6
  21. msprobe/core/compare/multiprocessing_compute.py +227 -107
  22. msprobe/core/compare/npy_compare.py +32 -16
  23. msprobe/core/compare/utils.py +218 -244
  24. msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
  25. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  26. msprobe/core/config_check/checkers/base_checker.py +60 -0
  27. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  28. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  29. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  30. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  31. msprobe/core/config_check/checkers/random_checker.py +367 -0
  32. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  33. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  34. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  35. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  36. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  37. msprobe/core/config_check/config_check_cli.py +51 -0
  38. msprobe/core/config_check/config_checker.py +100 -0
  39. msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
  40. msprobe/core/config_check/resource/env.yaml +57 -0
  41. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  42. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  43. msprobe/core/config_check/utils/utils.py +107 -0
  44. msprobe/core/data_dump/api_registry.py +239 -0
  45. msprobe/core/data_dump/data_collector.py +36 -9
  46. msprobe/core/data_dump/data_processor/base.py +74 -53
  47. msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
  48. msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
  49. msprobe/core/data_dump/json_writer.py +146 -57
  50. msprobe/core/debugger/precision_debugger.py +143 -0
  51. msprobe/core/grad_probe/constant.py +2 -1
  52. msprobe/core/grad_probe/grad_compare.py +2 -2
  53. msprobe/core/grad_probe/utils.py +1 -1
  54. msprobe/core/hook_manager.py +242 -0
  55. msprobe/core/monitor/anomaly_processor.py +384 -0
  56. msprobe/core/overflow_check/abnormal_scene.py +2 -0
  57. msprobe/core/service.py +356 -0
  58. msprobe/core/single_save/__init__.py +0 -0
  59. msprobe/core/single_save/single_comparator.py +243 -0
  60. msprobe/core/single_save/single_saver.py +157 -0
  61. msprobe/docs/01.installation.md +6 -5
  62. msprobe/docs/02.config_introduction.md +89 -30
  63. msprobe/docs/03.config_examples.md +1 -0
  64. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  65. msprobe/docs/05.data_dump_PyTorch.md +184 -50
  66. msprobe/docs/06.data_dump_MindSpore.md +193 -28
  67. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
  68. msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
  69. msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
  70. msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
  71. msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
  72. msprobe/docs/12.overflow_check_PyTorch.md +5 -3
  73. msprobe/docs/13.overflow_check_MindSpore.md +6 -4
  74. msprobe/docs/14.data_parse_PyTorch.md +4 -10
  75. msprobe/docs/17.grad_probe.md +2 -1
  76. msprobe/docs/18.online_dispatch.md +3 -3
  77. msprobe/docs/19.monitor.md +211 -103
  78. msprobe/docs/21.visualization_PyTorch.md +100 -28
  79. msprobe/docs/22.visualization_MindSpore.md +103 -31
  80. msprobe/docs/23.generate_operator_PyTorch.md +9 -9
  81. msprobe/docs/25.tool_function_introduction.md +23 -22
  82. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  83. msprobe/docs/27.dump_json_instruction.md +278 -8
  84. msprobe/docs/28.debugger_save_instruction.md +111 -20
  85. msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
  86. msprobe/docs/29.data_dump_MSAdapter.md +229 -0
  87. msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
  88. msprobe/docs/31.config_check.md +95 -0
  89. msprobe/docs/32.ckpt_compare.md +69 -0
  90. msprobe/docs/33.generate_operator_MindSpore.md +190 -0
  91. msprobe/docs/34.RL_collect.md +92 -0
  92. msprobe/docs/35.nan_analyze.md +72 -0
  93. msprobe/docs/FAQ.md +3 -11
  94. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  95. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  96. msprobe/docs/img/compare_result.png +0 -0
  97. msprobe/docs/img/merge_result.png +0 -0
  98. msprobe/docs/img/save_compare_result_sample.png +0 -0
  99. msprobe/docs/img/visualization/proxy.png +0 -0
  100. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  101. msprobe/docs/img/visualization/vis_match_info.png +0 -0
  102. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  103. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  104. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  105. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  106. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  107. msprobe/mindspore/__init__.py +3 -3
  108. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
  109. msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
  110. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
  111. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
  112. msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
  113. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  114. msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
  115. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  116. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
  117. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  118. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
  119. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  120. msprobe/mindspore/cell_processor.py +204 -33
  121. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  122. msprobe/mindspore/common/const.py +73 -2
  123. msprobe/mindspore/common/utils.py +157 -29
  124. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  125. msprobe/mindspore/compare/distributed_compare.py +2 -26
  126. msprobe/mindspore/compare/ms_compare.py +18 -398
  127. msprobe/mindspore/compare/ms_graph_compare.py +20 -10
  128. msprobe/mindspore/compare/utils.py +37 -0
  129. msprobe/mindspore/debugger/debugger_config.py +59 -7
  130. msprobe/mindspore/debugger/precision_debugger.py +83 -90
  131. msprobe/mindspore/dump/cell_dump_process.py +902 -0
  132. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
  133. msprobe/mindspore/dump/dump_tool_factory.py +18 -8
  134. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  135. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  136. msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
  137. msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
  138. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  139. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  140. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
  141. msprobe/mindspore/dump/jit_dump.py +35 -27
  142. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  143. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  144. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
  145. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
  146. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  147. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
  148. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
  149. msprobe/mindspore/grad_probe/global_context.py +9 -2
  150. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
  151. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  152. msprobe/mindspore/grad_probe/hook.py +2 -4
  153. msprobe/mindspore/mindspore_service.py +111 -0
  154. msprobe/mindspore/monitor/common_func.py +52 -0
  155. msprobe/mindspore/monitor/data_writers.py +237 -0
  156. msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
  157. msprobe/mindspore/monitor/features.py +13 -1
  158. msprobe/mindspore/monitor/module_hook.py +568 -444
  159. msprobe/mindspore/monitor/optimizer_collect.py +331 -0
  160. msprobe/mindspore/monitor/utils.py +71 -9
  161. msprobe/mindspore/ms_config.py +16 -15
  162. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
  163. msprobe/mindspore/task_handler_factory.py +5 -2
  164. msprobe/msprobe.py +19 -0
  165. msprobe/nan_analyze/__init__.py +14 -0
  166. msprobe/nan_analyze/analyzer.py +255 -0
  167. msprobe/nan_analyze/graph.py +189 -0
  168. msprobe/nan_analyze/utils.py +211 -0
  169. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  170. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
  171. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  172. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
  173. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
  174. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
  175. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
  176. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
  177. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
  178. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
  179. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  180. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  181. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  182. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
  183. msprobe/pytorch/attl_manager.py +65 -0
  184. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
  185. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  186. msprobe/pytorch/common/utils.py +53 -19
  187. msprobe/pytorch/compare/distributed_compare.py +4 -36
  188. msprobe/pytorch/compare/pt_compare.py +13 -84
  189. msprobe/pytorch/compare/utils.py +47 -0
  190. msprobe/pytorch/debugger/debugger_config.py +34 -17
  191. msprobe/pytorch/debugger/precision_debugger.py +50 -96
  192. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  193. msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
  194. msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
  195. msprobe/pytorch/free_benchmark/common/utils.py +1 -1
  196. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
  201. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
  202. msprobe/pytorch/function_factory.py +1 -1
  203. msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
  204. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  205. msprobe/pytorch/hook_module/api_register.py +155 -0
  206. msprobe/pytorch/hook_module/hook_module.py +18 -22
  207. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  208. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  209. msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
  210. msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
  211. msprobe/pytorch/hook_module/utils.py +28 -2
  212. msprobe/pytorch/monitor/csv2tb.py +14 -4
  213. msprobe/pytorch/monitor/data_writers.py +259 -0
  214. msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
  215. msprobe/pytorch/monitor/module_hook.py +336 -241
  216. msprobe/pytorch/monitor/module_metric.py +17 -0
  217. msprobe/pytorch/monitor/optimizer_collect.py +244 -224
  218. msprobe/pytorch/monitor/utils.py +84 -4
  219. msprobe/pytorch/online_dispatch/compare.py +0 -2
  220. msprobe/pytorch/online_dispatch/dispatch.py +13 -2
  221. msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
  222. msprobe/pytorch/online_dispatch/utils.py +3 -0
  223. msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
  224. msprobe/pytorch/parse_tool/lib/utils.py +5 -4
  225. msprobe/pytorch/pt_config.py +16 -11
  226. msprobe/pytorch/pytorch_service.py +70 -0
  227. msprobe/visualization/builder/graph_builder.py +69 -10
  228. msprobe/visualization/builder/msprobe_adapter.py +24 -12
  229. msprobe/visualization/compare/graph_comparator.py +63 -51
  230. msprobe/visualization/compare/mode_adapter.py +22 -20
  231. msprobe/visualization/graph/base_node.py +11 -4
  232. msprobe/visualization/graph/distributed_analyzer.py +1 -10
  233. msprobe/visualization/graph/graph.py +2 -13
  234. msprobe/visualization/graph/node_op.py +1 -2
  235. msprobe/visualization/graph_service.py +251 -104
  236. msprobe/visualization/utils.py +26 -44
  237. msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
  238. msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
  239. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
  240. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  241. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  242. msprobe/mindspore/service.py +0 -543
  243. msprobe/pytorch/hook_module/api_registry.py +0 -166
  244. msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
  245. msprobe/pytorch/hook_module/wrap_functional.py +0 -66
  246. msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
  247. msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
  248. msprobe/pytorch/hook_module/wrap_torch.py +0 -84
  249. msprobe/pytorch/hook_module/wrap_vf.py +0 -60
  250. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  251. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  252. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  253. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  254. msprobe/pytorch/service.py +0 -470
  255. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
  256. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
  257. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
  258. {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
  259. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  260. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  261. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -16,6 +16,7 @@ import re
16
16
 
17
17
  import torch
18
18
 
19
+ from msprobe.pytorch.common.utils import is_float8_tensor
19
20
  from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
20
21
  from msprobe.pytorch.monitor.utils import get_nan_tensor
21
22
 
@@ -143,6 +144,20 @@ class IdentMetric(Metric):
143
144
  return tensor
144
145
 
145
146
 
147
+ @register_config_metric("shape")
148
+ class ShapeMetric(Metric):
149
+ @staticmethod
150
+ def get_metric_value(tensor, eps):
151
+ return tensor.shape
152
+
153
+
154
+ @register_config_metric("dtype")
155
+ class DtypeMetric(Metric):
156
+ @staticmethod
157
+ def get_metric_value(tensor, eps):
158
+ return tensor.dtype
159
+
160
+
146
161
  def get_metrics(ops, tag2tensor, eps, out_dict=None):
147
162
  """
148
163
  :param ops: ["op1", "op2"]
@@ -166,6 +181,8 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None):
166
181
  # Non-tensor in/output filled with nan.
167
182
  out_dict[tag].update({metric_name: get_nan_tensor() for metric_name in ops})
168
183
  continue
184
+ if is_float8_tensor(tensor):
185
+ tensor = tensor.float()
169
186
  for metric_name in ops:
170
187
  fun_metric = config_metric_registry.get(metric_name)
171
188
  out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
@@ -12,129 +12,120 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
-
16
- from collections import defaultdict
15
+ from abc import abstractmethod
17
16
 
18
17
  import torch
19
- import torch.distributed as dist
20
18
 
21
19
  from msprobe.pytorch.common.log import logger
22
- from msprobe.pytorch.monitor.utils import MVResult, MVGradResult
20
+ from msprobe.pytorch.monitor.utils import MVResult
21
+ from msprobe.core.common.const import MonitorConst
23
22
 
24
23
 
25
24
  class OptimizerMon(object):
26
- def __init__(self) -> None:
25
+ def __init__(self, torch_opt) -> None:
27
26
  self.fp16_to_fp32_param = {}
28
- self.is_stage3 = False
27
+ self.torch_opt = torch_opt
28
+ self.state = {}
29
+
30
+ def narrow_from_flatten(self, param, flatten_state):
31
+ return flatten_state
32
+
33
+ def get_state(self, torch_opt):
34
+ if hasattr(torch_opt, 'chained_optimizers'):
35
+ for opt in torch_opt.chained_optimizers:
36
+ self._get_single_state(opt)
37
+ else:
38
+ self._get_single_state(torch_opt)
29
39
 
30
- def fetch_mv(self, monitor, torch_opt, params2name):
31
- pass
40
+ def fetch_grad(self, monitor, params2name):
41
+ if not self.fp16_to_fp32_param:
42
+ self.map_fp16_to_fp32_param(self.torch_opt)
32
43
 
33
- def _fetch_mv_in_adam(self, monitor, torch_opt, params2name):
34
- exp_avg_dict = defaultdict(float)
35
- exp_avg_sq_dict = defaultdict(float)
36
- update_dict = defaultdict()
37
- ratio_dict = defaultdict()
44
+ grad_dict = {}
45
+ first_param = True
38
46
  for param, name in params2name.items():
39
- if param in self.fp16_to_fp32_param:
40
- param = self.fp16_to_fp32_param[param]
41
-
42
- if param in torch_opt.state:
43
- state_param = torch_opt.state.get(param, None)
44
- exp_avg = state_param.get("exp_avg", None)
45
- exp_avg_sq = state_param.get("exp_avg_sq", None)
46
- if exp_avg is None or exp_avg_sq is None:
47
- logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.")
48
- continue
47
+ if monitor.duplicate_param.get(name, False):
48
+ continue
49
+ if self.fp16_to_fp32_param and param not in self.fp16_to_fp32_param:
50
+ continue
51
+ grad = param.main_grad if monitor.params_have_main_grad else param.grad
52
+ element_in_cur_partition = self.fp16_to_fp32_param.get(param, param).numel()
53
+ if param.numel() != element_in_cur_partition:
54
+ if first_param:
55
+ grad = grad.flatten()[-element_in_cur_partition:]
56
+ else: # supposed to be the last one
57
+ grad = grad.flatten()[:element_in_cur_partition]
58
+ first_param = False
59
+
60
+ if grad is None:
61
+ if not monitor.fsdp_wrapped_module:
62
+ logger.warning(f"grad is None: {name}, maybe something wrong happened.")
63
+ continue
64
+ tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
65
+ monitor.register_param_call_id("hook_optimizer", tag)
66
+ grad_dict[tag] = grad
67
+ return grad_dict
68
+
69
+ def map_fp16_to_fp32_param(self, torch_opt):
70
+ pass
71
+
72
+ def fetch_mv(self, monitor, params2name):
73
+ if not self.fp16_to_fp32_param:
74
+ self.map_fp16_to_fp32_param(self.torch_opt)
75
+ if not self.state:
76
+ self.get_state(self.torch_opt)
77
+
78
+ exp_avg_dict = {}
79
+ exp_avg_sq_dict = {}
80
+ update_dict = {}
81
+ ratio_dict = {}
82
+
83
+ if not self.state:
84
+ logger.warning('optimizer state can not accessed')
85
+ return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
86
+
87
+ for lp_param, name in params2name.items():
88
+ if lp_param in self.fp16_to_fp32_param:
89
+ hp_param = self.fp16_to_fp32_param[lp_param]
90
+ else:
91
+ hp_param = lp_param
92
+
93
+ if hp_param in self.state:
94
+ state_param = self.state.get(hp_param, {})
95
+ exp_avg = self.narrow_from_flatten(lp_param, state_param.get("exp_avg", None))
96
+ exp_avg_sq = self.narrow_from_flatten(lp_param, state_param.get("exp_avg_sq", None))
49
97
  if monitor.mv_distribution:
50
98
  exp_avg_dict[name] = exp_avg
51
99
  exp_avg_sq_dict[name] = exp_avg_sq
52
100
  if monitor.mg_direction:
53
101
  exp_avg_dict[name] = exp_avg
54
102
  if monitor.ur_distribution:
55
- if len(torch_opt.param_groups) > 1:
56
- logger.info(f"the length of torch_opt.param_groups is {len(torch_opt.param_groups)}.")
103
+ if len(self.torch_opt.param_groups) > 1:
104
+ logger.info(f"the length of torch_opt.param_groups is {len(self.torch_opt.param_groups)}.")
57
105
  if 'step' in state_param:
58
106
  step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
59
- elif 'step' in torch_opt.param_groups[0]:
60
- step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed
107
+ elif 'step' in self.torch_opt.param_groups[0]:
108
+ step = self.torch_opt.param_groups[0]['step'] # AdamW from mindspeed
61
109
  else:
62
110
  logger.warning(f"step of {name} is None, maybe something wrong happened.")
63
111
  continue
64
- exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
65
- exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
66
- update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
112
+ exp_avg_hat = exp_avg / (1 - self.torch_opt.defaults['betas'][0] ** step)
113
+ exp_avg_sq_hat = exp_avg_sq / (1 - self.torch_opt.defaults['betas'][1] ** step)
114
+ update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + self.torch_opt.defaults['eps'])
67
115
  ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
68
116
  monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
69
117
  monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
70
118
  return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
71
-
72
- def _fetch_mv_grad_in_adam(self, monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat):
73
- exp_avg_dict = defaultdict(float)
74
- exp_avg_sq_dict = defaultdict(float)
75
- update_dict = defaultdict()
76
- ratio_dict = defaultdict()
77
- param2name = defaultdict()
78
- fp32_partitioned_groups_flat_grad = defaultdict()
79
- partition_id = dist.get_rank()
80
-
81
- def get_flatten_grad(self, optimizer, group_idx):
82
- if fp32_partitioned_groups_flat[group_idx].grad is None:
83
- if partition_id == dist.get_world_size() - 1 and not self.is_stage3:
84
- fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned(
85
- optimizer.averaged_gradients[group_idx],
86
- int(optimizer.partition_size[group_idx])
87
- ).to(fp32_partitioned_groups_flat[group_idx].dtype)
88
- else:
89
- fp32_partitioned_groups_flat_grad = optimizer.flatten(
90
- optimizer.averaged_gradients[group_idx]
91
- ).to(fp32_partitioned_groups_flat[group_idx].dtype)
92
- return fp32_partitioned_groups_flat_grad
93
- else:
94
- return fp32_partitioned_groups_flat[group_idx].grad
95
-
96
- for group_idx in range(len(fp32_partitioned_groups_flat)):
97
- fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, torch_opt, group_idx)
98
-
99
- for name in params2name.values():
100
- start_idx, end_idx, group_idx, group_with_rank = name2indices[name]
101
- if group_with_rank != partition_id and isinstance(group_with_rank, int):
102
- continue
103
- fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx]
104
- fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx]
105
- param2name[fp32_param] = name
106
- if not torch_opt.state:
107
- continue
108
- state_param = list(torch_opt.state.values())[group_idx]
109
- exp_avg = state_param.get("exp_avg", None)
110
- exp_avg_sq = state_param.get("exp_avg_sq", None)
111
- if exp_avg is None or exp_avg_sq is None:
112
- logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.")
113
- continue
114
- exp_avg = exp_avg[start_idx: end_idx]
115
- exp_avg_sq = exp_avg_sq[start_idx: end_idx]
116
- if monitor.mv_distribution:
117
- exp_avg_dict[name] = exp_avg
118
- exp_avg_sq_dict[name] = exp_avg_sq
119
- if monitor.mg_direction:
120
- exp_avg_dict[name] = exp_avg
121
- if monitor.ur_distribution:
122
- if 'step' in state_param:
123
- step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
124
- elif 'step' in torch_opt.param_groups[group_idx]:
125
- step = torch_opt.param_groups[group_idx]['step'] # AdamW from mindspeed
126
- else:
127
- logger.warning(f"step of {name} is None, maybe something wrong happened.")
128
- continue
129
- exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
130
- exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
131
- update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
132
- ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
133
- monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
134
- monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
135
- del fp32_partitioned_groups_flat_grad
136
- return MVGradResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict,
137
- grad=param2name)
119
+
120
+ def _get_single_state(self, torch_opt):
121
+ state = {}
122
+ if hasattr(torch_opt, 'param_to_cpu_states_map'):
123
+ state = torch_opt.param_to_cpu_states_map
124
+ elif hasattr(torch_opt, 'state'):
125
+ state = torch_opt.state
126
+ elif hasattr(torch_opt, 'optimizer') and hasattr(torch_opt.optimizer, 'state'):
127
+ state = torch_opt.optimizer.state
128
+ self.state.update(state)
138
129
 
139
130
 
140
131
  class MixPrecisionOptimizerMon(OptimizerMon):
@@ -142,21 +133,14 @@ class MixPrecisionOptimizerMon(OptimizerMon):
142
133
  混合精度优化器监控类。在混合精度训练中监控和管理优化器。
143
134
  混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
144
135
  """
145
-
146
- def map_fp16_tp_fp32_param(self, torch_opt):
136
+ def map_fp16_to_fp32_param(self, torch_opt):
147
137
  for fp16_group, fp32_group in zip(torch_opt.float16_groups, torch_opt.fp32_from_float16_groups):
148
138
  for fp16_param, fp32_param in zip(fp16_group, fp32_group):
149
139
  self.fp16_to_fp32_param[fp16_param] = fp32_param
150
140
 
151
- def fetch_mv(self, monitor, torch_opt, params2name):
152
- if not self.fp16_to_fp32_param and torch_opt is not None:
153
- self.map_fp16_tp_fp32_param(torch_opt)
154
-
155
- return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
156
-
157
141
 
158
142
  class MegatronDistributedOptimizerMon(OptimizerMon):
159
- def map_fp16_tp_fp32_param(self, torch_opt):
143
+ def map_fp16_to_fp32_param(self, torch_opt):
160
144
  if not (hasattr(torch_opt, "model_float16_groups") and
161
145
  hasattr(torch_opt, "shard_fp32_from_float16_groups")):
162
146
  raise Exception(
@@ -167,141 +151,176 @@ class MegatronDistributedOptimizerMon(OptimizerMon):
167
151
  for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
168
152
  self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
169
153
 
170
- def fetch_mv(self, monitor, torch_opt, params2name):
171
- if not self.fp16_to_fp32_param and torch_opt is not None:
172
- self.map_fp16_tp_fp32_param(torch_opt)
173
-
174
- return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
175
154
 
155
+ class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
156
+ def map_fp16_to_fp32_param(self, torch_opt):
157
+ for opt in torch_opt.chained_optimizers:
158
+ super().map_fp16_to_fp32_param(opt)
176
159
 
177
- class MegatronFP32OptimizerMon(OptimizerMon):
178
- def fetch_mv(self, monitor, torch_opt, params2name):
179
- return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
180
160
 
161
+ class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
162
+ def map_fp16_to_fp32_param(self, torch_opt):
163
+ for opt in torch_opt.chained_optimizers:
164
+ super().map_fp16_to_fp32_param(opt)
181
165
 
182
- class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
183
- def fetch_mv(self, monitor, torch_opt, params2name):
184
- if not self.fp16_to_fp32_param and torch_opt is not None:
185
- for opt in torch_opt.chained_optimizers:
186
- self.map_fp16_tp_fp32_param(opt)
187
166
 
188
- if not isinstance(torch_opt, torch.optim.Optimizer):
189
- torch_opt.state = {}
190
- for opt in torch_opt.chained_optimizers:
191
- torch_opt.state.update(opt.optimizer.state)
192
- return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
167
+ class DeepSpeedZeroOptimizerMon(OptimizerMon):
168
+ """
169
+ Base monitor class for DeepSpeed ZeRO optimizer.
170
+ ZeRO stage 0 no partition
171
+ ZeRO stage 1 partitions optimizer states across data parallel processes.
172
+ ZeRO stage 2 additionally partitions gradients.
173
+ ZeRO stage 3 additionally partitions parameters.
174
+
175
+ This class provides monitoring capabilities for ZeRO optimizers by:
176
+ - Handling gradient collection for different ZeRO stages
177
+ - Managing optimizer state access for monitoring
178
+ """
179
+ def __init__(self, torch_opt):
180
+ super().__init__(torch_opt)
181
+ self.stage = ''
182
+ self.bit16_groups = []
183
+ self.fp32_flat_groups = []
184
+ self.param2group = ()
185
+ self.param2index = []
186
+ self.group_offset = {}
187
+
188
+ @abstractmethod
189
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
190
+ raise NotImplementedError
191
+
192
+ def param_not_in_partition(self, lp_param, group_idx):
193
+ param_slice_mapping = self.torch_opt.state_dict()['param_slice_mappings'][group_idx]
194
+ hp_address = param_slice_mapping.get(self.torch_opt.param_names.get(lp_param))
195
+ return hp_address is None
196
+
197
+ def get_position(self, lp_param, group_idx):
198
+ param_slice_mapping = self.torch_opt.state_dict()['param_slice_mappings'][group_idx]
199
+ hp_address = param_slice_mapping.get(self.torch_opt.param_names.get(lp_param))
200
+ return hp_address.start, hp_address.numel
201
+
202
+ def get_group_index(self):
203
+ param2group = {}
204
+ for group_idx, bit16_group in enumerate(self.bit16_groups):
205
+ for param in bit16_group:
206
+ param2group[param] = group_idx
207
+ return param2group
208
+
209
+ def get_param_index(self, lp_param, group_idx):
210
+ if not self.param2index:
211
+ for group in self.bit16_groups:
212
+ param2index = {}
213
+ for index, param in enumerate(group):
214
+ param2index[param] = index
215
+ self.param2index.append(param2index)
216
+
217
+ return self.param2index[group_idx][lp_param]
218
+
219
+ def narrow_from_flatten(self, param, flatten_state):
220
+ if flatten_state is None:
221
+ return flatten_state
222
+ group_idx = self.param2group[param]
223
+ if self.param_not_in_partition(param, group_idx):
224
+ return None
225
+ start, numel = self.get_position(param, group_idx)
226
+ return flatten_state.narrow(0, start, numel)
227
+
228
+ def map_fp16_to_fp32_param(self, torch_opt):
229
+ for group_idx, group in enumerate(self.bit16_groups):
230
+ for param in group:
231
+ self.fp16_to_fp32_param[param] = self.fp32_flat_groups[group_idx]
232
+
233
+ def fetch_grad(self, monitor, params2name):
234
+ grad_dict = {}
235
+ for lp_param, name in params2name.items():
236
+ group_idx = self.param2group[lp_param]
237
+ param_id = self.get_param_index(lp_param, group_idx)
238
+ if self.param_not_in_partition(lp_param, group_idx):
239
+ continue
240
+ if self.stage == '1or2':
241
+ param_id = param_id - self.group_offset[group_idx] - 1
242
+ grad = self.get_grad_for_param(lp_param, group_idx, param_id)
243
+ tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
244
+ monitor.register_param_call_id("hook_optimizer", tag)
245
+ grad_dict[tag] = grad
246
+
247
+ return grad_dict
248
+
249
+
250
+ class DeepSpeedZeroOptimizerStage0Mon(DeepSpeedZeroOptimizerMon):
251
+ def __init__(self, torch_opt):
252
+ super().__init__(torch_opt)
253
+ self.stage = '0'
254
+ self.bit16_groups = torch_opt.bf16_groups
255
+ self.fp32_flat_groups = torch_opt.fp32_groups_flat_partition
256
+ self.param2group = self.get_group_index()
257
+
258
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
259
+ return self.torch_opt.fp32_groups_gradient_dict[group_idx][param_id]
260
+
261
+
262
+ class DeepSpeedZeroOptimizerStage1or2Mon(DeepSpeedZeroOptimizerMon):
263
+ def __init__(self, torch_opt):
264
+ super().__init__(torch_opt)
265
+ self.stage = '1or2'
266
+ self.bit16_groups = torch_opt.bit16_groups
267
+ self.fp32_flat_groups = torch_opt.single_partition_of_fp32_groups
268
+ self.param2group = self.get_group_index()
269
+ self.group_offset = {}
270
+ self.get_group_offset()
271
+
272
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
273
+ if getattr(self.torch_opt, "cpu_offload", False):
274
+ grads = self.torch_opt.single_partition_of_fp32_groups[group_idx].grad
275
+ start, numel = self.get_position(lp_param, group_idx)
276
+ grad = grads.narrow(0, start, numel)
277
+ else:
278
+ grad = self.torch_opt.averaged_gradients[group_idx][param_id]
279
+ return grad
280
+
281
+ def get_group_offset(self):
282
+ for group_idx, group in enumerate(self.bit16_groups):
283
+ self.group_offset[group_idx] = -1
284
+ for lp_param in group:
285
+ if self.param_not_in_partition(lp_param, group_idx):
286
+ self.group_offset[group_idx] = self.get_param_index(lp_param, group_idx)
287
+ else:
288
+ break
193
289
 
194
290
 
195
- class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
196
- def fetch_mv(self, monitor, torch_opt, params2name):
197
- if not self.fp16_to_fp32_param and torch_opt is not None:
198
- for opt in torch_opt.chained_optimizers:
199
- self.map_fp16_tp_fp32_param(opt)
291
+ class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon):
292
+ def __init__(self, torch_opt):
293
+ super().__init__(torch_opt)
294
+ self.stage = '3'
295
+ self.bit16_groups = torch_opt.fp16_groups
296
+ self.fp32_flat_groups = torch_opt.fp32_partitioned_groups_flat
297
+ self.param2group = self.get_group_index()
200
298
 
201
- if not isinstance(torch_opt, torch.optim.Optimizer):
202
- torch_opt.state = {}
203
- for opt in torch_opt.chained_optimizers:
204
- torch_opt.state.update(opt.optimizer.state)
205
- return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
206
-
207
-
208
- class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon):
209
- def fetch_mv(self, monitor, torch_opt, params2name):
210
- return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
211
-
212
-
213
- class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
214
- def get_param_index(self, params2name, name2index, torch_opt):
215
- fp16_groups = torch_opt.fp16_partitioned_groups
216
- name2indices = defaultdict()
217
- index_length = defaultdict()
218
- index = 0
219
- idx = 0
220
- for group_idx, fp16_group in enumerate(fp16_groups):
221
- for param in fp16_group:
222
- param_length = len(param.flatten())
223
- index_length[idx] = (index, index + param_length, group_idx)
224
- index += param_length
225
- idx += 1
226
- for _, name in params2name.items():
227
- idx = name2index[name]
228
- start_idx, end_idx, group_idx = index_length[idx]
229
- name2indices[name] = (start_idx, end_idx, group_idx, None)
230
- return name2indices
231
-
232
- def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
233
- self.is_stage3 = True
234
- fp32_partitioned_groups_flat = torch_opt.fp32_partitioned_groups_flat
235
- return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
236
-
237
-
238
- class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
239
- @staticmethod
240
- def get_group_index(fp32_length, world_size, index):
241
- for i in range(len(fp32_length) - 1):
242
- if fp32_length[i] <= index < fp32_length[i + 1]:
243
- interval_start = fp32_length[i]
244
- interval_length = fp32_length[i + 1] - fp32_length[i]
245
- sub_interval_length = interval_length // world_size
246
- sub_index = (index - interval_start) // sub_interval_length
247
- sub_interval_start = interval_start + sub_index * sub_interval_length
248
- return sub_interval_start, min(sub_index, world_size - 1)
249
- return fp32_length[-1], 0
250
-
251
- def get_param_index(self, params2name, name2index, torch_opt):
252
- padding = torch_opt.groups_padding
253
- world_size = dist.get_world_size()
254
- fp32_length = [0]
255
- for fp32_group_index, single_partition_of_fp32_group in enumerate(torch_opt.single_partition_of_fp32_groups):
256
- fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index])
257
-
258
- bf16_groups = []
259
- name2indices = defaultdict()
260
- index_length = defaultdict()
261
- index = 0
262
- idx = 0
263
- for group_idx, bf16_group in enumerate(torch_opt.bit16_groups):
264
- bf16_groups.extend(bf16_group)
265
- for param in bf16_group:
266
- param_length = len(param.flatten())
267
- group_index, group_with_rank = self.get_group_index(fp32_length, world_size, index)
268
- index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank)
269
- index += param_length
270
- idx += 1
271
- group_length = len(bf16_groups) / len(torch_opt.bit16_groups)
272
- for _, name in params2name.items():
273
- name_index = name2index[name]
274
- start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index]
275
- need_padding = True if group_with_rank == world_size - 1 else False
276
- new_start_idx = start_idx - group_index
277
- new_end_idx = end_idx - group_index
278
- if need_padding and group_length - 1 <= name_index <= len(bf16_groups) - 1 and name_index % (
279
- group_length - 1) == 0:
280
- new_end_idx -= padding[int(name_index // (group_length - 1) - 1)]
281
- name2indices[name] = (new_start_idx, new_end_idx, group_idx, group_with_rank)
282
- return name2indices
283
-
284
- def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
285
- fp32_partitioned_groups_flat = torch_opt.single_partition_of_fp32_groups
286
- return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
287
-
288
-
289
- class DummyOptimizerMon(OptimizerMon):
290
- def fetch_mv(self, monitor, torch_opt, params2name):
291
- return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
299
+ def param_not_in_partition(self, param, group_index):
300
+ """Each param partioned across all zero ranks"""
301
+ return False
302
+
303
+ def get_position(self, lp_param, group_idx):
304
+ param_id = self.torch_opt.get_param_id(lp_param)
305
+ return self.torch_opt.grad_position[param_id][1:]
306
+
307
+ def get_grad_for_param(self, lp_param, group_idx, param_id):
308
+ return self.torch_opt.averaged_gradients[group_idx][param_id]
292
309
 
293
310
 
294
311
  class OptimizerMonFactory:
295
312
  _optimizer_mon_map = {
296
- "FP32Optimizer": MegatronFP32OptimizerMon,
313
+ "FP32Optimizer": OptimizerMon,
297
314
  "Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
298
315
  "DistributedOptimizer": MegatronDistributedOptimizerMon,
316
+ "SwapDistributedOptimizer": MegatronDistributedOptimizerMon,
299
317
  "ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
318
+ "ChainedSwapDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
300
319
  "ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
301
320
  "BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon,
302
321
  "DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon,
303
322
  "DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon,
304
- "Adam": DummyOptimizerMon
323
+ "Adam": OptimizerMon
305
324
  }
306
325
 
307
326
  @staticmethod
@@ -310,6 +329,7 @@ class OptimizerMonFactory:
310
329
  optimizer_class = optimizer.__class__.__name__
311
330
  if optimizer_class == "ChainedOptimizer":
312
331
  optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__
332
+ logger.info(f'The optimizer type is {optimizer_class}')
313
333
 
314
- optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, DummyOptimizerMon)
315
- return optimizer_mon_class(), optimizer_class
334
+ optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, OptimizerMon)
335
+ return optimizer_mon_class(optimizer)