mindstudio-probe 1.1.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (299) hide show
  1. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/METADATA +7 -6
  2. mindstudio_probe-1.2.1.dist-info/RECORD +396 -0
  3. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/WHEEL +1 -1
  4. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/entry_points.txt +0 -1
  5. msprobe/CMakeLists.txt +5 -0
  6. msprobe/README.md +51 -20
  7. msprobe/config.json +2 -3
  8. msprobe/core/advisor/advisor.py +8 -3
  9. msprobe/core/common/const.py +264 -15
  10. msprobe/core/common/exceptions.py +27 -3
  11. msprobe/core/common/file_utils.py +176 -26
  12. msprobe/core/common/inplace_op_checker.py +15 -0
  13. msprobe/core/common/inplace_ops.yaml +3 -0
  14. msprobe/core/common/log.py +27 -9
  15. msprobe/core/common/utils.py +204 -77
  16. msprobe/core/common_config.py +49 -14
  17. msprobe/core/compare/acc_compare.py +274 -198
  18. msprobe/core/compare/check.py +32 -33
  19. msprobe/core/compare/compare_cli.py +32 -14
  20. msprobe/core/compare/highlight.py +283 -127
  21. msprobe/core/compare/layer_mapping/__init__.py +19 -0
  22. msprobe/core/compare/layer_mapping/data_scope_parser.py +246 -0
  23. msprobe/core/compare/layer_mapping/layer_mapping.py +249 -0
  24. msprobe/core/compare/layer_mapping/postprocess_pass.py +95 -0
  25. msprobe/core/compare/merge_result/merge_result.py +380 -0
  26. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  27. msprobe/core/compare/multiprocessing_compute.py +2 -2
  28. msprobe/core/compare/npy_compare.py +135 -144
  29. msprobe/core/compare/utils.py +419 -274
  30. msprobe/core/data_dump/data_collector.py +60 -28
  31. msprobe/core/data_dump/data_processor/base.py +84 -36
  32. msprobe/core/data_dump/data_processor/factory.py +5 -3
  33. msprobe/core/data_dump/data_processor/mindspore_processor.py +152 -18
  34. msprobe/core/data_dump/data_processor/pytorch_processor.py +267 -110
  35. msprobe/core/data_dump/json_writer.py +29 -1
  36. msprobe/core/data_dump/scope.py +119 -39
  37. msprobe/core/grad_probe/constant.py +27 -13
  38. msprobe/core/grad_probe/grad_compare.py +18 -1
  39. msprobe/core/grad_probe/utils.py +30 -2
  40. msprobe/core/overflow_check/abnormal_scene.py +189 -0
  41. msprobe/core/overflow_check/api_info.py +55 -0
  42. msprobe/core/overflow_check/checker.py +138 -0
  43. msprobe/core/overflow_check/filter.py +157 -0
  44. msprobe/core/overflow_check/ignore_rules.yaml +55 -0
  45. msprobe/core/overflow_check/level.py +22 -0
  46. msprobe/core/overflow_check/utils.py +28 -0
  47. msprobe/docs/01.installation.md +96 -7
  48. msprobe/docs/02.config_introduction.md +50 -23
  49. msprobe/docs/03.config_examples.md +2 -9
  50. msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
  51. msprobe/docs/05.data_dump_PyTorch.md +93 -61
  52. msprobe/docs/06.data_dump_MindSpore.md +200 -95
  53. msprobe/docs/07.accuracy_checker_PyTorch.md +28 -28
  54. msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
  55. msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
  56. msprobe/docs/10.accuracy_compare_PyTorch.md +114 -50
  57. msprobe/docs/11.accuracy_compare_MindSpore.md +340 -48
  58. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  59. msprobe/docs/13.overflow_check_MindSpore.md +6 -6
  60. msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
  61. msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
  62. msprobe/docs/17.grad_probe.md +5 -6
  63. msprobe/docs/19.monitor.md +561 -0
  64. msprobe/docs/20.monitor_performance_baseline.md +52 -0
  65. msprobe/docs/21.visualization_PyTorch.md +466 -0
  66. msprobe/docs/22.visualization_MindSpore.md +481 -0
  67. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  68. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  69. msprobe/docs/25.tool_function_introduction.md +29 -0
  70. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  71. msprobe/docs/27.dump_json_instruction.md +521 -0
  72. msprobe/docs/FAQ.md +29 -2
  73. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  74. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  75. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +211 -0
  76. msprobe/docs/img/compare_result.png +0 -0
  77. msprobe/docs/img/merge_result.png +0 -0
  78. msprobe/docs/img/monitor/cpu_info.png +0 -0
  79. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  80. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  81. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  82. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  83. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  84. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  85. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  86. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  87. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  88. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  89. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  90. msprobe/docs/visualization/GPTModel.png +0 -0
  91. msprobe/docs/visualization/ParallelMLP.png +0 -0
  92. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  93. msprobe/docs/visualization/mapping.png +0 -0
  94. msprobe/docs/visualization/mapping1.png +0 -0
  95. msprobe/docs/visualization/module_name.png +0 -0
  96. msprobe/docs/visualization/module_name1.png +0 -0
  97. msprobe/docs/visualization/no_mapping.png +0 -0
  98. msprobe/docs/visualization/no_mapping1.png +0 -0
  99. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  100. msprobe/docs/visualization/top_layer.png +0 -0
  101. msprobe/mindspore/__init__.py +25 -0
  102. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -151
  103. msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
  104. msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
  105. msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
  106. msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
  107. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +64 -1
  108. msprobe/mindspore/api_accuracy_checker/compute_element.py +64 -31
  109. msprobe/mindspore/api_accuracy_checker/data_manager.py +301 -0
  110. msprobe/mindspore/api_accuracy_checker/main.py +28 -3
  111. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +212 -0
  112. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +60 -0
  113. msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
  114. msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
  115. msprobe/mindspore/cell_processor.py +33 -12
  116. msprobe/mindspore/code_mapping/bind.py +264 -0
  117. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  118. msprobe/mindspore/code_mapping/graph.py +49 -0
  119. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  120. msprobe/mindspore/code_mapping/main.py +24 -0
  121. msprobe/mindspore/code_mapping/processor.py +34 -0
  122. msprobe/mindspore/common/const.py +35 -13
  123. msprobe/mindspore/common/log.py +5 -9
  124. msprobe/mindspore/common/utils.py +88 -4
  125. msprobe/mindspore/compare/distributed_compare.py +22 -24
  126. msprobe/mindspore/compare/ms_compare.py +333 -268
  127. msprobe/mindspore/compare/ms_graph_compare.py +95 -52
  128. msprobe/mindspore/debugger/debugger_config.py +7 -1
  129. msprobe/mindspore/debugger/precision_debugger.py +87 -12
  130. msprobe/mindspore/dump/dump_tool_factory.py +3 -1
  131. msprobe/mindspore/dump/hook_cell/api_registry.py +95 -18
  132. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  133. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +45 -30
  134. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +36 -1
  135. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  136. msprobe/mindspore/dump/jit_dump.py +17 -5
  137. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  138. msprobe/mindspore/dump/kernel_graph_dump.py +9 -4
  139. msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
  140. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
  141. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
  142. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +156 -41
  143. msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
  144. msprobe/mindspore/free_benchmark/common/utils.py +19 -4
  145. msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
  146. msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
  147. msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
  148. msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
  149. msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
  150. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
  151. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
  152. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  153. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
  154. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
  155. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
  156. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
  157. msprobe/mindspore/grad_probe/global_context.py +28 -8
  158. msprobe/mindspore/grad_probe/grad_analyzer.py +50 -24
  159. msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
  160. msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
  161. msprobe/mindspore/grad_probe/hook.py +35 -12
  162. msprobe/mindspore/grad_probe/utils.py +18 -5
  163. msprobe/mindspore/mindtorch/__init__.py +18 -0
  164. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  165. msprobe/mindspore/ms_config.py +27 -16
  166. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +9 -4
  167. msprobe/mindspore/runtime.py +15 -0
  168. msprobe/mindspore/service.py +285 -113
  169. msprobe/mindspore/task_handler_factory.py +15 -0
  170. msprobe/msprobe.py +48 -10
  171. msprobe/pytorch/__init__.py +8 -6
  172. msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
  173. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  174. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  175. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +103 -271
  176. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  177. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  178. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  179. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  180. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  181. msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
  182. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +478 -0
  183. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
  184. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  185. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  186. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  187. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  188. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  189. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  190. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  191. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  192. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  193. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +63 -2
  194. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +21 -15
  195. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +54 -22
  196. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +140 -71
  197. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +49 -8
  198. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
  199. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
  200. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  201. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
  202. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
  203. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
  204. msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
  205. msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
  206. msprobe/pytorch/bench_functions/npu_fusion_attention.py +142 -16
  207. msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
  208. msprobe/pytorch/bench_functions/swiglu.py +10 -2
  209. msprobe/pytorch/common/parse_json.py +7 -6
  210. msprobe/pytorch/common/utils.py +101 -7
  211. msprobe/pytorch/compare/distributed_compare.py +17 -30
  212. msprobe/pytorch/compare/pt_compare.py +44 -22
  213. msprobe/pytorch/debugger/debugger_config.py +46 -27
  214. msprobe/pytorch/debugger/precision_debugger.py +42 -12
  215. msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
  216. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  217. msprobe/pytorch/{module_processer.py → dump/module_dump/module_processer.py} +81 -10
  218. msprobe/pytorch/free_benchmark/common/constant.py +15 -0
  219. msprobe/pytorch/free_benchmark/common/counter.py +15 -0
  220. msprobe/pytorch/free_benchmark/common/enums.py +15 -0
  221. msprobe/pytorch/free_benchmark/common/params.py +10 -2
  222. msprobe/pytorch/free_benchmark/common/utils.py +29 -4
  223. msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -5
  224. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
  225. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
  226. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
  227. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
  228. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
  229. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +41 -47
  230. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
  231. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  232. msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
  233. msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
  234. msprobe/pytorch/hook_module/__init__.py +1 -1
  235. msprobe/pytorch/hook_module/hook_module.py +14 -11
  236. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  237. msprobe/pytorch/hook_module/support_wrap_ops.yaml +35 -0
  238. msprobe/pytorch/hook_module/wrap_distributed.py +6 -8
  239. msprobe/pytorch/hook_module/wrap_functional.py +0 -38
  240. msprobe/pytorch/monitor/__init__.py +0 -0
  241. msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
  242. msprobe/pytorch/monitor/anomaly_detect.py +425 -0
  243. msprobe/pytorch/monitor/csv2tb.py +166 -0
  244. msprobe/pytorch/monitor/distributed/__init__.py +0 -0
  245. msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
  246. msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
  247. msprobe/pytorch/monitor/distributed/wrap_distributed.py +283 -0
  248. msprobe/pytorch/monitor/features.py +108 -0
  249. msprobe/pytorch/monitor/module_hook.py +1076 -0
  250. msprobe/pytorch/monitor/module_metric.py +172 -0
  251. msprobe/pytorch/monitor/module_spec_verifier.py +95 -0
  252. msprobe/pytorch/monitor/optimizer_collect.py +333 -0
  253. msprobe/pytorch/monitor/unittest/__init__.py +0 -0
  254. msprobe/pytorch/monitor/unittest/test_monitor.py +160 -0
  255. msprobe/pytorch/monitor/utils.py +321 -0
  256. msprobe/pytorch/monitor/visualizer.py +59 -0
  257. msprobe/pytorch/online_dispatch/__init__.py +2 -3
  258. msprobe/pytorch/online_dispatch/compare.py +29 -38
  259. msprobe/pytorch/online_dispatch/dispatch.py +58 -27
  260. msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
  261. msprobe/pytorch/online_dispatch/single_compare.py +53 -32
  262. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
  263. msprobe/pytorch/online_dispatch/utils.py +49 -21
  264. msprobe/pytorch/parse_tool/lib/compare.py +21 -27
  265. msprobe/pytorch/parse_tool/lib/config.py +6 -8
  266. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  267. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  268. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  269. msprobe/pytorch/parse_tool/lib/parse_tool.py +12 -12
  270. msprobe/pytorch/parse_tool/lib/utils.py +33 -53
  271. msprobe/pytorch/parse_tool/lib/visualization.py +11 -10
  272. msprobe/pytorch/pt_config.py +31 -8
  273. msprobe/pytorch/service.py +188 -108
  274. msprobe/visualization/__init__.py +14 -0
  275. msprobe/visualization/builder/__init__.py +14 -0
  276. msprobe/visualization/builder/graph_builder.py +222 -0
  277. msprobe/visualization/builder/msprobe_adapter.py +227 -0
  278. msprobe/visualization/compare/__init__.py +14 -0
  279. msprobe/visualization/compare/graph_comparator.py +180 -0
  280. msprobe/visualization/compare/mode_adapter.py +197 -0
  281. msprobe/visualization/graph/__init__.py +14 -0
  282. msprobe/visualization/graph/base_node.py +119 -0
  283. msprobe/visualization/graph/distributed_analyzer.py +318 -0
  284. msprobe/visualization/graph/graph.py +209 -0
  285. msprobe/visualization/graph/node_colors.py +95 -0
  286. msprobe/visualization/graph/node_op.py +39 -0
  287. msprobe/visualization/graph_service.py +288 -0
  288. msprobe/visualization/utils.py +217 -0
  289. mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
  290. msprobe/docs/04.acl_config_examples.md +0 -78
  291. msprobe/mindspore/compare/layer_mapping.py +0 -146
  292. msprobe/mindspore/compare/modify_mapping.py +0 -107
  293. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
  294. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
  295. msprobe/pytorch/functional/module_dump.py +0 -84
  296. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/LICENSE +0 -0
  297. {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.2.1.dist-info}/top_level.txt +0 -0
  298. /msprobe/mindspore/{free_benchmark/decorator → code_mapping}/__init__.py +0 -0
  299. /msprobe/pytorch/{functional → dump/module_dump}/__init__.py +0 -0
