mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (278) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
  2. mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
  3. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/README.md +84 -18
  6. msprobe/__init__.py +16 -1
  7. msprobe/config.json +1 -5
  8. msprobe/core/advisor/advisor.py +16 -11
  9. msprobe/core/advisor/advisor_const.py +6 -7
  10. msprobe/core/advisor/advisor_result.py +12 -12
  11. msprobe/core/common/const.py +164 -3
  12. msprobe/core/common/exceptions.py +26 -4
  13. msprobe/core/common/file_utils.py +196 -27
  14. msprobe/core/common/inplace_op_checker.py +53 -0
  15. msprobe/core/common/inplace_ops.yaml +251 -0
  16. msprobe/core/common/log.py +46 -18
  17. msprobe/core/common/utils.py +308 -209
  18. msprobe/core/common_config.py +60 -38
  19. msprobe/core/compare/acc_compare.py +332 -94
  20. msprobe/core/compare/check.py +104 -22
  21. msprobe/core/compare/compare_cli.py +42 -5
  22. msprobe/core/compare/highlight.py +162 -57
  23. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  24. msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
  25. msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
  26. msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
  27. msprobe/core/compare/multiprocessing_compute.py +33 -8
  28. msprobe/core/compare/npy_compare.py +73 -29
  29. msprobe/core/compare/utils.py +306 -247
  30. msprobe/core/data_dump/data_collector.py +44 -43
  31. msprobe/core/data_dump/data_processor/base.py +88 -35
  32. msprobe/core/data_dump/data_processor/factory.py +20 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +14 -8
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +180 -66
  35. msprobe/core/data_dump/json_writer.py +63 -42
  36. msprobe/core/data_dump/scope.py +143 -48
  37. msprobe/core/grad_probe/constant.py +31 -13
  38. msprobe/core/grad_probe/grad_compare.py +20 -4
  39. msprobe/core/grad_probe/utils.py +44 -3
  40. msprobe/core/overflow_check/abnormal_scene.py +185 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +29 -9
  48. msprobe/docs/02.config_introduction.md +83 -84
  49. msprobe/docs/03.config_examples.md +3 -20
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +143 -13
  52. msprobe/docs/06.data_dump_MindSpore.md +197 -88
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +69 -46
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +52 -17
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +51 -15
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +187 -99
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +253 -31
  58. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +60 -55
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +159 -0
  62. msprobe/docs/17.grad_probe.md +19 -22
  63. msprobe/docs/18.online_dispatch.md +89 -0
  64. msprobe/docs/19.monitor.md +468 -0
  65. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  66. msprobe/docs/21.visualization_PyTorch.md +386 -0
  67. msprobe/docs/22.visualization_MindSpore.md +384 -0
  68. msprobe/docs/23.tool_function_introduction.md +28 -0
  69. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +25 -10
  70. msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
  71. msprobe/docs/img/compare_result.png +0 -0
  72. msprobe/docs/img/monitor/cpu_info.png +0 -0
  73. msprobe/docs/img/ms_dump.png +0 -0
  74. msprobe/docs/img/ms_layer.png +0 -0
  75. msprobe/docs/img/pt_dump.png +0 -0
  76. msprobe/mindspore/__init__.py +16 -0
  77. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +130 -138
  78. msprobe/mindspore/api_accuracy_checker/api_info.py +27 -5
  79. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  80. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  81. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  82. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
  83. msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
  84. msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
  85. msprobe/mindspore/api_accuracy_checker/main.py +27 -3
  86. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
  87. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
  88. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  89. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  90. msprobe/mindspore/cell_processor.py +58 -13
  91. msprobe/mindspore/common/const.py +35 -13
  92. msprobe/mindspore/common/log.py +5 -9
  93. msprobe/mindspore/common/utils.py +60 -5
  94. msprobe/mindspore/compare/distributed_compare.py +15 -28
  95. msprobe/mindspore/compare/ms_compare.py +319 -158
  96. msprobe/mindspore/compare/ms_graph_compare.py +99 -49
  97. msprobe/mindspore/debugger/debugger_config.py +20 -14
  98. msprobe/mindspore/debugger/precision_debugger.py +43 -13
  99. msprobe/mindspore/dump/dump_tool_factory.py +18 -1
  100. msprobe/mindspore/dump/hook_cell/api_registry.py +23 -3
  101. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +203 -0
  102. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +107 -10
  103. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  104. msprobe/mindspore/dump/jit_dump.py +56 -20
  105. msprobe/mindspore/dump/kernel_graph_dump.py +19 -5
  106. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -6
  107. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  108. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  109. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +162 -41
  110. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  111. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -1
  112. msprobe/mindspore/free_benchmark/common/utils.py +37 -8
  113. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  114. msprobe/mindspore/free_benchmark/handler/base_handler.py +20 -5
  115. msprobe/mindspore/free_benchmark/handler/check_handler.py +21 -7
  116. msprobe/mindspore/free_benchmark/handler/fix_handler.py +18 -3
  117. msprobe/mindspore/free_benchmark/handler/handler_factory.py +21 -6
  118. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +23 -8
  119. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +29 -5
  120. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +25 -10
  121. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +45 -19
  122. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +29 -8
  123. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  124. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +22 -7
  125. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +17 -2
  126. msprobe/mindspore/grad_probe/global_context.py +44 -14
  127. msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
  128. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  129. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  130. msprobe/mindspore/grad_probe/hook.py +24 -10
  131. msprobe/mindspore/grad_probe/utils.py +18 -5
  132. msprobe/mindspore/ms_config.py +22 -15
  133. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -6
  134. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  135. msprobe/mindspore/runtime.py +15 -0
  136. msprobe/mindspore/service.py +75 -150
  137. msprobe/mindspore/task_handler_factory.py +15 -0
  138. msprobe/msprobe.py +24 -7
  139. msprobe/pytorch/__init__.py +23 -3
  140. msprobe/pytorch/api_accuracy_checker/common/config.py +81 -2
  141. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  142. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  143. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +50 -25
  144. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  145. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  146. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  147. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  148. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  149. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
  150. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  151. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +73 -33
  152. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +44 -18
  153. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +32 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +122 -172
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +158 -4
  156. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +30 -24
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +68 -31
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +115 -0
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +26 -9
  161. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  162. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  163. msprobe/pytorch/bench_functions/__init__.py +18 -3
  164. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  165. msprobe/pytorch/bench_functions/confusion_transpose.py +20 -1
  166. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  167. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  168. msprobe/pytorch/bench_functions/linear.py +15 -0
  169. msprobe/pytorch/bench_functions/matmul_backward.py +33 -6
  170. msprobe/pytorch/bench_functions/npu_fusion_attention.py +280 -157
  171. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  172. msprobe/pytorch/bench_functions/rotary_mul.py +32 -9
  173. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  174. msprobe/pytorch/bench_functions/swiglu.py +29 -6
  175. msprobe/pytorch/common/__init__.py +15 -0
  176. msprobe/pytorch/common/log.py +18 -6
  177. msprobe/pytorch/common/parse_json.py +31 -16
  178. msprobe/pytorch/common/utils.py +96 -40
  179. msprobe/pytorch/compare/distributed_compare.py +13 -14
  180. msprobe/pytorch/compare/match.py +15 -0
  181. msprobe/pytorch/compare/pt_compare.py +44 -10
  182. msprobe/pytorch/debugger/debugger_config.py +69 -52
  183. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  184. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  185. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  186. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  187. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  188. msprobe/pytorch/free_benchmark/common/enums.py +43 -0
  189. msprobe/pytorch/free_benchmark/common/params.py +23 -1
  190. msprobe/pytorch/free_benchmark/common/utils.py +43 -5
  191. msprobe/pytorch/free_benchmark/compare/grad_saver.py +47 -9
  192. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +17 -0
  193. msprobe/pytorch/free_benchmark/main.py +19 -4
  194. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  195. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  196. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +18 -1
  197. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +21 -4
  198. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +28 -2
  199. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +19 -0
  200. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  201. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  202. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  203. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +65 -16
  204. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  205. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +21 -5
  206. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  207. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  208. msprobe/pytorch/function_factory.py +17 -2
  209. msprobe/pytorch/functional/module_dump.py +84 -0
  210. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  211. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  212. msprobe/pytorch/hook_module/__init__.py +16 -1
  213. msprobe/pytorch/hook_module/api_registry.py +13 -8
  214. msprobe/pytorch/hook_module/hook_module.py +17 -19
  215. msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
  216. msprobe/pytorch/hook_module/utils.py +4 -6
  217. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  218. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  219. msprobe/pytorch/hook_module/wrap_functional.py +21 -20
  220. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  221. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  222. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  223. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  224. msprobe/pytorch/module_processer.py +18 -6
  225. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  226. msprobe/pytorch/monitor/anomaly_detect.py +340 -0
  227. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  228. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  229. msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
  230. msprobe/pytorch/monitor/features.py +108 -0
  231. msprobe/pytorch/monitor/module_hook.py +870 -0
  232. msprobe/pytorch/monitor/module_metric.py +193 -0
  233. msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
  234. msprobe/pytorch/monitor/optimizer_collect.py +295 -0
  235. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  236. msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
  237. msprobe/pytorch/monitor/utils.py +250 -0
  238. msprobe/pytorch/monitor/visualizer.py +59 -0
  239. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  240. msprobe/pytorch/online_dispatch/compare.py +38 -48
  241. msprobe/pytorch/online_dispatch/dispatch.py +50 -25
  242. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  243. msprobe/pytorch/online_dispatch/single_compare.py +60 -39
  244. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +9 -1
  245. msprobe/pytorch/online_dispatch/utils.py +48 -23
  246. msprobe/pytorch/parse.py +15 -0
  247. msprobe/pytorch/parse_tool/cli.py +5 -6
  248. msprobe/pytorch/parse_tool/lib/compare.py +19 -26
  249. msprobe/pytorch/parse_tool/lib/config.py +1 -1
  250. msprobe/pytorch/parse_tool/lib/parse_tool.py +4 -2
  251. msprobe/pytorch/parse_tool/lib/utils.py +40 -55
  252. msprobe/pytorch/parse_tool/lib/visualization.py +3 -1
  253. msprobe/pytorch/pt_config.py +192 -40
  254. msprobe/pytorch/service.py +110 -35
  255. msprobe/visualization/__init__.py +14 -0
  256. msprobe/visualization/builder/__init__.py +14 -0
  257. msprobe/visualization/builder/graph_builder.py +165 -0
  258. msprobe/visualization/builder/msprobe_adapter.py +205 -0
  259. msprobe/visualization/compare/__init__.py +14 -0
  260. msprobe/visualization/compare/graph_comparator.py +130 -0
  261. msprobe/visualization/compare/mode_adapter.py +211 -0
  262. msprobe/visualization/graph/__init__.py +14 -0
  263. msprobe/visualization/graph/base_node.py +124 -0
  264. msprobe/visualization/graph/graph.py +200 -0
  265. msprobe/visualization/graph/node_colors.py +95 -0
  266. msprobe/visualization/graph/node_op.py +39 -0
  267. msprobe/visualization/graph_service.py +214 -0
  268. msprobe/visualization/utils.py +232 -0
  269. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  270. msprobe/docs/04.acl_config_examples.md +0 -76
  271. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -43
  272. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -107
  273. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  274. msprobe/pytorch/functional/dump_module.py +0 -39
  275. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
  276. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
  277. /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
  278. /msprobe/pytorch/{functional/data_processor.py → monitor/distributed/__init__.py} +0 -0
