mindstudio-probe 1.1.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/METADATA +5 -5
- mindstudio_probe-1.1.1.dist-info/RECORD +341 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/WHEEL +1 -1
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/entry_points.txt +0 -1
- msprobe/README.md +39 -3
- msprobe/config.json +1 -3
- msprobe/core/advisor/advisor.py +8 -3
- msprobe/core/common/const.py +113 -13
- msprobe/core/common/exceptions.py +25 -3
- msprobe/core/common/file_utils.py +150 -26
- msprobe/core/common/inplace_op_checker.py +15 -0
- msprobe/core/common/log.py +27 -9
- msprobe/core/common/utils.py +182 -69
- msprobe/core/common_config.py +44 -15
- msprobe/core/compare/acc_compare.py +207 -142
- msprobe/core/compare/check.py +2 -5
- msprobe/core/compare/compare_cli.py +21 -4
- msprobe/core/compare/highlight.py +124 -55
- msprobe/core/compare/layer_mapping/__init__.py +19 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +235 -0
- msprobe/core/compare/layer_mapping/layer_mapping.py +242 -0
- msprobe/core/compare/layer_mapping/postprocess_pass.py +94 -0
- msprobe/core/compare/npy_compare.py +52 -23
- msprobe/core/compare/utils.py +272 -247
- msprobe/core/data_dump/data_collector.py +13 -11
- msprobe/core/data_dump/data_processor/base.py +46 -16
- msprobe/core/data_dump/data_processor/mindspore_processor.py +4 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +156 -59
- msprobe/core/data_dump/scope.py +113 -34
- msprobe/core/grad_probe/constant.py +27 -13
- msprobe/core/grad_probe/grad_compare.py +18 -1
- msprobe/core/grad_probe/utils.py +30 -2
- msprobe/core/overflow_check/abnormal_scene.py +185 -0
- msprobe/core/overflow_check/api_info.py +55 -0
- msprobe/core/overflow_check/checker.py +138 -0
- msprobe/core/overflow_check/filter.py +157 -0
- msprobe/core/overflow_check/ignore_rules.yaml +55 -0
- msprobe/core/overflow_check/level.py +22 -0
- msprobe/core/overflow_check/utils.py +28 -0
- msprobe/docs/01.installation.md +10 -0
- msprobe/docs/02.config_introduction.md +49 -22
- msprobe/docs/03.config_examples.md +2 -9
- msprobe/docs/04.kernel_dump_PyTorch.md +73 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -1
- msprobe/docs/06.data_dump_MindSpore.md +157 -90
- msprobe/docs/07.accuracy_checker_PyTorch.md +12 -12
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +1 -6
- msprobe/docs/09.accuracy_checker_MindSpore.md +44 -8
- msprobe/docs/10.accuracy_compare_PyTorch.md +19 -13
- msprobe/docs/11.accuracy_compare_MindSpore.md +104 -13
- msprobe/docs/12.overflow_check_PyTorch.md +1 -1
- msprobe/docs/13.overflow_check_MindSpore.md +6 -6
- msprobe/docs/15.free_benchmarking_PyTorch.md +4 -5
- msprobe/docs/16.free_benchmarking_MindSpore.md +56 -37
- msprobe/docs/17.grad_probe.md +5 -6
- msprobe/docs/19.monitor.md +468 -0
- msprobe/docs/20.monitor_performance_baseline.md +52 -0
- msprobe/docs/21.visualization_PyTorch.md +386 -0
- msprobe/docs/22.visualization_MindSpore.md +384 -0
- msprobe/docs/23.tool_function_introduction.md +28 -0
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/data_dump_Mindspore/dynamic_graph_quick_start_example.md +211 -0
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/monitor/cpu_info.png +0 -0
- msprobe/mindspore/__init__.py +15 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +113 -145
- msprobe/mindspore/api_accuracy_checker/api_info.py +21 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +43 -18
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +21 -7
- msprobe/mindspore/api_accuracy_checker/checker_support_api.yaml +77 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +63 -1
- msprobe/mindspore/api_accuracy_checker/compute_element.py +59 -24
- msprobe/mindspore/api_accuracy_checker/data_manager.py +264 -0
- msprobe/mindspore/api_accuracy_checker/main.py +27 -3
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +206 -0
- msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +58 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +22 -5
- msprobe/mindspore/api_accuracy_checker/utils.py +34 -17
- msprobe/mindspore/cell_processor.py +33 -12
- msprobe/mindspore/common/const.py +33 -13
- msprobe/mindspore/common/log.py +5 -9
- msprobe/mindspore/common/utils.py +43 -4
- msprobe/mindspore/compare/distributed_compare.py +22 -22
- msprobe/mindspore/compare/ms_compare.py +271 -248
- msprobe/mindspore/compare/ms_graph_compare.py +81 -47
- msprobe/mindspore/debugger/debugger_config.py +4 -1
- msprobe/mindspore/debugger/precision_debugger.py +7 -1
- msprobe/mindspore/dump/dump_tool_factory.py +3 -1
- msprobe/mindspore/dump/hook_cell/api_registry.py +12 -2
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +13 -16
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +25 -0
- msprobe/mindspore/dump/jit_dump.py +17 -5
- msprobe/mindspore/dump/kernel_graph_dump.py +2 -4
- msprobe/mindspore/dump/kernel_kbyk_dump.py +2 -4
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +140 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +53 -0
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +145 -39
- msprobe/mindspore/free_benchmark/common/handler_params.py +1 -2
- msprobe/mindspore/free_benchmark/common/utils.py +19 -4
- msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml +0 -204
- msprobe/mindspore/free_benchmark/handler/base_handler.py +3 -3
- msprobe/mindspore/free_benchmark/handler/check_handler.py +4 -5
- msprobe/mindspore/free_benchmark/handler/fix_handler.py +4 -4
- msprobe/mindspore/free_benchmark/handler/handler_factory.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/add_noise.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -6
- msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +4 -4
- msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +2 -2
- msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +13 -6
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +2 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +2 -2
- msprobe/mindspore/grad_probe/global_context.py +28 -8
- msprobe/mindspore/grad_probe/grad_analyzer.py +27 -13
- msprobe/mindspore/grad_probe/grad_monitor.py +16 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +33 -5
- msprobe/mindspore/grad_probe/hook.py +24 -10
- msprobe/mindspore/grad_probe/utils.py +18 -5
- msprobe/mindspore/ms_config.py +22 -15
- msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +2 -4
- msprobe/mindspore/runtime.py +15 -0
- msprobe/mindspore/service.py +36 -30
- msprobe/mindspore/task_handler_factory.py +15 -0
- msprobe/msprobe.py +24 -7
- msprobe/pytorch/__init__.py +3 -2
- msprobe/pytorch/api_accuracy_checker/common/config.py +62 -0
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -4
- msprobe/pytorch/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +454 -0
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +365 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +6 -1
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +19 -14
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +13 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +77 -53
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +15 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +9 -24
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +4 -12
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +9 -4
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +3 -11
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -2
- msprobe/pytorch/bench_functions/confusion_transpose.py +5 -1
- msprobe/pytorch/bench_functions/matmul_backward.py +12 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +100 -6
- msprobe/pytorch/bench_functions/rotary_mul.py +4 -0
- msprobe/pytorch/bench_functions/swiglu.py +10 -2
- msprobe/pytorch/common/parse_json.py +6 -6
- msprobe/pytorch/common/utils.py +56 -5
- msprobe/pytorch/compare/distributed_compare.py +8 -9
- msprobe/pytorch/compare/pt_compare.py +8 -6
- msprobe/pytorch/debugger/debugger_config.py +19 -15
- msprobe/pytorch/dump/kernel_dump/kernel_config.py +33 -0
- msprobe/pytorch/free_benchmark/common/constant.py +15 -0
- msprobe/pytorch/free_benchmark/common/counter.py +15 -0
- msprobe/pytorch/free_benchmark/common/enums.py +15 -0
- msprobe/pytorch/free_benchmark/common/params.py +8 -1
- msprobe/pytorch/free_benchmark/common/utils.py +26 -4
- msprobe/pytorch/free_benchmark/compare/grad_saver.py +20 -3
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +6 -4
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +2 -0
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +4 -0
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +10 -0
- msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +6 -5
- msprobe/pytorch/grad_probe/grad_monitor.py +23 -6
- msprobe/pytorch/grad_probe/grad_stat_csv.py +40 -10
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +1 -0
- msprobe/pytorch/hook_module/wrap_functional.py +14 -12
- msprobe/pytorch/module_processer.py +2 -5
- msprobe/pytorch/monitor/anomaly_analyse.py +201 -0
- msprobe/pytorch/monitor/anomaly_detect.py +340 -0
- msprobe/pytorch/monitor/distributed/__init__.py +0 -0
- msprobe/pytorch/monitor/distributed/distributed_ops.yaml +19 -0
- msprobe/pytorch/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +272 -0
- msprobe/pytorch/monitor/features.py +108 -0
- msprobe/pytorch/monitor/module_hook.py +870 -0
- msprobe/pytorch/monitor/module_metric.py +193 -0
- msprobe/pytorch/monitor/module_spec_verifier.py +93 -0
- msprobe/pytorch/monitor/optimizer_collect.py +295 -0
- msprobe/pytorch/monitor/unittest/__init__.py +0 -0
- msprobe/pytorch/monitor/unittest/test_monitor.py +145 -0
- msprobe/pytorch/monitor/utils.py +250 -0
- msprobe/pytorch/monitor/visualizer.py +59 -0
- msprobe/pytorch/online_dispatch/__init__.py +2 -3
- msprobe/pytorch/online_dispatch/compare.py +29 -38
- msprobe/pytorch/online_dispatch/dispatch.py +50 -25
- msprobe/pytorch/online_dispatch/dump_compare.py +21 -9
- msprobe/pytorch/online_dispatch/single_compare.py +53 -32
- msprobe/pytorch/online_dispatch/torch_ops_config.yaml +1 -1
- msprobe/pytorch/online_dispatch/utils.py +49 -21
- msprobe/pytorch/parse_tool/lib/compare.py +12 -18
- msprobe/pytorch/parse_tool/lib/config.py +1 -1
- msprobe/pytorch/parse_tool/lib/parse_tool.py +1 -2
- msprobe/pytorch/parse_tool/lib/utils.py +16 -35
- msprobe/pytorch/parse_tool/lib/visualization.py +2 -0
- msprobe/pytorch/pt_config.py +31 -8
- msprobe/pytorch/service.py +15 -5
- msprobe/visualization/__init__.py +14 -0
- msprobe/visualization/builder/__init__.py +14 -0
- msprobe/visualization/builder/graph_builder.py +165 -0
- msprobe/visualization/builder/msprobe_adapter.py +205 -0
- msprobe/visualization/compare/__init__.py +14 -0
- msprobe/visualization/compare/graph_comparator.py +130 -0
- msprobe/visualization/compare/mode_adapter.py +211 -0
- msprobe/visualization/graph/__init__.py +14 -0
- msprobe/visualization/graph/base_node.py +124 -0
- msprobe/visualization/graph/graph.py +200 -0
- msprobe/visualization/graph/node_colors.py +95 -0
- msprobe/visualization/graph/node_op.py +39 -0
- msprobe/visualization/graph_service.py +214 -0
- msprobe/visualization/utils.py +232 -0
- mindstudio_probe-1.1.0.dist-info/RECORD +0 -287
- msprobe/docs/04.acl_config_examples.md +0 -78
- msprobe/mindspore/compare/layer_mapping.py +0 -146
- msprobe/mindspore/compare/modify_mapping.py +0 -107
- msprobe/mindspore/free_benchmark/decorator/dec_forward.py +0 -57
- msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +0 -122
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.1.0.dist-info → mindstudio_probe-1.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore/free_benchmark/decorator → pytorch/monitor}/__init__.py +0 -0
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
import itertools
|
|
16
|
+
import math
|
|
17
|
+
import re
|
|
18
|
+
import statistics
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
|
|
22
|
+
from msprobe.core.common.const import MonitorConst
|
|
23
|
+
from msprobe.pytorch.monitor.features import square_sum, get_max, get_min, get_zeros, get_nans, get_norm, get_mean
|
|
24
|
+
from msprobe.core.common.log import logger
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
|
|
28
|
+
if rank is None:
|
|
29
|
+
return f"{module_or_param_name}/{tag}"
|
|
30
|
+
else:
|
|
31
|
+
return f"{module_or_param_name}/rank{rank}/{tag}"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def squash_param_name(param_name):
|
|
35
|
+
name = ''
|
|
36
|
+
for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']:
|
|
37
|
+
match = re.findall(pattern, param_name)
|
|
38
|
+
if match:
|
|
39
|
+
name += match[0]
|
|
40
|
+
break
|
|
41
|
+
if name == '':
|
|
42
|
+
name = param_name
|
|
43
|
+
return name
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
# 用于存储所有metric实现类的注册表
|
|
47
|
+
config_metric_registry = {}
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def register_config_metric(key, cls=None):
|
|
51
|
+
"""装饰器 用于注册Metric的实现类"""
|
|
52
|
+
if cls is None:
|
|
53
|
+
# 无参数时,返回装饰器函数
|
|
54
|
+
return lambda cls_: register_config_metric(key, cls_)
|
|
55
|
+
config_metric_registry[key] = cls()
|
|
56
|
+
return cls
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class TensorMetrics:
|
|
60
|
+
fun_map = {"norm": get_norm, "max": get_max, "min": get_min, "mean": get_mean}
|
|
61
|
+
|
|
62
|
+
def __init__(self) -> None:
|
|
63
|
+
self.metrics = {} # tensor_tag --> []
|
|
64
|
+
self.cur_idx = {}
|
|
65
|
+
|
|
66
|
+
def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank, eps=1e-8):
|
|
67
|
+
"""get stats and insert into metrics dictionary"""
|
|
68
|
+
prefix = get_summary_writer_tag_name(module_name, tensor_name, rank)
|
|
69
|
+
for stat_op in stat_ops:
|
|
70
|
+
y = TensorMetrics.fun_map[stat_op](tensor)
|
|
71
|
+
key = f"{prefix}_{stat_op}"
|
|
72
|
+
if key not in self.metrics:
|
|
73
|
+
self.metrics[key] = []
|
|
74
|
+
self.cur_idx[key] = 0
|
|
75
|
+
self.metrics[key].append(y)
|
|
76
|
+
|
|
77
|
+
def flush(self, tb_writer):
|
|
78
|
+
for key, metric_list in self.metrics.items():
|
|
79
|
+
start = self.cur_idx[key]
|
|
80
|
+
for v in metric_list[start:]:
|
|
81
|
+
tb_writer.add_scalar(key, v.item(), global_step=self.cur_idx[key])
|
|
82
|
+
self.cur_idx[key] += 1
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class Metric(object):
|
|
86
|
+
@staticmethod
|
|
87
|
+
def get_metric_value(tensor, eps):
|
|
88
|
+
NotImplementedError
|
|
89
|
+
|
|
90
|
+
def get_metric(self, tensor, eps):
|
|
91
|
+
try:
|
|
92
|
+
return self.get_metric_value(tensor, eps)
|
|
93
|
+
except RuntimeError as e:
|
|
94
|
+
return torch.tensor(torch.nan).to(tensor.device)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@register_config_metric("min")
|
|
98
|
+
class MinMetric(Metric):
|
|
99
|
+
@staticmethod
|
|
100
|
+
def get_metric_value(tensor, eps):
|
|
101
|
+
return get_min(tensor)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
@register_config_metric("mean")
|
|
105
|
+
class MeanMetric(Metric):
|
|
106
|
+
@staticmethod
|
|
107
|
+
def get_metric_value(tensor, eps):
|
|
108
|
+
return get_mean(tensor)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@register_config_metric("max")
|
|
112
|
+
class MaxMetric(Metric):
|
|
113
|
+
@staticmethod
|
|
114
|
+
def get_metric_value(tensor, eps):
|
|
115
|
+
return get_max(tensor)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@register_config_metric("norm")
|
|
119
|
+
class NormMetric(Metric):
|
|
120
|
+
@staticmethod
|
|
121
|
+
def get_metric_value(tensor, eps):
|
|
122
|
+
return get_norm(tensor)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@register_config_metric("zeros")
|
|
126
|
+
class ZerosMetric(Metric):
|
|
127
|
+
@staticmethod
|
|
128
|
+
def get_metric_value(tensor, eps):
|
|
129
|
+
return get_zeros(tensor, eps)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@register_config_metric("nans")
|
|
133
|
+
class NaNsMetric(Metric):
|
|
134
|
+
@staticmethod
|
|
135
|
+
def get_metric_value(tensor, eps):
|
|
136
|
+
return get_nans(tensor)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@register_config_metric("id")
|
|
140
|
+
class IdentMetric(Metric):
|
|
141
|
+
@staticmethod
|
|
142
|
+
def get_metric_value(tensor, eps):
|
|
143
|
+
if tensor.dim() != 0:
|
|
144
|
+
return None
|
|
145
|
+
return tensor
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def get_metrics(ops, tag2tensor, eps, out_dict=None):
|
|
149
|
+
if out_dict is None:
|
|
150
|
+
out_dict = {}
|
|
151
|
+
for tag, tensor in tag2tensor.items():
|
|
152
|
+
if tag not in out_dict:
|
|
153
|
+
out_dict[tag] = {}
|
|
154
|
+
for metric_name in ops:
|
|
155
|
+
fun_metric = config_metric_registry.get(metric_name)
|
|
156
|
+
out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
|
|
157
|
+
return out_dict
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def write_metrics_base(ops, summary_writer, metric_value, step, prefix=''):
|
|
161
|
+
if not metric_value:
|
|
162
|
+
return
|
|
163
|
+
tensors = []
|
|
164
|
+
tags = list(itertools.product(metric_value.keys(), ops))
|
|
165
|
+
for op2tensor in metric_value.values():
|
|
166
|
+
tensors.extend(op2tensor.values())
|
|
167
|
+
with torch.no_grad():
|
|
168
|
+
metric_list = torch.stack(tensors).cpu()
|
|
169
|
+
for tag, metric in zip(tags, metric_list):
|
|
170
|
+
summary_writer.add_scalar(tag, metric, step)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def write_metrics_csv(ops, summary_writer, metric_value, step, prefix=''):
|
|
174
|
+
write_metrics_base(ops, summary_writer, metric_value, step, prefix='')
|
|
175
|
+
|
|
176
|
+
if not summary_writer.header:
|
|
177
|
+
# 前向的norm用input.ops_和output.ops_,反向的用input_grad.ops_和output_grad.ops_
|
|
178
|
+
if prefix in {"actv", "actv_grad"}:
|
|
179
|
+
if prefix == "actv":
|
|
180
|
+
input_and_output = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT]
|
|
181
|
+
else:
|
|
182
|
+
input_and_output = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
|
|
183
|
+
ops_ = [MonitorConst.DOT.join(i[::-1]) for i in itertools.product(ops, input_and_output)]
|
|
184
|
+
summary_writer.header = ["module_name", "step", *ops_]
|
|
185
|
+
else:
|
|
186
|
+
summary_writer.header = ["param_name", "step", *ops]
|
|
187
|
+
|
|
188
|
+
for key in metric_value.keys():
|
|
189
|
+
if MonitorConst.VPP_SEP in key:
|
|
190
|
+
summary_writer.header.insert(0, 'vpp_stage')
|
|
191
|
+
break
|
|
192
|
+
summary_writer.write_csv(prefix, step)
|
|
193
|
+
summary_writer.header = []
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import re
|
|
17
|
+
import abc
|
|
18
|
+
import torch
|
|
19
|
+
|
|
20
|
+
from msprobe.core.common.log import logger
|
|
21
|
+
|
|
22
|
+
# 用于存储所有validator实现类的注册表
|
|
23
|
+
config_validator_registry = {}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def register_config_validator(cls):
|
|
27
|
+
"""装饰器 用于注册ConfigValidator的实现类"""
|
|
28
|
+
config_validator_registry[cls.__name__] = cls
|
|
29
|
+
return cls
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ConfigValidator(metaclass=abc.ABCMeta):
|
|
33
|
+
@abc.abstractmethod
|
|
34
|
+
def check_pattern_match(self, config_spec: str):
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
@abc.abstractmethod
|
|
38
|
+
def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@register_config_validator
|
|
43
|
+
class TensorValidator(ConfigValidator):
|
|
44
|
+
def check_pattern_match(self, config_spec: str):
|
|
45
|
+
pattern = re.compile(r"tensor")
|
|
46
|
+
return pattern.match(config_spec)
|
|
47
|
+
|
|
48
|
+
def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
|
|
49
|
+
if not torch.is_tensor(actual_data):
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@register_config_validator
|
|
55
|
+
class TupleValidator(ConfigValidator):
|
|
56
|
+
def check_pattern_match(self, config_spec: str):
|
|
57
|
+
pattern = re.compile(r"tuple\[(\d+)\]:?(\d+)?")
|
|
58
|
+
return pattern.match(config_spec)
|
|
59
|
+
|
|
60
|
+
def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
|
|
61
|
+
length, index = pattern_match.groups()
|
|
62
|
+
if index is None:
|
|
63
|
+
index = 0
|
|
64
|
+
length, index = int(length), int(index)
|
|
65
|
+
|
|
66
|
+
if not (0 <= index < length):
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"Format of {module_name} {data_type} in config.json does not match the required format 'tuple[x]:y'."
|
|
69
|
+
f"y must be greater than or equal to 0 and less than x.")
|
|
70
|
+
if not isinstance(actual_data, tuple):
|
|
71
|
+
raise ValueError(
|
|
72
|
+
f"Type of {module_name} {data_type} does not match spec of config.json, should be tuple, please check.")
|
|
73
|
+
if len(actual_data) != length:
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"Length of {module_name} {data_type} does not match spec of config.json, should be {length}, "
|
|
76
|
+
f"actual is {len(actual_data)} please check.")
|
|
77
|
+
return index
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str):
|
|
81
|
+
focused_col = None
|
|
82
|
+
for _, validator_cls in config_validator_registry.items():
|
|
83
|
+
config_validator = validator_cls()
|
|
84
|
+
pattern_match = config_validator.check_pattern_match(config_spec)
|
|
85
|
+
if pattern_match:
|
|
86
|
+
try:
|
|
87
|
+
focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match)
|
|
88
|
+
except ValueError as e:
|
|
89
|
+
logger.warning(f"config spec validate failed: {str(e)}")
|
|
90
|
+
return focused_col
|
|
91
|
+
logger.warning(f"config spec in {module_name} {data_type} not supported, "
|
|
92
|
+
f"expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.")
|
|
93
|
+
return focused_col
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from abc import ABC, abstractmethod
|
|
17
|
+
from collections import defaultdict
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
import torch.distributed as dist
|
|
21
|
+
|
|
22
|
+
from msprobe.core.common.log import logger
|
|
23
|
+
from msprobe.pytorch.monitor.utils import MVResult, MVGradResult
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class OptimizerMon(object):
|
|
27
|
+
wrapped_optimizer = None
|
|
28
|
+
|
|
29
|
+
def __init__(self) -> None:
|
|
30
|
+
self.fp16_to_fp32_param = {}
|
|
31
|
+
self.is_stage3 = False
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def set_wrapped_optimizer(cls, wrapped_optimizer):
|
|
35
|
+
cls.wrapped_optimizer = wrapped_optimizer
|
|
36
|
+
|
|
37
|
+
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
def _fetch_mv_in_adam(self, monitor, torch_opt, params2name):
|
|
41
|
+
exp_avg_dict = defaultdict(float)
|
|
42
|
+
exp_avg_sq_dict = defaultdict(float)
|
|
43
|
+
update_dict = defaultdict()
|
|
44
|
+
ratio_dict = defaultdict()
|
|
45
|
+
for param, name in params2name.items():
|
|
46
|
+
if param in self.fp16_to_fp32_param:
|
|
47
|
+
param = self.fp16_to_fp32_param[param]
|
|
48
|
+
|
|
49
|
+
if param in torch_opt.state:
|
|
50
|
+
state_param = torch_opt.state.get(param, None)
|
|
51
|
+
exp_avg = state_param.get("exp_avg", None)
|
|
52
|
+
exp_avg_sq = state_param.get("exp_avg_sq", None)
|
|
53
|
+
if exp_avg is None or exp_avg_sq is None:
|
|
54
|
+
logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.")
|
|
55
|
+
continue
|
|
56
|
+
if monitor.mv_distribution:
|
|
57
|
+
exp_avg_dict[name] = exp_avg
|
|
58
|
+
exp_avg_sq_dict[name] = exp_avg_sq
|
|
59
|
+
if monitor.mg_direction:
|
|
60
|
+
exp_avg_dict[name] = exp_avg
|
|
61
|
+
if monitor.ur_distribution:
|
|
62
|
+
if len(torch_opt.param_groups) > 1:
|
|
63
|
+
logger.info(f"the length of torch_opt.param_groups is {len(torch_opt.param_groups)}.")
|
|
64
|
+
if 'step' in state_param:
|
|
65
|
+
step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
|
|
66
|
+
elif 'step' in torch_opt.param_groups[0]:
|
|
67
|
+
step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed
|
|
68
|
+
else:
|
|
69
|
+
logger.warning(f"step of {name} is None, maybe something wrong happened.")
|
|
70
|
+
continue
|
|
71
|
+
exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
|
|
72
|
+
exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
|
|
73
|
+
update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
|
|
74
|
+
ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
|
|
75
|
+
monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
|
|
76
|
+
monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
|
|
77
|
+
return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
|
|
78
|
+
|
|
79
|
+
def _fetch_mv_grad_in_adam(self, monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat):
|
|
80
|
+
exp_avg_dict = defaultdict(float)
|
|
81
|
+
exp_avg_sq_dict = defaultdict(float)
|
|
82
|
+
update_dict = defaultdict()
|
|
83
|
+
ratio_dict = defaultdict()
|
|
84
|
+
param2name = defaultdict()
|
|
85
|
+
fp32_partitioned_groups_flat_grad = defaultdict()
|
|
86
|
+
mix_prec_opt = OptimizerMon.wrapped_optimizer
|
|
87
|
+
partition_id = dist.get_rank()
|
|
88
|
+
|
|
89
|
+
def get_flatten_grad(self, optimizer, group_idx):
|
|
90
|
+
if fp32_partitioned_groups_flat[group_idx].grad is None:
|
|
91
|
+
if partition_id == dist.get_world_size() - 1 and not self.is_stage3:
|
|
92
|
+
fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned(
|
|
93
|
+
optimizer.averaged_gradients[group_idx],
|
|
94
|
+
int(optimizer.partition_size[group_idx])
|
|
95
|
+
).to(fp32_partitioned_groups_flat[group_idx].dtype)
|
|
96
|
+
else:
|
|
97
|
+
fp32_partitioned_groups_flat_grad = optimizer.flatten(
|
|
98
|
+
optimizer.averaged_gradients[group_idx]
|
|
99
|
+
).to(fp32_partitioned_groups_flat[group_idx].dtype)
|
|
100
|
+
return fp32_partitioned_groups_flat_grad
|
|
101
|
+
else:
|
|
102
|
+
return fp32_partitioned_groups_flat[group_idx].grad
|
|
103
|
+
|
|
104
|
+
for group_idx in range(len(fp32_partitioned_groups_flat)):
|
|
105
|
+
fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, mix_prec_opt, group_idx)
|
|
106
|
+
|
|
107
|
+
for name in params2name.values():
|
|
108
|
+
start_idx, end_idx, group_idx, group_with_rank = name2indices[name]
|
|
109
|
+
if group_with_rank != partition_id and isinstance(group_with_rank, int):
|
|
110
|
+
continue
|
|
111
|
+
fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx]
|
|
112
|
+
fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx]
|
|
113
|
+
param2name[fp32_param] = name
|
|
114
|
+
if not mix_prec_opt.state:
|
|
115
|
+
continue
|
|
116
|
+
state_param = list(mix_prec_opt.state.values())[group_idx]
|
|
117
|
+
exp_avg = state_param.get("exp_avg", None)
|
|
118
|
+
exp_avg_sq = state_param.get("exp_avg_sq", None)
|
|
119
|
+
if exp_avg is None or exp_avg_sq is None:
|
|
120
|
+
logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.")
|
|
121
|
+
continue
|
|
122
|
+
exp_avg = exp_avg[start_idx: end_idx]
|
|
123
|
+
exp_avg_sq = exp_avg_sq[start_idx: end_idx]
|
|
124
|
+
if monitor.mv_distribution:
|
|
125
|
+
exp_avg_dict[name] = exp_avg
|
|
126
|
+
exp_avg_sq_dict[name] = exp_avg_sq
|
|
127
|
+
if monitor.mg_direction:
|
|
128
|
+
exp_avg_dict[name] = exp_avg
|
|
129
|
+
if monitor.ur_distribution:
|
|
130
|
+
if 'step' in state_param:
|
|
131
|
+
step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
|
|
132
|
+
elif 'step' in torch_opt.param_groups[group_idx]:
|
|
133
|
+
step = torch_opt.param_groups[group_idx]['step'] # AdamW from mindspeed
|
|
134
|
+
else:
|
|
135
|
+
logger.warning(f"step of {name} is None, maybe something wrong happened.")
|
|
136
|
+
continue
|
|
137
|
+
exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
|
|
138
|
+
exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
|
|
139
|
+
update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
|
|
140
|
+
ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
|
|
141
|
+
monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
|
|
142
|
+
monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
|
|
143
|
+
del fp32_partitioned_groups_flat_grad
|
|
144
|
+
return MVGradResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict,
|
|
145
|
+
grad=param2name)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class MixPrecisionOptimizerMon(OptimizerMon):
|
|
149
|
+
"""
|
|
150
|
+
混合精度优化器监控类。在混合精度训练中监控和管理优化器。
|
|
151
|
+
混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
155
|
+
mix_prec_opt = self.wrapped_optimizer
|
|
156
|
+
|
|
157
|
+
if not self.fp16_to_fp32_param and mix_prec_opt is not None:
|
|
158
|
+
for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups):
|
|
159
|
+
for fp16_param, fp32_param in zip(fp16_group, fp32_group):
|
|
160
|
+
self.fp16_to_fp32_param[fp16_param] = fp32_param
|
|
161
|
+
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class MegatronDistributedOptimizerMon(OptimizerMon):
|
|
165
|
+
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
166
|
+
mix_prec_opt = self.wrapped_optimizer
|
|
167
|
+
if not (hasattr(mix_prec_opt, "model_float16_groups") and
|
|
168
|
+
hasattr(mix_prec_opt, "shard_fp32_from_float16_groups")):
|
|
169
|
+
raise Exception(
|
|
170
|
+
"megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, "
|
|
171
|
+
"if not, please check megatron-lm version")
|
|
172
|
+
if not self.fp16_to_fp32_param and mix_prec_opt is not None:
|
|
173
|
+
for fp16_group, shard_fp32_group in zip(mix_prec_opt.model_float16_groups,
|
|
174
|
+
mix_prec_opt.shard_fp32_from_float16_groups):
|
|
175
|
+
for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
|
|
176
|
+
self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
|
|
177
|
+
|
|
178
|
+
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class MegatronFP32OptimizerMon(OptimizerMon):
|
|
182
|
+
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
183
|
+
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon):
|
|
187
|
+
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
188
|
+
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
|
|
192
|
+
def get_param_index(self, params2name, name2index):
|
|
193
|
+
mix_prec_opt = OptimizerMon.wrapped_optimizer
|
|
194
|
+
fp16_groups = mix_prec_opt.fp16_partitioned_groups
|
|
195
|
+
name2indices = defaultdict()
|
|
196
|
+
index_length = defaultdict()
|
|
197
|
+
index = 0
|
|
198
|
+
idx = 0
|
|
199
|
+
for group_idx, fp16_group in enumerate(fp16_groups):
|
|
200
|
+
for param in fp16_group:
|
|
201
|
+
param_length = len(param.flatten())
|
|
202
|
+
index_length[idx] = (index, index + param_length, group_idx)
|
|
203
|
+
index += param_length
|
|
204
|
+
idx += 1
|
|
205
|
+
for _, name in params2name.items():
|
|
206
|
+
idx = name2index[name]
|
|
207
|
+
start_idx, end_idx, group_idx = index_length[idx]
|
|
208
|
+
name2indices[name] = (start_idx, end_idx, group_idx, None)
|
|
209
|
+
return name2indices
|
|
210
|
+
|
|
211
|
+
def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
|
|
212
|
+
self.is_stage3 = True
|
|
213
|
+
mix_prec_opt = OptimizerMon.wrapped_optimizer
|
|
214
|
+
fp32_partitioned_groups_flat = mix_prec_opt.fp32_partitioned_groups_flat
|
|
215
|
+
return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
|
|
219
|
+
|
|
220
|
+
@staticmethod
|
|
221
|
+
def get_group_index(fp32_length, world_size, index):
|
|
222
|
+
for i in range(len(fp32_length) - 1):
|
|
223
|
+
if fp32_length[i] <= index < fp32_length[i + 1]:
|
|
224
|
+
interval_start = fp32_length[i]
|
|
225
|
+
interval_length = fp32_length[i + 1] - fp32_length[i]
|
|
226
|
+
sub_interval_length = interval_length // world_size
|
|
227
|
+
sub_index = (index - interval_start) // sub_interval_length
|
|
228
|
+
sub_interval_start = interval_start + sub_index * sub_interval_length
|
|
229
|
+
return sub_interval_start, min(sub_index, world_size - 1)
|
|
230
|
+
return fp32_length[-1], 0
|
|
231
|
+
|
|
232
|
+
def get_param_index(self, params2name, name2index):
|
|
233
|
+
mix_prec_opt = OptimizerMon.wrapped_optimizer
|
|
234
|
+
padding = mix_prec_opt.groups_padding
|
|
235
|
+
world_size = dist.get_world_size()
|
|
236
|
+
fp32_length = [0]
|
|
237
|
+
for fp32_group_index, single_partition_of_fp32_group in enumerate(mix_prec_opt.single_partition_of_fp32_groups):
|
|
238
|
+
fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index])
|
|
239
|
+
|
|
240
|
+
bf16_groups = []
|
|
241
|
+
name2indices = defaultdict()
|
|
242
|
+
index_length = defaultdict()
|
|
243
|
+
index = 0
|
|
244
|
+
idx = 0
|
|
245
|
+
for group_idx, bf16_group in enumerate(mix_prec_opt.bit16_groups):
|
|
246
|
+
bf16_groups.extend(bf16_group)
|
|
247
|
+
for param in bf16_group:
|
|
248
|
+
param_length = len(param.flatten())
|
|
249
|
+
group_index, group_with_rank = self.get_group_index(fp32_length, world_size, index)
|
|
250
|
+
index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank)
|
|
251
|
+
index += param_length
|
|
252
|
+
idx += 1
|
|
253
|
+
group_length = len(bf16_groups) / len(mix_prec_opt.bit16_groups)
|
|
254
|
+
for _, name in params2name.items():
|
|
255
|
+
name_index = name2index[name]
|
|
256
|
+
start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index]
|
|
257
|
+
need_padding = True if group_with_rank == world_size - 1 else False
|
|
258
|
+
new_start_idx = start_idx - group_index
|
|
259
|
+
new_end_idx = end_idx - group_index
|
|
260
|
+
if need_padding and group_length - 1 <= name_index <= len(bf16_groups) - 1 and name_index % (
|
|
261
|
+
group_length - 1) == 0:
|
|
262
|
+
new_end_idx -= padding[int(name_index // (group_length - 1) - 1)]
|
|
263
|
+
name2indices[name] = (new_start_idx, new_end_idx, group_idx, group_with_rank)
|
|
264
|
+
return name2indices
|
|
265
|
+
|
|
266
|
+
def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
|
|
267
|
+
mix_prec_opt = OptimizerMon.wrapped_optimizer
|
|
268
|
+
fp32_partitioned_groups_flat = mix_prec_opt.single_partition_of_fp32_groups
|
|
269
|
+
return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
class DummyOptimizerMon(OptimizerMon):
|
|
273
|
+
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
274
|
+
return MVResult(exp_avg=None, exp_avg_sq=None, update=None, ratio=None)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class OptimizerMonFactory:
|
|
278
|
+
_optimizer_mon_map = {
|
|
279
|
+
"Megatron_Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
|
|
280
|
+
"Megatron_DistributedOptimizer": MegatronDistributedOptimizerMon,
|
|
281
|
+
"Megatron_FP32Optimizer": MegatronFP32OptimizerMon,
|
|
282
|
+
"DeepSpeedZeroOptimizer_Stage0": DeepSpeedZeroOptimizerStage0Mon,
|
|
283
|
+
"DeepSpeedZeroOptimizer_Stage1_or_2": DeepSpeedZeroOptimizerStage1or2Mon,
|
|
284
|
+
"DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon,
|
|
285
|
+
"unknown": DummyOptimizerMon
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
@staticmethod
|
|
289
|
+
def create_optimizer_mon(opt_ty: str):
|
|
290
|
+
if not opt_ty:
|
|
291
|
+
return DummyOptimizerMon()
|
|
292
|
+
optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(opt_ty)
|
|
293
|
+
if not optimizer_mon_class:
|
|
294
|
+
raise Exception("opt_ty should be one of: " + ", ".join(OptimizerMonFactory._optimizer_mon_map.keys()))
|
|
295
|
+
return optimizer_mon_class()
|
|
File without changes
|