@@ -0,0 +1,172 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import re
16
+
17
+ import torch
18
+
19
+ from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
20
+ from msprobe.pytorch.monitor.utils import NAN_TENSOR_ON_DEVICE
21
+
22
+
23
+ def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
24
+ if rank is None:
25
+ return f"{module_or_param_name}/{tag}"
26
+ else:
27
+ return f"{module_or_param_name}/rank{rank}/{tag}"
28
+
29
+
30
+ def squash_param_name(param_name, enable=True):
31
+ if not enable:
32
+ return param_name
33
+ name = ''
34
+ for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']:
35
+ match = re.findall(pattern, param_name)
36
+ if match:
37
+ name += match[0]
38
+ break
39
+ if name == '':
40
+ name = param_name
41
+ return name
42
+
43
+
44
+ # 用于存储所有metric实现类的注册表
45
+ config_metric_registry = {}
46
+
47
+
48
+ def register_config_metric(key, cls=None):
49
+ """装饰器 用于注册Metric的实现类"""
50
+ if cls is None:
51
+ # 无参数时,返回装饰器函数
52
+ return lambda cls_: register_config_metric(key, cls_)
53
+ config_metric_registry[key] = cls()
54
+ return cls
55
+
56
+
57
+ class TensorMetrics:
58
+ fun_map = {"norm": get_norm, "max": get_max, "min": get_min, "mean": get_mean}
59
+
60
+ def __init__(self) -> None:
61
+ self.metrics = {} # tensor_tag --> []
62
+ self.cur_idx = {}
63
+
64
+ def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank):
65
+ """get stats and insert into metrics dictionary"""
66
+ prefix = get_summary_writer_tag_name(module_name, tensor_name, rank)
67
+ for stat_op in stat_ops:
68
+ y = TensorMetrics.fun_map[stat_op](tensor)
69
+ key = f"{prefix}_{stat_op}"
70
+ if key not in self.metrics:
71
+ self.metrics[key] = []
72
+ self.cur_idx[key] = 0
73
+ self.metrics[key].append(y)
74
+
75
+ def flush(self, tb_writer):
76
+ for key, metric_list in self.metrics.items():
77
+ start = self.cur_idx[key]
78
+ for v in metric_list[start:]:
79
+ tb_writer.add_scalar(key, v.item(), global_step=self.cur_idx[key])
80
+ self.cur_idx[key] += 1
81
+
82
+
83
+ class Metric(object):
84
+ @staticmethod
85
+ def get_metric_value(tensor, eps):
86
+ NotImplementedError
87
+
88
+ def get_metric(self, tensor, eps):
89
+ try:
90
+ return self.get_metric_value(tensor, eps)
91
+ except RuntimeError as e:
92
+ return torch.tensor(torch.nan).to(tensor.device)
93
+
94
+
95
+ @register_config_metric("min")
96
+ class MinMetric(Metric):
97
+ @staticmethod
98
+ def get_metric_value(tensor, eps):
99
+ return get_min(tensor)
100
+
101
+
102
+ @register_config_metric("mean")
103
+ class MeanMetric(Metric):
104
+ @staticmethod
105
+ def get_metric_value(tensor, eps):
106
+ return get_mean(tensor)
107
+
108
+
109
+ @register_config_metric("max")
110
+ class MaxMetric(Metric):
111
+ @staticmethod
112
+ def get_metric_value(tensor, eps):
113
+ return get_max(tensor)
114
+
115
+
116
+ @register_config_metric("norm")
117
+ class NormMetric(Metric):
118
+ @staticmethod
119
+ def get_metric_value(tensor, eps):
120
+ return get_norm(tensor)
121
+
122
+
123
+ @register_config_metric("zeros")
124
+ class ZerosMetric(Metric):
125
+ @staticmethod
126
+ def get_metric_value(tensor, eps):
127
+ return get_zeros(tensor, eps)
128
+
129
+
130
+ @register_config_metric("nans")
131
+ class NaNsMetric(Metric):
132
+ @staticmethod
133
+ def get_metric_value(tensor, eps):
134
+ return get_nans(tensor)
135
+
136
+
137
+ @register_config_metric("id")
138
+ class IdentMetric(Metric):
139
+ @staticmethod
140
+ def get_metric_value(tensor, eps):
141
+ if tensor.dim() != 0:
142
+ return None
143
+ return tensor
144
+
145
+
146
+ def get_metrics(ops, tag2tensor, eps, out_dict=None):
147
+ """
148
+ :param ops: ["op1", "op2"]
149
+ :param tag2tensor: {
150
+ '0:fc_0/input': torch.randn([3, 4]),
151
+ '0:fc_0/output': torch.randn([3, 3])
152
+ }
153
+ :param eps: float 1e-8
154
+ :param out_dict:{
155
+ '0:fc_0/input': {"op1": op1(torch.randn([3, 4])), "op2": op2(torch.randn([3, 4]))}
156
+ '0:fc_0/output': {"op1": op1(torch.randn([3, 3])), "op2": op2(torch.randn([3, 3]))}
157
+ }
158
+ :return: out_dict
159
+ """
160
+ if out_dict is None:
161
+ out_dict = {}
162
+ for tag, tensor in tag2tensor.items():
163
+ if tag not in out_dict:
164
+ out_dict[tag] = {}
165
+ if not torch.is_tensor(tensor):
166
+ # Non-tensor in/output filled with nan.
167
+ out_dict[tag].update({metric_name: NAN_TENSOR_ON_DEVICE for metric_name in ops})
168
+ continue
169
+ for metric_name in ops:
170
+ fun_metric = config_metric_registry.get(metric_name)
171
+ out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
172
+ return out_dict
@@ -0,0 +1,95 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import re
17
+ import abc
18
+ import torch
19
+
20
+ from msprobe.pytorch.common.log import logger
21
+
22
+ # 用于存储所有validator实现类的注册表
23
+ config_validator_registry = {}
24
+
25
+
26
+ def register_config_validator(cls):
27
+ """装饰器 用于注册ConfigValidator的实现类"""
28
+ config_validator_registry[cls.__name__] = cls
29
+ return cls
30
+
31
+
32
+ class ConfigValidator(metaclass=abc.ABCMeta):
33
+ @abc.abstractmethod
34
+ def check_pattern_match(self, config_spec: str):
35
+ pass
36
+
37
+ @abc.abstractmethod
38
+ def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
39
+ pass
40
+
41
+
42
+ @register_config_validator
43
+ class TensorValidator(ConfigValidator):
44
+ def check_pattern_match(self, config_spec: str):
45
+ pattern = re.compile(r"tensor")
46
+ return pattern.match(config_spec)
47
+
48
+ def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
49
+ if not torch.is_tensor(actual_data):
50
+ raise ValueError(
51
+ f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.")
52
+
53
+
54
+ @register_config_validator
55
+ class TupleValidator(ConfigValidator):
56
+ def check_pattern_match(self, config_spec: str):
57
+ pattern = re.compile(r"tuple\[(\d+)\]:?(\d+)?")
58
+ return pattern.match(config_spec)
59
+
60
+ def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
61
+ length, index = pattern_match.groups()
62
+ if index is None:
63
+ index = 0
64
+ length, index = int(length), int(index)
65
+
66
+ if not (0 <= index < length):
67
+ raise ValueError(
68
+ f"Format of {module_name} {data_type} in config.json does not match the required format 'tuple[x]:y'."
69
+ f"y must be greater than or equal to 0 and less than x.")
70
+ if not isinstance(actual_data, tuple):
71
+ raise ValueError(
72
+ f"Type of {module_name} {data_type} does not match spec of config.json, should be tuple, please check.")
73
+ if len(actual_data) != length:
74
+ raise ValueError(
75
+ f"Length of {module_name} {data_type} does not match spec of config.json, should be {length}, "
76
+ f"actual is {len(actual_data)} please check.")
77
+ return index
78
+
79
+
80
+ def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str):
81
+ focused_col = None
82
+ if not config_spec or not isinstance(config_spec, str):
83
+ return focused_col
84
+ for _, validator_cls in config_validator_registry.items():
85
+ config_validator = validator_cls()
86
+ pattern_match = config_validator.check_pattern_match(config_spec)
87
+ if pattern_match:
88
+ try:
89
+ focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match)
90
+ except ValueError as e:
91
+ logger.warning(f"config spec validate failed: {str(e)}")
92
+ return focused_col
93
+ logger.warning(f"config spec in {module_name} {data_type} not supported, "
94
+ f"expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.")
95
+ return focused_col
@@ -0,0 +1,333 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import defaultdict
17
+
18
+ import torch
19
+ import torch.distributed as dist
20
+
21
+ from msprobe.pytorch.common.log import logger
22
+ from msprobe.pytorch.monitor.utils import MVResult, MVGradResult
23
+
24
+
25
+ class OptimizerMon(object):
26
+ wrapped_optimizer = None
27
+
28
+ def __init__(self) -> None:
29
+ self.fp16_to_fp32_param = {}
30
+ self.is_stage3 = False
31
+
32
+ @classmethod
33
+ def set_wrapped_optimizer(cls, wrapped_optimizer):
34
+ cls.wrapped_optimizer = wrapped_optimizer
35
+
36
+ def fetch_mv(self, monitor, torch_opt, params2name):
37
+ pass
38
+
39
+ def _fetch_mv_in_adam(self, monitor, torch_opt, params2name):
40
+ exp_avg_dict = defaultdict(float)
41
+ exp_avg_sq_dict = defaultdict(float)
42
+ update_dict = defaultdict()
43
+ ratio_dict = defaultdict()
44
+ for param, name in params2name.items():
45
+ if param in self.fp16_to_fp32_param:
46
+ param = self.fp16_to_fp32_param[param]
47
+
48
+ if param in torch_opt.state:
49
+ state_param = torch_opt.state.get(param, None)
50
+ exp_avg = state_param.get("exp_avg", None)
51
+ exp_avg_sq = state_param.get("exp_avg_sq", None)
52
+ if exp_avg is None or exp_avg_sq is None:
53
+ logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.")
54
+ continue
55
+ if monitor.mv_distribution:
56
+ exp_avg_dict[name] = exp_avg
57
+ exp_avg_sq_dict[name] = exp_avg_sq
58
+ if monitor.mg_direction:
59
+ exp_avg_dict[name] = exp_avg
60
+ if monitor.ur_distribution:
61
+ if len(torch_opt.param_groups) > 1:
62
+ logger.info(f"the length of torch_opt.param_groups is {len(torch_opt.param_groups)}.")
63
+ if 'step' in state_param:
64
+ step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
65
+ elif 'step' in torch_opt.param_groups[0]:
66
+ step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed
67
+ else:
68
+ logger.warning(f"step of {name} is None, maybe something wrong happened.")
69
+ continue
70
+ exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
71
+ exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
72
+ update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
73
+ ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
74
+ monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
75
+ monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
76
+ return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
77
+
78
+ def _fetch_mv_grad_in_adam(self, monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat):
79
+ exp_avg_dict = defaultdict(float)
80
+ exp_avg_sq_dict = defaultdict(float)
81
+ update_dict = defaultdict()
82
+ ratio_dict = defaultdict()
83
+ param2name = defaultdict()
84
+ fp32_partitioned_groups_flat_grad = defaultdict()
85
+ mix_prec_opt = OptimizerMon.wrapped_optimizer
86
+ partition_id = dist.get_rank()
87
+
88
+ def get_flatten_grad(self, optimizer, group_idx):
89
+ if fp32_partitioned_groups_flat[group_idx].grad is None:
90
+ if partition_id == dist.get_world_size() - 1 and not self.is_stage3:
91
+ fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned(
92
+ optimizer.averaged_gradients[group_idx],
93
+ int(optimizer.partition_size[group_idx])
94
+ ).to(fp32_partitioned_groups_flat[group_idx].dtype)
95
+ else:
96
+ fp32_partitioned_groups_flat_grad = optimizer.flatten(
97
+ optimizer.averaged_gradients[group_idx]
98
+ ).to(fp32_partitioned_groups_flat[group_idx].dtype)
99
+ return fp32_partitioned_groups_flat_grad
100
+ else:
101
+ return fp32_partitioned_groups_flat[group_idx].grad
102
+
103
+ for group_idx in range(len(fp32_partitioned_groups_flat)):
104
+ fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, mix_prec_opt, group_idx)
105
+
106
+ for name in params2name.values():
107
+ start_idx, end_idx, group_idx, group_with_rank = name2indices[name]
108
+ if group_with_rank != partition_id and isinstance(group_with_rank, int):
109
+ continue
110
+ fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx]
111
+ fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx]
112
+ param2name[fp32_param] = name
113
+ if not mix_prec_opt.state:
114
+ continue
115
+ state_param = list(mix_prec_opt.state.values())[group_idx]
116
+ exp_avg = state_param.get("exp_avg", None)
117
+ exp_avg_sq = state_param.get("exp_avg_sq", None)
118
+ if exp_avg is None or exp_avg_sq is None:
119
+ logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.")
120
+ continue
121
+ exp_avg = exp_avg[start_idx: end_idx]
122
+ exp_avg_sq = exp_avg_sq[start_idx: end_idx]
123
+ if monitor.mv_distribution:
124
+ exp_avg_dict[name] = exp_avg
125
+ exp_avg_sq_dict[name] = exp_avg_sq
126
+ if monitor.mg_direction:
127
+ exp_avg_dict[name] = exp_avg
128
+ if monitor.ur_distribution:
129
+ if 'step' in state_param:
130
+ step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
131
+ elif 'step' in torch_opt.param_groups[group_idx]:
132
+ step = torch_opt.param_groups[group_idx]['step'] # AdamW from mindspeed
133
+ else:
134
+ logger.warning(f"step of {name} is None, maybe something wrong happened.")
135
+ continue
136
+ exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
137
+ exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
138
+ update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
139
+ ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
140
+ monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
141
+ monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
142
+ del fp32_partitioned_groups_flat_grad
143
+ return MVGradResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict,
144
+ grad=param2name)
145
+
146
+
147
+ class MixPrecisionOptimizerMon(OptimizerMon):
148
+ """
149
+ 混合精度优化器监控类。在混合精度训练中监控和管理优化器。
150
+ 混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
151
+ """
152
+
153
+ def map_fp16_tp_fp32_param(self, mix_prec_opt):
154
+ for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups):
155
+ for fp16_param, fp32_param in zip(fp16_group, fp32_group):
156
+ self.fp16_to_fp32_param[fp16_param] = fp32_param
157
+
158
+ def fetch_mv(self, monitor, torch_opt, params2name):
159
+ mix_prec_opt = self.wrapped_optimizer
160
+
161
+ if not self.fp16_to_fp32_param and mix_prec_opt is not None:
162
+ self.map_fp16_tp_fp32_param(mix_prec_opt)
163
+
164
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
165
+
166
+
167
+ class MegatronDistributedOptimizerMon(OptimizerMon):
168
+ def map_fp16_tp_fp32_param(self, mix_prec_opt):
169
+ if not (hasattr(mix_prec_opt, "model_float16_groups") and
170
+ hasattr(mix_prec_opt, "shard_fp32_from_float16_groups")):
171
+ raise Exception(
172
+ "megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, "
173
+ "if not, please check megatron-lm version")
174
+ for fp16_group, shard_fp32_group in zip(mix_prec_opt.model_float16_groups,
175
+ mix_prec_opt.shard_fp32_from_float16_groups):
176
+ for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
177
+ self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
178
+
179
+ def fetch_mv(self, monitor, torch_opt, params2name):
180
+ mix_prec_opt = self.wrapped_optimizer
181
+ if not self.fp16_to_fp32_param and mix_prec_opt is not None:
182
+ self.map_fp16_tp_fp32_param(mix_prec_opt)
183
+
184
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
185
+
186
+
187
+ class MegatronFP32OptimizerMon(OptimizerMon):
188
+ def fetch_mv(self, monitor, torch_opt, params2name):
189
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
190
+
191
+
192
+ class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
193
+ def fetch_mv(self, monitor, torch_opt, params2name):
194
+ mix_prec_opt = self.wrapped_optimizer
195
+
196
+ if not self.fp16_to_fp32_param and mix_prec_opt is not None:
197
+ for opt in mix_prec_opt.chained_optimizers:
198
+ self.map_fp16_tp_fp32_param(opt)
199
+
200
+ if not isinstance(torch_opt, torch.optim.Optimizer):
201
+ torch_opt.state = {}
202
+ for opt in mix_prec_opt.chained_optimizers:
203
+ torch_opt.state.update(opt.optimizer.state)
204
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
205
+
206
+
207
+ class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
208
+ def fetch_mv(self, monitor, torch_opt, params2name):
209
+ mix_prec_opt = self.wrapped_optimizer
210
+
211
+ if not self.fp16_to_fp32_param and mix_prec_opt is not None:
212
+ for opt in mix_prec_opt.chained_optimizers:
213
+ self.map_fp16_tp_fp32_param(opt)
214
+
215
+ if not isinstance(torch_opt, torch.optim.Optimizer):
216
+ torch_opt.state = {}
217
+ for opt in mix_prec_opt.chained_optimizers:
218
+ torch_opt.state.update(opt.optimizer.state)
219
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
220
+
221
+
222
+ class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon):
223
+ def fetch_mv(self, monitor, torch_opt, params2name):
224
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
225
+
226
+
227
+ class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
228
+ def get_param_index(self, params2name, name2index):
229
+ mix_prec_opt = OptimizerMon.wrapped_optimizer
230
+ fp16_groups = mix_prec_opt.fp16_partitioned_groups
231
+ name2indices = defaultdict()
232
+ index_length = defaultdict()
233
+ index = 0
234
+ idx = 0
235
+ for group_idx, fp16_group in enumerate(fp16_groups):
236
+ for param in fp16_group:
237
+ param_length = len(param.flatten())
238
+ index_length[idx] = (index, index + param_length, group_idx)
239
+ index += param_length
240
+ idx += 1
241
+ for _, name in params2name.items():
242
+ idx = name2index[name]
243
+ start_idx, end_idx, group_idx = index_length[idx]
244
+ name2indices[name] = (start_idx, end_idx, group_idx, None)
245
+ return name2indices
246
+
247
+ def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
248
+ self.is_stage3 = True
249
+ mix_prec_opt = OptimizerMon.wrapped_optimizer
250
+ fp32_partitioned_groups_flat = mix_prec_opt.fp32_partitioned_groups_flat
251
+ return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
252
+
253
+
254
+ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
255
+
256
+ @staticmethod
257
+ def get_group_index(fp32_length, world_size, index):
258
+ for i in range(len(fp32_length) - 1):
259
+ if fp32_length[i] <= index < fp32_length[i + 1]:
260
+ interval_start = fp32_length[i]
261
+ interval_length = fp32_length[i + 1] - fp32_length[i]
262
+ sub_interval_length = interval_length // world_size
263
+ sub_index = (index - interval_start) // sub_interval_length
264
+ sub_interval_start = interval_start + sub_index * sub_interval_length
265
+ return sub_interval_start, min(sub_index, world_size - 1)
266
+ return fp32_length[-1], 0
267
+
268
+ def get_param_index(self, params2name, name2index):
269
+ mix_prec_opt = OptimizerMon.wrapped_optimizer
270
+ padding = mix_prec_opt.groups_padding
271
+ world_size = dist.get_world_size()
272
+ fp32_length = [0]
273
+ for fp32_group_index, single_partition_of_fp32_group in enumerate(mix_prec_opt.single_partition_of_fp32_groups):
274
+ fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index])
275
+
276
+ bf16_groups = []
277
+ name2indices = defaultdict()
278
+ index_length = defaultdict()
279
+ index = 0
280
+ idx = 0
281
+ for group_idx, bf16_group in enumerate(mix_prec_opt.bit16_groups):
282
+ bf16_groups.extend(bf16_group)
283
+ for param in bf16_group:
284
+ param_length = len(param.flatten())
285
+ group_index, group_with_rank = self.get_group_index(fp32_length, world_size, index)
286
+ index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank)
287
+ index += param_length
288
+ idx += 1
289
+ group_length = len(bf16_groups) / len(mix_prec_opt.bit16_groups)
290
+ for _, name in params2name.items():
291
+ name_index = name2index[name]
292
+ start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index]
293
+ need_padding = True if group_with_rank == world_size - 1 else False
294
+ new_start_idx = start_idx - group_index
295
+ new_end_idx = end_idx - group_index
296
+ if need_padding and group_length - 1 <= name_index <= len(bf16_groups) - 1 and name_index % (
297
+ group_length - 1) == 0:
298
+ new_end_idx -= padding[int(name_index // (group_length - 1) - 1)]
299
+ name2indices[name] = (new_start_idx, new_end_idx, group_idx, group_with_rank)
300
+ return name2indices
301
+
302
+ def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
303
+ mix_prec_opt = OptimizerMon.wrapped_optimizer
304
+ fp32_partitioned_groups_flat = mix_prec_opt.single_partition_of_fp32_groups
305
+ return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
306
+
307
+
308
+ class DummyOptimizerMon(OptimizerMon):
309
+ def fetch_mv(self, monitor, torch_opt, params2name):
310
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
311
+
312
+
313
+ class OptimizerMonFactory:
314
+ _optimizer_mon_map = {
315
+ "Megatron_Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
316
+ "Megatron_DistributedOptimizer": MegatronDistributedOptimizerMon,
317
+ "Megatron_ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
318
+ "Megatron_ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
319
+ "Megatron_FP32Optimizer": MegatronFP32OptimizerMon,
320
+ "DeepSpeedZeroOptimizer_Stage0": DeepSpeedZeroOptimizerStage0Mon,
321
+ "DeepSpeedZeroOptimizer_Stage1_or_2": DeepSpeedZeroOptimizerStage1or2Mon,
322
+ "DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon,
323
+ "unknown": DummyOptimizerMon
324
+ }
325
+
326
+ @staticmethod
327
+ def create_optimizer_mon(opt_ty: str):
328
+ if not opt_ty:
329
+ return DummyOptimizerMon()
330
+ optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(opt_ty)
331
+ if not optimizer_mon_class:
332
+ raise Exception("opt_ty should be one of: " + ", ".join(OptimizerMonFactory._optimizer_mon_map.keys()))
333
+ return optimizer_mon_class()
File without changes