@@ -0,0 +1,193 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import itertools
16
+ import math
17
+ import re
18
+ import statistics
19
+
20
+ import torch
21
+
22
+ from msprobe.core.common.const import MonitorConst
23
+ from msprobe.pytorch.monitor.features import square_sum, get_max, get_min, get_zeros, get_nans, get_norm, get_mean
24
+ from msprobe.core.common.log import logger
25
+
26
+
27
+ def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
28
+ if rank is None:
29
+ return f"{module_or_param_name}/{tag}"
30
+ else:
31
+ return f"{module_or_param_name}/rank{rank}/{tag}"
32
+
33
+
34
+ def squash_param_name(param_name):
35
+ name = ''
36
+ for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']:
37
+ match = re.findall(pattern, param_name)
38
+ if match:
39
+ name += match[0]
40
+ break
41
+ if name == '':
42
+ name = param_name
43
+ return name
44
+
45
+
46
+ # 用于存储所有metric实现类的注册表
47
+ config_metric_registry = {}
48
+
49
+
50
+ def register_config_metric(key, cls=None):
51
+ """装饰器 用于注册Metric的实现类"""
52
+ if cls is None:
53
+ # 无参数时,返回装饰器函数
54
+ return lambda cls_: register_config_metric(key, cls_)
55
+ config_metric_registry[key] = cls()
56
+ return cls
57
+
58
+
59
+ class TensorMetrics:
60
+ fun_map = {"norm": get_norm, "max": get_max, "min": get_min, "mean": get_mean}
61
+
62
+ def __init__(self) -> None:
63
+ self.metrics = {} # tensor_tag --> []
64
+ self.cur_idx = {}
65
+
66
+ def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank, eps=1e-8):
67
+ """get stats and insert into metrics dictionary"""
68
+ prefix = get_summary_writer_tag_name(module_name, tensor_name, rank)
69
+ for stat_op in stat_ops:
70
+ y = TensorMetrics.fun_map[stat_op](tensor)
71
+ key = f"{prefix}_{stat_op}"
72
+ if key not in self.metrics:
73
+ self.metrics[key] = []
74
+ self.cur_idx[key] = 0
75
+ self.metrics[key].append(y)
76
+
77
+ def flush(self, tb_writer):
78
+ for key, metric_list in self.metrics.items():
79
+ start = self.cur_idx[key]
80
+ for v in metric_list[start:]:
81
+ tb_writer.add_scalar(key, v.item(), global_step=self.cur_idx[key])
82
+ self.cur_idx[key] += 1
83
+
84
+
85
+ class Metric(object):
86
+ @staticmethod
87
+ def get_metric_value(tensor, eps):
88
+ NotImplementedError
89
+
90
+ def get_metric(self, tensor, eps):
91
+ try:
92
+ return self.get_metric_value(tensor, eps)
93
+ except RuntimeError as e:
94
+ return torch.tensor(torch.nan).to(tensor.device)
95
+
96
+
97
+ @register_config_metric("min")
98
+ class MinMetric(Metric):
99
+ @staticmethod
100
+ def get_metric_value(tensor, eps):
101
+ return get_min(tensor)
102
+
103
+
104
+ @register_config_metric("mean")
105
+ class MeanMetric(Metric):
106
+ @staticmethod
107
+ def get_metric_value(tensor, eps):
108
+ return get_mean(tensor)
109
+
110
+
111
+ @register_config_metric("max")
112
+ class MaxMetric(Metric):
113
+ @staticmethod
114
+ def get_metric_value(tensor, eps):
115
+ return get_max(tensor)
116
+
117
+
118
+ @register_config_metric("norm")
119
+ class NormMetric(Metric):
120
+ @staticmethod
121
+ def get_metric_value(tensor, eps):
122
+ return get_norm(tensor)
123
+
124
+
125
+ @register_config_metric("zeros")
126
+ class ZerosMetric(Metric):
127
+ @staticmethod
128
+ def get_metric_value(tensor, eps):
129
+ return get_zeros(tensor, eps)
130
+
131
+
132
+ @register_config_metric("nans")
133
+ class NaNsMetric(Metric):
134
+ @staticmethod
135
+ def get_metric_value(tensor, eps):
136
+ return get_nans(tensor)
137
+
138
+
139
+ @register_config_metric("id")
140
+ class IdentMetric(Metric):
141
+ @staticmethod
142
+ def get_metric_value(tensor, eps):
143
+ if tensor.dim() != 0:
144
+ return None
145
+ return tensor
146
+
147
+
148
+ def get_metrics(ops, tag2tensor, eps, out_dict=None):
149
+ if out_dict is None:
150
+ out_dict = {}
151
+ for tag, tensor in tag2tensor.items():
152
+ if tag not in out_dict:
153
+ out_dict[tag] = {}
154
+ for metric_name in ops:
155
+ fun_metric = config_metric_registry.get(metric_name)
156
+ out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
157
+ return out_dict
158
+
159
+
160
+ def write_metrics_base(ops, summary_writer, metric_value, step, prefix=''):
161
+ if not metric_value:
162
+ return
163
+ tensors = []
164
+ tags = list(itertools.product(metric_value.keys(), ops))
165
+ for op2tensor in metric_value.values():
166
+ tensors.extend(op2tensor.values())
167
+ with torch.no_grad():
168
+ metric_list = torch.stack(tensors).cpu()
169
+ for tag, metric in zip(tags, metric_list):
170
+ summary_writer.add_scalar(tag, metric, step)
171
+
172
+
173
+ def write_metrics_csv(ops, summary_writer, metric_value, step, prefix=''):
174
+ write_metrics_base(ops, summary_writer, metric_value, step, prefix='')
175
+
176
+ if not summary_writer.header:
177
+ # 前向的norm用input.ops_和output.ops_,反向的用input_grad.ops_和output_grad.ops_
178
+ if prefix in {"actv", "actv_grad"}:
179
+ if prefix == "actv":
180
+ input_and_output = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT]
181
+ else:
182
+ input_and_output = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
183
+ ops_ = [MonitorConst.DOT.join(i[::-1]) for i in itertools.product(ops, input_and_output)]
184
+ summary_writer.header = ["module_name", "step", *ops_]
185
+ else:
186
+ summary_writer.header = ["param_name", "step", *ops]
187
+
188
+ for key in metric_value.keys():
189
+ if MonitorConst.VPP_SEP in key:
190
+ summary_writer.header.insert(0, 'vpp_stage')
191
+ break
192
+ summary_writer.write_csv(prefix, step)
193
+ summary_writer.header = []
@@ -0,0 +1,93 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ import abc
18
+ import torch
19
+
20
+ from msprobe.core.common.log import logger
21
+
22
+ # 用于存储所有validator实现类的注册表
23
+ config_validator_registry = {}
24
+
25
+
26
+ def register_config_validator(cls):
27
+ """装饰器 用于注册ConfigValidator的实现类"""
28
+ config_validator_registry[cls.__name__] = cls
29
+ return cls
30
+
31
+
32
+ class ConfigValidator(metaclass=abc.ABCMeta):
33
+ @abc.abstractmethod
34
+ def check_pattern_match(self, config_spec: str):
35
+ pass
36
+
37
+ @abc.abstractmethod
38
+ def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
39
+ pass
40
+
41
+
42
+ @register_config_validator
43
+ class TensorValidator(ConfigValidator):
44
+ def check_pattern_match(self, config_spec: str):
45
+ pattern = re.compile(r"tensor")
46
+ return pattern.match(config_spec)
47
+
48
+ def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
49
+ if not torch.is_tensor(actual_data):
50
+ raise ValueError(
51
+ f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.")
52
+
53
+
54
+ @register_config_validator
55
+ class TupleValidator(ConfigValidator):
56
+ def check_pattern_match(self, config_spec: str):
57
+ pattern = re.compile(r"tuple\[(\d+)\]:?(\d+)?")
58
+ return pattern.match(config_spec)
59
+
60
+ def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
61
+ length, index = pattern_match.groups()
62
+ if index is None:
63
+ index = 0
64
+ length, index = int(length), int(index)
65
+
66
+ if not (0 <= index < length):
67
+ raise ValueError(
68
+ f"Format of {module_name} {data_type} in config.json does not match the required format 'tuple[x]:y'."
69
+ f"y must be greater than or equal to 0 and less than x.")
70
+ if not isinstance(actual_data, tuple):
71
+ raise ValueError(
72
+ f"Type of {module_name} {data_type} does not match spec of config.json, should be tuple, please check.")
73
+ if len(actual_data) != length:
74
+ raise ValueError(
75
+ f"Length of {module_name} {data_type} does not match spec of config.json, should be {length}, "
76
+ f"actual is {len(actual_data)} please check.")
77
+ return index
78
+
79
+
80
+ def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str):
81
+ focused_col = None
82
+ for _, validator_cls in config_validator_registry.items():
83
+ config_validator = validator_cls()
84
+ pattern_match = config_validator.check_pattern_match(config_spec)
85
+ if pattern_match:
86
+ try:
87
+ focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match)
88
+ except ValueError as e:
89
+ logger.warning(f"config spec validate failed: {str(e)}")
90
+ return focused_col
91
+ logger.warning(f"config spec in {module_name} {data_type} not supported, "
92
+ f"expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.")
93
+ return focused_col
@@ -0,0 +1,295 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import ABC, abstractmethod
17
+ from collections import defaultdict
18
+
19
+ import torch
20
+ import torch.distributed as dist
21
+
22
+ from msprobe.core.common.log import logger
23
+ from msprobe.pytorch.monitor.utils import MVResult, MVGradResult
24
+
25
+
26
+ class OptimizerMon(object):
27
+ wrapped_optimizer = None
28
+
29
+ def __init__(self) -> None:
30
+ self.fp16_to_fp32_param = {}
31
+ self.is_stage3 = False
32
+
33
+ @classmethod
34
+ def set_wrapped_optimizer(cls, wrapped_optimizer):
35
+ cls.wrapped_optimizer = wrapped_optimizer
36
+
37
+ def fetch_mv(self, monitor, torch_opt, params2name):
38
+ pass
39
+
40
+ def _fetch_mv_in_adam(self, monitor, torch_opt, params2name):
41
+ exp_avg_dict = defaultdict(float)
42
+ exp_avg_sq_dict = defaultdict(float)
43
+ update_dict = defaultdict()
44
+ ratio_dict = defaultdict()
45
+ for param, name in params2name.items():
46
+ if param in self.fp16_to_fp32_param:
47
+ param = self.fp16_to_fp32_param[param]
48
+
49
+ if param in torch_opt.state:
50
+ state_param = torch_opt.state.get(param, None)
51
+ exp_avg = state_param.get("exp_avg", None)
52
+ exp_avg_sq = state_param.get("exp_avg_sq", None)
53
+ if exp_avg is None or exp_avg_sq is None:
54
+ logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.")
55
+ continue
56
+ if monitor.mv_distribution:
57
+ exp_avg_dict[name] = exp_avg
58
+ exp_avg_sq_dict[name] = exp_avg_sq
59
+ if monitor.mg_direction:
60
+ exp_avg_dict[name] = exp_avg
61
+ if monitor.ur_distribution:
62
+ if len(torch_opt.param_groups) > 1:
63
+ logger.info(f"the length of torch_opt.param_groups is {len(torch_opt.param_groups)}.")
64
+ if 'step' in state_param:
65
+ step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
66
+ elif 'step' in torch_opt.param_groups[0]:
67
+ step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed
68
+ else:
69
+ logger.warning(f"step of {name} is None, maybe something wrong happened.")
70
+ continue
71
+ exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
72
+ exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
73
+ update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
74
+ ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
75
+ monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
76
+ monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
77
+ return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
78
+
79
+ def _fetch_mv_grad_in_adam(self, monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat):
80
+ exp_avg_dict = defaultdict(float)
81
+ exp_avg_sq_dict = defaultdict(float)
82
+ update_dict = defaultdict()
83
+ ratio_dict = defaultdict()
84
+ param2name = defaultdict()
85
+ fp32_partitioned_groups_flat_grad = defaultdict()
86
+ mix_prec_opt = OptimizerMon.wrapped_optimizer
87
+ partition_id = dist.get_rank()
88
+
89
+ def get_flatten_grad(self, optimizer, group_idx):
90
+ if fp32_partitioned_groups_flat[group_idx].grad is None:
91
+ if partition_id == dist.get_world_size() - 1 and not self.is_stage3:
92
+ fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned(
93
+ optimizer.averaged_gradients[group_idx],
94
+ int(optimizer.partition_size[group_idx])
95
+ ).to(fp32_partitioned_groups_flat[group_idx].dtype)
96
+ else:
97
+ fp32_partitioned_groups_flat_grad = optimizer.flatten(
98
+ optimizer.averaged_gradients[group_idx]
99
+ ).to(fp32_partitioned_groups_flat[group_idx].dtype)
100
+ return fp32_partitioned_groups_flat_grad
101
+ else:
102
+ return fp32_partitioned_groups_flat[group_idx].grad
103
+
104
+ for group_idx in range(len(fp32_partitioned_groups_flat)):
105
+ fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, mix_prec_opt, group_idx)
106
+
107
+ for name in params2name.values():
108
+ start_idx, end_idx, group_idx, group_with_rank = name2indices[name]
109
+ if group_with_rank != partition_id and isinstance(group_with_rank, int):
110
+ continue
111
+ fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx]
112
+ fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx]
113
+ param2name[fp32_param] = name
114
+ if not mix_prec_opt.state:
115
+ continue
116
+ state_param = list(mix_prec_opt.state.values())[group_idx]
117
+ exp_avg = state_param.get("exp_avg", None)
118
+ exp_avg_sq = state_param.get("exp_avg_sq", None)
119
+ if exp_avg is None or exp_avg_sq is None:
120
+ logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.")
121
+ continue
122
+ exp_avg = exp_avg[start_idx: end_idx]
123
+ exp_avg_sq = exp_avg_sq[start_idx: end_idx]
124
+ if monitor.mv_distribution:
125
+ exp_avg_dict[name] = exp_avg
126
+ exp_avg_sq_dict[name] = exp_avg_sq
127
+ if monitor.mg_direction:
128
+ exp_avg_dict[name] = exp_avg
129
+ if monitor.ur_distribution:
130
+ if 'step' in state_param:
131
+ step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
132
+ elif 'step' in torch_opt.param_groups[group_idx]:
133
+ step = torch_opt.param_groups[group_idx]['step'] # AdamW from mindspeed
134
+ else:
135
+ logger.warning(f"step of {name} is None, maybe something wrong happened.")
136
+ continue
137
+ exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
138
+ exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
139
+ update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
140
+ ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
141
+ monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
142
+ monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
143
+ del fp32_partitioned_groups_flat_grad
144
+ return MVGradResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict,
145
+ grad=param2name)
146
+
147
+
148
+ class MixPrecisionOptimizerMon(OptimizerMon):
149
+ """
150
+ 混合精度优化器监控类。在混合精度训练中监控和管理优化器。
151
+ 混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
152
+ """
153
+
154
+ def fetch_mv(self, monitor, torch_opt, params2name):
155
+ mix_prec_opt = self.wrapped_optimizer
156
+
157
+ if not self.fp16_to_fp32_param and mix_prec_opt is not None:
158
+ for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups):
159
+ for fp16_param, fp32_param in zip(fp16_group, fp32_group):
160
+ self.fp16_to_fp32_param[fp16_param] = fp32_param
161
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
162
+
163
+
164
+ class MegatronDistributedOptimizerMon(OptimizerMon):
165
+ def fetch_mv(self, monitor, torch_opt, params2name):
166
+ mix_prec_opt = self.wrapped_optimizer
167
+ if not (hasattr(mix_prec_opt, "model_float16_groups") and
168
+ hasattr(mix_prec_opt, "shard_fp32_from_float16_groups")):
169
+ raise Exception(
170
+ "megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, "
171
+ "if not, please check megatron-lm version")
172
+ if not self.fp16_to_fp32_param and mix_prec_opt is not None:
173
+ for fp16_group, shard_fp32_group in zip(mix_prec_opt.model_float16_groups,
174
+ mix_prec_opt.shard_fp32_from_float16_groups):
175
+ for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
176
+ self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
177
+
178
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
179
+
180
+
181
+ class MegatronFP32OptimizerMon(OptimizerMon):
182
+ def fetch_mv(self, monitor, torch_opt, params2name):
183
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
184
+
185
+
186
+ class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon):
187
+ def fetch_mv(self, monitor, torch_opt, params2name):
188
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
189
+
190
+
191
+ class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
192
+ def get_param_index(self, params2name, name2index):
193
+ mix_prec_opt = OptimizerMon.wrapped_optimizer
194
+ fp16_groups = mix_prec_opt.fp16_partitioned_groups
195
+ name2indices = defaultdict()
196
+ index_length = defaultdict()
197
+ index = 0
198
+ idx = 0
199
+ for group_idx, fp16_group in enumerate(fp16_groups):
200
+ for param in fp16_group:
201
+ param_length = len(param.flatten())
202
+ index_length[idx] = (index, index + param_length, group_idx)
203
+ index += param_length
204
+ idx += 1
205
+ for _, name in params2name.items():
206
+ idx = name2index[name]
207
+ start_idx, end_idx, group_idx = index_length[idx]
208
+ name2indices[name] = (start_idx, end_idx, group_idx, None)
209
+ return name2indices
210
+
211
+ def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
212
+ self.is_stage3 = True
213
+ mix_prec_opt = OptimizerMon.wrapped_optimizer
214
+ fp32_partitioned_groups_flat = mix_prec_opt.fp32_partitioned_groups_flat
215
+ return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
216
+
217
+
218
+ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
219
+
220
+ @staticmethod
221
+ def get_group_index(fp32_length, world_size, index):
222
+ for i in range(len(fp32_length) - 1):
223
+ if fp32_length[i] <= index < fp32_length[i + 1]:
224
+ interval_start = fp32_length[i]
225
+ interval_length = fp32_length[i + 1] - fp32_length[i]
226
+ sub_interval_length = interval_length // world_size
227
+ sub_index = (index - interval_start) // sub_interval_length
228
+ sub_interval_start = interval_start + sub_index * sub_interval_length
229
+ return sub_interval_start, min(sub_index, world_size - 1)
230
+ return fp32_length[-1], 0
231
+
232
+ def get_param_index(self, params2name, name2index):
233
+ mix_prec_opt = OptimizerMon.wrapped_optimizer
234
+ padding = mix_prec_opt.groups_padding
235
+ world_size = dist.get_world_size()
236
+ fp32_length = [0]
237
+ for fp32_group_index, single_partition_of_fp32_group in enumerate(mix_prec_opt.single_partition_of_fp32_groups):
238
+ fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index])
239
+
240
+ bf16_groups = []
241
+ name2indices = defaultdict()
242
+ index_length = defaultdict()
243
+ index = 0
244
+ idx = 0
245
+ for group_idx, bf16_group in enumerate(mix_prec_opt.bit16_groups):
246
+ bf16_groups.extend(bf16_group)
247
+ for param in bf16_group:
248
+ param_length = len(param.flatten())
249
+ group_index, group_with_rank = self.get_group_index(fp32_length, world_size, index)
250
+ index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank)
251
+ index += param_length
252
+ idx += 1
253
+ group_length = len(bf16_groups) / len(mix_prec_opt.bit16_groups)
254
+ for _, name in params2name.items():
255
+ name_index = name2index[name]
256
+ start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index]
257
+ need_padding = True if group_with_rank == world_size - 1 else False
258
+ new_start_idx = start_idx - group_index
259
+ new_end_idx = end_idx - group_index
260
+ if need_padding and group_length - 1 <= name_index <= len(bf16_groups) - 1 and name_index % (
261
+ group_length - 1) == 0:
262
+ new_end_idx -= padding[int(name_index // (group_length - 1) - 1)]
263
+ name2indices[name] = (new_start_idx, new_end_idx, group_idx, group_with_rank)
264
+ return name2indices
265
+
266
+ def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
267
+ mix_prec_opt = OptimizerMon.wrapped_optimizer
268
+ fp32_partitioned_groups_flat = mix_prec_opt.single_partition_of_fp32_groups
269
+ return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
270
+
271
+
272
+ class DummyOptimizerMon(OptimizerMon):
273
+ def fetch_mv(self, monitor, torch_opt, params2name):
274
+ return MVResult(exp_avg=None, exp_avg_sq=None, update=None, ratio=None)
275
+
276
+
277
+ class OptimizerMonFactory:
278
+ _optimizer_mon_map = {
279
+ "Megatron_Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
280
+ "Megatron_DistributedOptimizer": MegatronDistributedOptimizerMon,
281
+ "Megatron_FP32Optimizer": MegatronFP32OptimizerMon,
282
+ "DeepSpeedZeroOptimizer_Stage0": DeepSpeedZeroOptimizerStage0Mon,
283
+ "DeepSpeedZeroOptimizer_Stage1_or_2": DeepSpeedZeroOptimizerStage1or2Mon,
284
+ "DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon,
285
+ "unknown": DummyOptimizerMon
286
+ }
287
+
288
+ @staticmethod
289
+ def create_optimizer_mon(opt_ty: str):
290
+ if not opt_ty:
291
+ return DummyOptimizerMon()
292
+ optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(opt_ty)
293
+ if not optimizer_mon_class:
294
+ raise Exception("opt_ty should be one of: " + ", ".join(OptimizerMonFactory._optimizer_mon_map.keys()))
295
+ return optimizer_mon_class()
File without changes