mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
- msprobe/README.md +32 -1
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +120 -21
- msprobe/core/common/exceptions.py +2 -2
- msprobe/core/common/file_utils.py +279 -50
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +136 -45
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +646 -428
- msprobe/core/compare/check.py +36 -103
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +215 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
- msprobe/core/compare/merge_result/merge_result.py +4 -4
- msprobe/core/compare/multiprocessing_compute.py +223 -110
- msprobe/core/compare/npy_compare.py +2 -4
- msprobe/core/compare/utils.py +214 -244
- msprobe/core/config_check/__init__.py +17 -0
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +67 -4
- msprobe/core/data_dump/data_collector.py +170 -89
- msprobe/core/data_dump/data_processor/base.py +72 -51
- msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
- msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
- msprobe/core/data_dump/json_writer.py +143 -27
- msprobe/core/debugger/precision_debugger.py +144 -0
- msprobe/core/grad_probe/constant.py +1 -1
- msprobe/core/grad_probe/grad_compare.py +1 -1
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/service.py +357 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +146 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +79 -22
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +118 -49
- msprobe/docs/06.data_dump_MindSpore.md +167 -20
- msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
- msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
- msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
- msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
- msprobe/docs/12.overflow_check_PyTorch.md +2 -2
- msprobe/docs/13.overflow_check_MindSpore.md +2 -2
- msprobe/docs/14.data_parse_PyTorch.md +3 -3
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +2 -2
- msprobe/docs/19.monitor.md +90 -44
- msprobe/docs/21.visualization_PyTorch.md +68 -15
- msprobe/docs/22.visualization_MindSpore.md +71 -18
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +1 -1
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/29.data_dump_MSAdapter.md +2 -2
- msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +181 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/mindspore/__init__.py +1 -2
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
- msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +17 -7
- msprobe/mindspore/common/utils.py +128 -11
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +17 -405
- msprobe/mindspore/compare/ms_graph_compare.py +14 -5
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +53 -3
- msprobe/mindspore/debugger/precision_debugger.py +72 -91
- msprobe/mindspore/dump/cell_dump_process.py +877 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
- msprobe/mindspore/dump/dump_tool_factory.py +13 -5
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
- msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
- msprobe/mindspore/dump/jit_dump.py +21 -18
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/grad_probe/global_context.py +7 -2
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/mindspore_service.py +114 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/features.py +20 -7
- msprobe/mindspore/monitor/module_hook.py +281 -209
- msprobe/mindspore/monitor/optimizer_collect.py +334 -0
- msprobe/mindspore/monitor/utils.py +25 -5
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +26 -14
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +66 -118
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
- msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +29 -5
- msprobe/pytorch/hook_module/hook_module.py +9 -18
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +6 -2
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/module_hook.py +227 -158
- msprobe/pytorch/monitor/module_metric.py +14 -0
- msprobe/pytorch/monitor/optimizer_collect.py +242 -270
- msprobe/pytorch/monitor/utils.py +16 -3
- msprobe/pytorch/online_dispatch/dispatch.py +4 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
- msprobe/pytorch/parse_tool/lib/utils.py +3 -3
- msprobe/pytorch/pt_config.py +8 -7
- msprobe/pytorch/pytorch_service.py +73 -0
- msprobe/visualization/builder/graph_builder.py +33 -13
- msprobe/visualization/builder/msprobe_adapter.py +24 -11
- msprobe/visualization/compare/graph_comparator.py +53 -45
- msprobe/visualization/compare/mode_adapter.py +31 -1
- msprobe/visualization/graph/base_node.py +3 -3
- msprobe/visualization/graph/graph.py +2 -2
- msprobe/visualization/graph_service.py +250 -103
- msprobe/visualization/utils.py +27 -11
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -549
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -473
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -144,6 +144,20 @@ class IdentMetric(Metric):
|
|
|
144
144
|
return tensor
|
|
145
145
|
|
|
146
146
|
|
|
147
|
+
@register_config_metric("shape")
|
|
148
|
+
class ShapeMetric(Metric):
|
|
149
|
+
@staticmethod
|
|
150
|
+
def get_metric_value(tensor, eps):
|
|
151
|
+
return tensor.shape
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@register_config_metric("dtype")
|
|
155
|
+
class DtypeMetric(Metric):
|
|
156
|
+
@staticmethod
|
|
157
|
+
def get_metric_value(tensor, eps):
|
|
158
|
+
return tensor.dtype
|
|
159
|
+
|
|
160
|
+
|
|
147
161
|
def get_metrics(ops, tag2tensor, eps, out_dict=None):
|
|
148
162
|
"""
|
|
149
163
|
:param ops: ["op1", "op2"]
|
|
@@ -12,129 +12,123 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
from collections import defaultdict
|
|
15
|
+
from abc import abstractmethod
|
|
17
16
|
|
|
18
17
|
import torch
|
|
19
|
-
import torch.distributed as dist
|
|
20
18
|
|
|
21
19
|
from msprobe.pytorch.common.log import logger
|
|
22
|
-
from msprobe.pytorch.monitor.utils import MVResult
|
|
20
|
+
from msprobe.pytorch.monitor.utils import MVResult
|
|
21
|
+
from msprobe.core.common.const import MonitorConst
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
class OptimizerMon(object):
|
|
26
|
-
def __init__(self) -> None:
|
|
25
|
+
def __init__(self, torch_opt) -> None:
|
|
27
26
|
self.fp16_to_fp32_param = {}
|
|
28
|
-
self.
|
|
27
|
+
self.torch_opt = torch_opt
|
|
28
|
+
self.state = {}
|
|
29
29
|
|
|
30
|
-
def
|
|
31
|
-
|
|
30
|
+
def narrow_from_flatten(self, param, flatten_state):
|
|
31
|
+
return flatten_state
|
|
32
|
+
|
|
33
|
+
def get_state(self, torch_opt):
|
|
34
|
+
if hasattr(torch_opt, 'chained_optimizers'):
|
|
35
|
+
for opt in torch_opt.chained_optimizers:
|
|
36
|
+
self._get_single_state(opt)
|
|
37
|
+
else:
|
|
38
|
+
self._get_single_state(torch_opt)
|
|
32
39
|
|
|
33
|
-
def
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
40
|
+
def fetch_grad(self, monitor, params2name):
|
|
41
|
+
if not self.fp16_to_fp32_param:
|
|
42
|
+
self.map_fp16_to_fp32_param(self.torch_opt)
|
|
43
|
+
|
|
44
|
+
grad_dict = {}
|
|
45
|
+
first_param = True
|
|
38
46
|
for param, name in params2name.items():
|
|
39
|
-
if
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
if
|
|
47
|
-
|
|
48
|
-
|
|
47
|
+
if monitor.duplicate_param.get(name, False):
|
|
48
|
+
continue
|
|
49
|
+
if self.fp16_to_fp32_param and param not in self.fp16_to_fp32_param:
|
|
50
|
+
continue
|
|
51
|
+
grad = param.main_grad if monitor.params_have_main_grad else param.grad
|
|
52
|
+
element_in_cur_partition = self.fp16_to_fp32_param.get(param, param).numel()
|
|
53
|
+
if param.numel() != element_in_cur_partition:
|
|
54
|
+
if first_param:
|
|
55
|
+
grad = grad.flatten()[-element_in_cur_partition:]
|
|
56
|
+
else: # supposed to be the last one
|
|
57
|
+
grad = grad.flatten()[:element_in_cur_partition]
|
|
58
|
+
first_param = False
|
|
59
|
+
|
|
60
|
+
if grad is None:
|
|
61
|
+
if not monitor.fsdp_wrapped_module:
|
|
62
|
+
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
63
|
+
continue
|
|
64
|
+
tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
65
|
+
monitor.register_param_call_id("hook_optimizer", tag)
|
|
66
|
+
grad_dict[tag] = grad
|
|
67
|
+
return grad_dict
|
|
68
|
+
|
|
69
|
+
def map_fp16_to_fp32_param(self, torch_opt):
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
def fetch_mv(self, monitor, params2name):
|
|
73
|
+
if not self.fp16_to_fp32_param:
|
|
74
|
+
self.map_fp16_to_fp32_param(self.torch_opt)
|
|
75
|
+
if not self.state:
|
|
76
|
+
self.get_state(self.torch_opt)
|
|
77
|
+
|
|
78
|
+
exp_avg_dict = {}
|
|
79
|
+
exp_avg_sq_dict = {}
|
|
80
|
+
update_dict = {}
|
|
81
|
+
ratio_dict = {}
|
|
82
|
+
|
|
83
|
+
if not self.state:
|
|
84
|
+
logger.warning('optimizer state can not accessed')
|
|
85
|
+
return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
|
|
86
|
+
|
|
87
|
+
for lp_param, name in params2name.items():
|
|
88
|
+
if lp_param in self.fp16_to_fp32_param:
|
|
89
|
+
hp_param = self.fp16_to_fp32_param[lp_param]
|
|
90
|
+
else:
|
|
91
|
+
hp_param = lp_param
|
|
92
|
+
|
|
93
|
+
if hp_param in self.state:
|
|
94
|
+
state_param = self.state.get(hp_param, {})
|
|
95
|
+
exp_avg = self.narrow_from_flatten(lp_param, state_param.get("exp_avg", None))
|
|
96
|
+
exp_avg_sq = self.narrow_from_flatten(lp_param, state_param.get("exp_avg_sq", None))
|
|
49
97
|
if monitor.mv_distribution:
|
|
50
98
|
exp_avg_dict[name] = exp_avg
|
|
51
99
|
exp_avg_sq_dict[name] = exp_avg_sq
|
|
52
100
|
if monitor.mg_direction:
|
|
53
101
|
exp_avg_dict[name] = exp_avg
|
|
54
102
|
if monitor.ur_distribution:
|
|
55
|
-
if len(torch_opt.param_groups) > 1:
|
|
56
|
-
logger.info(f"the length of torch_opt.param_groups is {len(torch_opt.param_groups)}.")
|
|
103
|
+
if len(self.torch_opt.param_groups) > 1:
|
|
104
|
+
logger.info(f"the length of torch_opt.param_groups is {len(self.torch_opt.param_groups)}.")
|
|
57
105
|
if 'step' in state_param:
|
|
58
106
|
step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
|
|
59
|
-
elif 'step' in torch_opt.param_groups[0]:
|
|
60
|
-
step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed
|
|
107
|
+
elif 'step' in self.torch_opt.param_groups[0]:
|
|
108
|
+
step = self.torch_opt.param_groups[0]['step'] # AdamW from mindspeed
|
|
61
109
|
else:
|
|
62
110
|
logger.warning(f"step of {name} is None, maybe something wrong happened.")
|
|
63
111
|
continue
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
112
|
+
if exp_avg is None or exp_avg_sq is None:
|
|
113
|
+
logger.warning(f"exp_avg or exp_avg_sq of {name} is None, skip calculation.")
|
|
114
|
+
continue
|
|
115
|
+
exp_avg_hat = exp_avg / (1 - self.torch_opt.defaults['betas'][0] ** step)
|
|
116
|
+
exp_avg_sq_hat = exp_avg_sq / (1 - self.torch_opt.defaults['betas'][1] ** step)
|
|
117
|
+
update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + self.torch_opt.defaults['eps'])
|
|
67
118
|
ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
|
|
68
119
|
monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
|
|
69
120
|
monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
|
|
70
121
|
return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
|
|
71
|
-
|
|
72
|
-
def
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
def get_flatten_grad(self, optimizer, group_idx):
|
|
82
|
-
if fp32_partitioned_groups_flat[group_idx].grad is None:
|
|
83
|
-
if partition_id == dist.get_world_size() - 1 and not self.is_stage3:
|
|
84
|
-
fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned(
|
|
85
|
-
optimizer.averaged_gradients[group_idx],
|
|
86
|
-
int(optimizer.partition_size[group_idx])
|
|
87
|
-
).to(fp32_partitioned_groups_flat[group_idx].dtype)
|
|
88
|
-
else:
|
|
89
|
-
fp32_partitioned_groups_flat_grad = optimizer.flatten(
|
|
90
|
-
optimizer.averaged_gradients[group_idx]
|
|
91
|
-
).to(fp32_partitioned_groups_flat[group_idx].dtype)
|
|
92
|
-
return fp32_partitioned_groups_flat_grad
|
|
93
|
-
else:
|
|
94
|
-
return fp32_partitioned_groups_flat[group_idx].grad
|
|
95
|
-
|
|
96
|
-
for group_idx in range(len(fp32_partitioned_groups_flat)):
|
|
97
|
-
fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, torch_opt, group_idx)
|
|
98
|
-
|
|
99
|
-
for name in params2name.values():
|
|
100
|
-
start_idx, end_idx, group_idx, group_with_rank = name2indices[name]
|
|
101
|
-
if group_with_rank != partition_id and isinstance(group_with_rank, int):
|
|
102
|
-
continue
|
|
103
|
-
fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx]
|
|
104
|
-
fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx]
|
|
105
|
-
param2name[fp32_param] = name
|
|
106
|
-
if not torch_opt.state:
|
|
107
|
-
continue
|
|
108
|
-
state_param = list(torch_opt.state.values())[group_idx]
|
|
109
|
-
exp_avg = state_param.get("exp_avg", None)
|
|
110
|
-
exp_avg_sq = state_param.get("exp_avg_sq", None)
|
|
111
|
-
if exp_avg is None or exp_avg_sq is None:
|
|
112
|
-
logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.")
|
|
113
|
-
continue
|
|
114
|
-
exp_avg = exp_avg[start_idx: end_idx]
|
|
115
|
-
exp_avg_sq = exp_avg_sq[start_idx: end_idx]
|
|
116
|
-
if monitor.mv_distribution:
|
|
117
|
-
exp_avg_dict[name] = exp_avg
|
|
118
|
-
exp_avg_sq_dict[name] = exp_avg_sq
|
|
119
|
-
if monitor.mg_direction:
|
|
120
|
-
exp_avg_dict[name] = exp_avg
|
|
121
|
-
if monitor.ur_distribution:
|
|
122
|
-
if 'step' in state_param:
|
|
123
|
-
step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
|
|
124
|
-
elif 'step' in torch_opt.param_groups[group_idx]:
|
|
125
|
-
step = torch_opt.param_groups[group_idx]['step'] # AdamW from mindspeed
|
|
126
|
-
else:
|
|
127
|
-
logger.warning(f"step of {name} is None, maybe something wrong happened.")
|
|
128
|
-
continue
|
|
129
|
-
exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
|
|
130
|
-
exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
|
|
131
|
-
update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
|
|
132
|
-
ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
|
|
133
|
-
monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
|
|
134
|
-
monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
|
|
135
|
-
del fp32_partitioned_groups_flat_grad
|
|
136
|
-
return MVGradResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict,
|
|
137
|
-
grad=param2name)
|
|
122
|
+
|
|
123
|
+
def _get_single_state(self, torch_opt):
|
|
124
|
+
state = {}
|
|
125
|
+
if hasattr(torch_opt, 'param_to_cpu_states_map'):
|
|
126
|
+
state = torch_opt.param_to_cpu_states_map
|
|
127
|
+
elif hasattr(torch_opt, 'state'):
|
|
128
|
+
state = torch_opt.state
|
|
129
|
+
elif hasattr(torch_opt, 'optimizer') and hasattr(torch_opt.optimizer, 'state'):
|
|
130
|
+
state = torch_opt.optimizer.state
|
|
131
|
+
self.state.update(state)
|
|
138
132
|
|
|
139
133
|
|
|
140
134
|
class MixPrecisionOptimizerMon(OptimizerMon):
|
|
@@ -142,21 +136,14 @@ class MixPrecisionOptimizerMon(OptimizerMon):
|
|
|
142
136
|
混合精度优化器监控类。在混合精度训练中监控和管理优化器。
|
|
143
137
|
混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
|
|
144
138
|
"""
|
|
145
|
-
|
|
146
|
-
def map_fp16_tp_fp32_param(self, torch_opt):
|
|
139
|
+
def map_fp16_to_fp32_param(self, torch_opt):
|
|
147
140
|
for fp16_group, fp32_group in zip(torch_opt.float16_groups, torch_opt.fp32_from_float16_groups):
|
|
148
141
|
for fp16_param, fp32_param in zip(fp16_group, fp32_group):
|
|
149
142
|
self.fp16_to_fp32_param[fp16_param] = fp32_param
|
|
150
143
|
|
|
151
|
-
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
152
|
-
if not self.fp16_to_fp32_param and torch_opt is not None:
|
|
153
|
-
self.map_fp16_tp_fp32_param(torch_opt)
|
|
154
|
-
|
|
155
|
-
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
156
|
-
|
|
157
144
|
|
|
158
145
|
class MegatronDistributedOptimizerMon(OptimizerMon):
|
|
159
|
-
def
|
|
146
|
+
def map_fp16_to_fp32_param(self, torch_opt):
|
|
160
147
|
if not (hasattr(torch_opt, "model_float16_groups") and
|
|
161
148
|
hasattr(torch_opt, "shard_fp32_from_float16_groups")):
|
|
162
149
|
raise Exception(
|
|
@@ -167,192 +154,176 @@ class MegatronDistributedOptimizerMon(OptimizerMon):
|
|
|
167
154
|
for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
|
|
168
155
|
self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
|
|
169
156
|
|
|
170
|
-
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
171
|
-
if not self.fp16_to_fp32_param and torch_opt is not None:
|
|
172
|
-
self.map_fp16_tp_fp32_param(torch_opt)
|
|
173
|
-
|
|
174
|
-
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
class MegatronFP32OptimizerMon(OptimizerMon):
|
|
178
|
-
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
179
|
-
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
180
|
-
|
|
181
157
|
|
|
182
158
|
class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
|
|
183
|
-
def
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
self.map_fp16_tp_fp32_param(opt)
|
|
187
|
-
|
|
188
|
-
if not isinstance(torch_opt, torch.optim.Optimizer) and not hasattr(torch_opt, 'state'):
|
|
189
|
-
torch_opt.state = {}
|
|
190
|
-
for opt in torch_opt.chained_optimizers:
|
|
191
|
-
torch_opt.state.update(opt.optimizer.state)
|
|
192
|
-
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
159
|
+
def map_fp16_to_fp32_param(self, torch_opt):
|
|
160
|
+
for opt in torch_opt.chained_optimizers:
|
|
161
|
+
super().map_fp16_to_fp32_param(opt)
|
|
193
162
|
|
|
194
163
|
|
|
195
164
|
class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
|
|
196
|
-
def
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
self.map_fp16_tp_fp32_param(opt)
|
|
165
|
+
def map_fp16_to_fp32_param(self, torch_opt):
|
|
166
|
+
for opt in torch_opt.chained_optimizers:
|
|
167
|
+
super().map_fp16_to_fp32_param(opt)
|
|
200
168
|
|
|
201
|
-
if not isinstance(torch_opt, torch.optim.Optimizer) and not hasattr(torch_opt, 'state'):
|
|
202
|
-
torch_opt.state = {}
|
|
203
|
-
for opt in torch_opt.chained_optimizers:
|
|
204
|
-
torch_opt.state.update(opt.optimizer.state)
|
|
205
|
-
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
206
169
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
170
|
+
class DeepSpeedZeroOptimizerMon(OptimizerMon):
|
|
171
|
+
"""
|
|
172
|
+
Base monitor class for DeepSpeed ZeRO optimizer.
|
|
173
|
+
ZeRO stage 0 no partition
|
|
174
|
+
ZeRO stage 1 partitions optimizer states across data parallel processes.
|
|
175
|
+
ZeRO stage 2 additionally partitions gradients.
|
|
176
|
+
ZeRO stage 3 additionally partitions parameters.
|
|
177
|
+
|
|
178
|
+
This class provides monitoring capabilities for ZeRO optimizers by:
|
|
179
|
+
- Handling gradient collection for different ZeRO stages
|
|
180
|
+
- Managing optimizer state access for monitoring
|
|
181
|
+
"""
|
|
182
|
+
def __init__(self, torch_opt):
|
|
183
|
+
super().__init__(torch_opt)
|
|
184
|
+
self.stage = ''
|
|
185
|
+
self.bit16_groups = []
|
|
186
|
+
self.fp32_flat_groups = []
|
|
187
|
+
self.param2group = ()
|
|
188
|
+
self.param2index = []
|
|
189
|
+
self.group_offset = {}
|
|
190
|
+
|
|
191
|
+
@abstractmethod
|
|
192
|
+
def get_grad_for_param(self, lp_param, group_idx, param_id):
|
|
193
|
+
raise NotImplementedError
|
|
194
|
+
|
|
195
|
+
def param_not_in_partition(self, lp_param, group_idx):
|
|
196
|
+
param_slice_mapping = self.torch_opt.state_dict()['param_slice_mappings'][group_idx]
|
|
197
|
+
hp_address = param_slice_mapping.get(self.torch_opt.param_names.get(lp_param))
|
|
198
|
+
return hp_address is None
|
|
199
|
+
|
|
200
|
+
def get_position(self, lp_param, group_idx):
|
|
201
|
+
param_slice_mapping = self.torch_opt.state_dict()['param_slice_mappings'][group_idx]
|
|
202
|
+
hp_address = param_slice_mapping.get(self.torch_opt.param_names.get(lp_param))
|
|
203
|
+
return hp_address.start, hp_address.numel
|
|
204
|
+
|
|
205
|
+
def get_group_index(self):
|
|
206
|
+
param2group = {}
|
|
207
|
+
for group_idx, bit16_group in enumerate(self.bit16_groups):
|
|
213
208
|
for param in bit16_group:
|
|
214
209
|
param2group[param] = group_idx
|
|
215
210
|
return param2group
|
|
216
|
-
|
|
217
|
-
def
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
211
|
+
|
|
212
|
+
def get_param_index(self, lp_param, group_idx):
|
|
213
|
+
if not self.param2index:
|
|
214
|
+
for group in self.bit16_groups:
|
|
215
|
+
param2index = {}
|
|
216
|
+
for index, param in enumerate(group):
|
|
217
|
+
param2index[param] = index
|
|
218
|
+
self.param2index.append(param2index)
|
|
219
|
+
|
|
220
|
+
return self.param2index[group_idx][lp_param]
|
|
221
|
+
|
|
222
|
+
def narrow_from_flatten(self, param, flatten_state):
|
|
223
|
+
if flatten_state is None:
|
|
224
|
+
return flatten_state
|
|
225
|
+
group_idx = self.param2group[param]
|
|
226
|
+
if self.param_not_in_partition(param, group_idx):
|
|
227
|
+
return None
|
|
228
|
+
start, numel = self.get_position(param, group_idx)
|
|
229
|
+
return flatten_state.narrow(0, start, numel)
|
|
230
|
+
|
|
231
|
+
def map_fp16_to_fp32_param(self, torch_opt):
|
|
232
|
+
for group_idx, group in enumerate(self.bit16_groups):
|
|
233
|
+
for param in group:
|
|
234
|
+
self.fp16_to_fp32_param[param] = self.fp32_flat_groups[group_idx]
|
|
235
|
+
|
|
236
|
+
def fetch_grad(self, monitor, params2name):
|
|
237
|
+
grad_dict = {}
|
|
238
|
+
for lp_param, name in params2name.items():
|
|
239
|
+
group_idx = self.param2group[lp_param]
|
|
240
|
+
param_id = self.get_param_index(lp_param, group_idx)
|
|
241
|
+
if self.param_not_in_partition(lp_param, group_idx):
|
|
234
242
|
continue
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
243
|
+
if self.stage == '1or2':
|
|
244
|
+
param_id = param_id - self.group_offset[group_idx] - 1
|
|
245
|
+
grad = self.get_grad_for_param(lp_param, group_idx, param_id)
|
|
246
|
+
tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
247
|
+
monitor.register_param_call_id("hook_optimizer", tag)
|
|
248
|
+
grad_dict[tag] = grad
|
|
249
|
+
|
|
250
|
+
return grad_dict
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class DeepSpeedZeroOptimizerStage0Mon(DeepSpeedZeroOptimizerMon):
|
|
254
|
+
def __init__(self, torch_opt):
|
|
255
|
+
super().__init__(torch_opt)
|
|
256
|
+
self.stage = '0'
|
|
257
|
+
self.bit16_groups = torch_opt.bf16_groups
|
|
258
|
+
self.fp32_flat_groups = torch_opt.fp32_groups_flat_partition
|
|
259
|
+
self.param2group = self.get_group_index()
|
|
260
|
+
|
|
261
|
+
def get_grad_for_param(self, lp_param, group_idx, param_id):
|
|
262
|
+
return self.torch_opt.fp32_groups_gradient_dict[group_idx][param_id]
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class DeepSpeedZeroOptimizerStage1or2Mon(DeepSpeedZeroOptimizerMon):
|
|
266
|
+
def __init__(self, torch_opt):
|
|
267
|
+
super().__init__(torch_opt)
|
|
268
|
+
self.stage = '1or2'
|
|
269
|
+
self.bit16_groups = torch_opt.bit16_groups
|
|
270
|
+
self.fp32_flat_groups = torch_opt.single_partition_of_fp32_groups
|
|
271
|
+
self.param2group = self.get_group_index()
|
|
272
|
+
self.group_offset = {}
|
|
273
|
+
self.get_group_offset()
|
|
274
|
+
|
|
275
|
+
def get_grad_for_param(self, lp_param, group_idx, param_id):
|
|
276
|
+
if getattr(self.torch_opt, "cpu_offload", False):
|
|
277
|
+
grads = self.torch_opt.single_partition_of_fp32_groups[group_idx].grad
|
|
278
|
+
start, numel = self.get_position(lp_param, group_idx)
|
|
279
|
+
grad = grads.narrow(0, start, numel)
|
|
280
|
+
else:
|
|
281
|
+
grad = self.torch_opt.averaged_gradients[group_idx][param_id]
|
|
282
|
+
return grad
|
|
283
|
+
|
|
284
|
+
def get_group_offset(self):
|
|
285
|
+
for group_idx, group in enumerate(self.bit16_groups):
|
|
286
|
+
self.group_offset[group_idx] = -1
|
|
287
|
+
for lp_param in group:
|
|
288
|
+
if self.param_not_in_partition(lp_param, group_idx):
|
|
289
|
+
self.group_offset[group_idx] = self.get_param_index(lp_param, group_idx)
|
|
250
290
|
else:
|
|
251
|
-
|
|
252
|
-
continue
|
|
253
|
-
exp_avg = state['exp_avg'].narrow(0, start, numel)
|
|
254
|
-
exp_avg_sq = state['exp_avg_sq'].narrow(0, start, numel)
|
|
255
|
-
exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
|
|
256
|
-
exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
|
|
257
|
-
update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
|
|
258
|
-
ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
|
|
259
|
-
monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
|
|
260
|
-
monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
|
|
261
|
-
return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
|
|
262
|
-
|
|
291
|
+
break
|
|
263
292
|
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
|
|
284
|
-
self.is_stage3 = True
|
|
285
|
-
fp32_partitioned_groups_flat = torch_opt.fp32_partitioned_groups_flat
|
|
286
|
-
return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
|
|
290
|
-
@staticmethod
|
|
291
|
-
def get_group_index(fp32_length, world_size, index):
|
|
292
|
-
for i in range(len(fp32_length) - 1):
|
|
293
|
-
if fp32_length[i] <= index < fp32_length[i + 1]:
|
|
294
|
-
interval_start = fp32_length[i]
|
|
295
|
-
interval_length = fp32_length[i + 1] - fp32_length[i]
|
|
296
|
-
sub_interval_length = interval_length // world_size
|
|
297
|
-
sub_index = (index - interval_start) // sub_interval_length
|
|
298
|
-
sub_interval_start = interval_start + sub_index * sub_interval_length
|
|
299
|
-
return sub_interval_start, min(sub_index, world_size - 1)
|
|
300
|
-
return fp32_length[-1], 0
|
|
301
|
-
|
|
302
|
-
def get_param_index(self, params2name, name2index, torch_opt):
|
|
303
|
-
padding = torch_opt.groups_padding
|
|
304
|
-
world_size = dist.get_world_size()
|
|
305
|
-
fp32_length = [0]
|
|
306
|
-
for fp32_group_index, single_partition_of_fp32_group in enumerate(torch_opt.single_partition_of_fp32_groups):
|
|
307
|
-
fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index])
|
|
308
|
-
|
|
309
|
-
bf16_groups = []
|
|
310
|
-
name2indices = defaultdict()
|
|
311
|
-
index_length = defaultdict()
|
|
312
|
-
index = 0
|
|
313
|
-
idx = 0
|
|
314
|
-
for group_idx, bf16_group in enumerate(torch_opt.bit16_groups):
|
|
315
|
-
bf16_groups.extend(bf16_group)
|
|
316
|
-
for param in bf16_group:
|
|
317
|
-
param_length = len(param.flatten())
|
|
318
|
-
group_index, group_with_rank = self.get_group_index(fp32_length, world_size, index)
|
|
319
|
-
index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank)
|
|
320
|
-
index += param_length
|
|
321
|
-
idx += 1
|
|
322
|
-
group_length = len(bf16_groups) / len(torch_opt.bit16_groups)
|
|
323
|
-
for _, name in params2name.items():
|
|
324
|
-
name_index = name2index[name]
|
|
325
|
-
start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index]
|
|
326
|
-
need_padding = True if group_with_rank == world_size - 1 else False
|
|
327
|
-
new_start_idx = start_idx - group_index
|
|
328
|
-
new_end_idx = end_idx - group_index
|
|
329
|
-
if need_padding and group_length - 1 <= name_index <= len(bf16_groups) - 1 and name_index % (
|
|
330
|
-
group_length - 1) == 0:
|
|
331
|
-
new_end_idx -= padding[int(name_index // (group_length - 1) - 1)]
|
|
332
|
-
name2indices[name] = (new_start_idx, new_end_idx, group_idx, group_with_rank)
|
|
333
|
-
return name2indices
|
|
334
|
-
|
|
335
|
-
def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
|
|
336
|
-
fp32_partitioned_groups_flat = torch_opt.single_partition_of_fp32_groups
|
|
337
|
-
return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
class DummyOptimizerMon(OptimizerMon):
|
|
341
|
-
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
342
|
-
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
293
|
+
|
|
294
|
+
class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon):
|
|
295
|
+
def __init__(self, torch_opt):
|
|
296
|
+
super().__init__(torch_opt)
|
|
297
|
+
self.stage = '3'
|
|
298
|
+
self.bit16_groups = torch_opt.fp16_groups
|
|
299
|
+
self.fp32_flat_groups = torch_opt.fp32_partitioned_groups_flat
|
|
300
|
+
self.param2group = self.get_group_index()
|
|
301
|
+
|
|
302
|
+
def param_not_in_partition(self, lp_param, group_idx):
|
|
303
|
+
"""Each param partioned across all zero ranks"""
|
|
304
|
+
return False
|
|
305
|
+
|
|
306
|
+
def get_position(self, lp_param, group_idx):
|
|
307
|
+
param_id = self.torch_opt.get_param_id(lp_param)
|
|
308
|
+
return self.torch_opt.grad_position[param_id][1:]
|
|
309
|
+
|
|
310
|
+
def get_grad_for_param(self, lp_param, group_idx, param_id):
|
|
311
|
+
return self.torch_opt.averaged_gradients[group_idx][param_id]
|
|
343
312
|
|
|
344
313
|
|
|
345
314
|
class OptimizerMonFactory:
|
|
346
315
|
_optimizer_mon_map = {
|
|
347
|
-
"FP32Optimizer":
|
|
316
|
+
"FP32Optimizer": OptimizerMon,
|
|
348
317
|
"Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
|
|
349
318
|
"DistributedOptimizer": MegatronDistributedOptimizerMon,
|
|
319
|
+
"SwapDistributedOptimizer": MegatronDistributedOptimizerMon,
|
|
350
320
|
"ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
|
|
321
|
+
"ChainedSwapDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
|
|
351
322
|
"ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
|
|
352
323
|
"BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon,
|
|
353
324
|
"DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon,
|
|
354
325
|
"DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon,
|
|
355
|
-
"Adam":
|
|
326
|
+
"Adam": OptimizerMon
|
|
356
327
|
}
|
|
357
328
|
|
|
358
329
|
@staticmethod
|
|
@@ -361,6 +332,7 @@ class OptimizerMonFactory:
|
|
|
361
332
|
optimizer_class = optimizer.__class__.__name__
|
|
362
333
|
if optimizer_class == "ChainedOptimizer":
|
|
363
334
|
optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__
|
|
335
|
+
logger.info(f'The optimizer type is {optimizer_class}')
|
|
364
336
|
|
|
365
|
-
optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class,
|
|
366
|
-
return optimizer_mon_class()
|
|
337
|
+
optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, OptimizerMon)
|
|
338
|
+
return optimizer_mon_class(optimizer)
|