mindstudio-probe 1.2.2__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/METADATA +4 -3
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/RECORD +243 -191
- msprobe/README.md +57 -21
- msprobe/core/__init__.py +17 -0
- msprobe/core/common/const.py +224 -82
- msprobe/core/common/decorator.py +50 -0
- msprobe/core/common/exceptions.py +5 -3
- msprobe/core/common/file_utils.py +274 -40
- msprobe/core/common/framework_adapter.py +169 -0
- msprobe/core/common/global_lock.py +86 -0
- msprobe/core/common/runtime.py +25 -0
- msprobe/core/common/utils.py +148 -72
- msprobe/core/common_config.py +7 -0
- msprobe/core/compare/acc_compare.py +640 -462
- msprobe/core/compare/check.py +36 -107
- msprobe/core/compare/compare_cli.py +4 -0
- msprobe/core/compare/config.py +72 -0
- msprobe/core/compare/highlight.py +217 -215
- msprobe/core/compare/layer_mapping/layer_mapping.py +4 -1
- msprobe/core/compare/merge_result/merge_result.py +12 -6
- msprobe/core/compare/multiprocessing_compute.py +227 -107
- msprobe/core/compare/npy_compare.py +32 -16
- msprobe/core/compare/utils.py +218 -244
- msprobe/{mindspore/runtime.py → core/config_check/__init__.py} +2 -4
- msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
- msprobe/core/config_check/checkers/base_checker.py +60 -0
- msprobe/core/config_check/checkers/dataset_checker.py +138 -0
- msprobe/core/config_check/checkers/env_args_checker.py +96 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
- msprobe/core/config_check/checkers/pip_checker.py +90 -0
- msprobe/core/config_check/checkers/random_checker.py +367 -0
- msprobe/core/config_check/checkers/weights_checker.py +147 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
- msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
- msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
- msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
- msprobe/core/config_check/config_check_cli.py +51 -0
- msprobe/core/config_check/config_checker.py +100 -0
- msprobe/{pytorch/parse.py → core/config_check/resource/dependency.yaml} +7 -4
- msprobe/core/config_check/resource/env.yaml +57 -0
- msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
- msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
- msprobe/core/config_check/utils/utils.py +107 -0
- msprobe/core/data_dump/api_registry.py +239 -0
- msprobe/core/data_dump/data_collector.py +36 -9
- msprobe/core/data_dump/data_processor/base.py +74 -53
- msprobe/core/data_dump/data_processor/mindspore_processor.py +119 -78
- msprobe/core/data_dump/data_processor/pytorch_processor.py +134 -96
- msprobe/core/data_dump/json_writer.py +146 -57
- msprobe/core/debugger/precision_debugger.py +143 -0
- msprobe/core/grad_probe/constant.py +2 -1
- msprobe/core/grad_probe/grad_compare.py +2 -2
- msprobe/core/grad_probe/utils.py +1 -1
- msprobe/core/hook_manager.py +242 -0
- msprobe/core/monitor/anomaly_processor.py +384 -0
- msprobe/core/overflow_check/abnormal_scene.py +2 -0
- msprobe/core/service.py +356 -0
- msprobe/core/single_save/__init__.py +0 -0
- msprobe/core/single_save/single_comparator.py +243 -0
- msprobe/core/single_save/single_saver.py +157 -0
- msprobe/docs/01.installation.md +6 -5
- msprobe/docs/02.config_introduction.md +89 -30
- msprobe/docs/03.config_examples.md +1 -0
- msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
- msprobe/docs/05.data_dump_PyTorch.md +184 -50
- msprobe/docs/06.data_dump_MindSpore.md +193 -28
- msprobe/docs/07.accuracy_checker_PyTorch.md +13 -3
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +72 -10
- msprobe/docs/09.accuracy_checker_MindSpore.md +19 -7
- msprobe/docs/10.accuracy_compare_PyTorch.md +266 -102
- msprobe/docs/11.accuracy_compare_MindSpore.md +117 -43
- msprobe/docs/12.overflow_check_PyTorch.md +5 -3
- msprobe/docs/13.overflow_check_MindSpore.md +6 -4
- msprobe/docs/14.data_parse_PyTorch.md +4 -10
- msprobe/docs/17.grad_probe.md +2 -1
- msprobe/docs/18.online_dispatch.md +3 -3
- msprobe/docs/19.monitor.md +211 -103
- msprobe/docs/21.visualization_PyTorch.md +100 -28
- msprobe/docs/22.visualization_MindSpore.md +103 -31
- msprobe/docs/23.generate_operator_PyTorch.md +9 -9
- msprobe/docs/25.tool_function_introduction.md +23 -22
- msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
- msprobe/docs/27.dump_json_instruction.md +278 -8
- msprobe/docs/28.debugger_save_instruction.md +111 -20
- msprobe/docs/28.kernel_dump_MindSpore.md +1 -1
- msprobe/docs/29.data_dump_MSAdapter.md +229 -0
- msprobe/docs/30.overflow_check_MSAdapter.md +31 -0
- msprobe/docs/31.config_check.md +95 -0
- msprobe/docs/32.ckpt_compare.md +69 -0
- msprobe/docs/33.generate_operator_MindSpore.md +190 -0
- msprobe/docs/34.RL_collect.md +92 -0
- msprobe/docs/35.nan_analyze.md +72 -0
- msprobe/docs/FAQ.md +3 -11
- msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
- msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
- msprobe/docs/img/compare_result.png +0 -0
- msprobe/docs/img/merge_result.png +0 -0
- msprobe/docs/img/save_compare_result_sample.png +0 -0
- msprobe/docs/img/visualization/proxy.png +0 -0
- msprobe/docs/img/visualization/vis_browser_1.png +0 -0
- msprobe/docs/img/visualization/vis_match_info.png +0 -0
- msprobe/docs/img/visualization/vis_precision_info.png +0 -0
- msprobe/docs/img/visualization/vis_search_info.png +0 -0
- msprobe/docs/img/visualization/vis_show_info.png +0 -0
- msprobe/docs/img/visualization/vis_showcase.png +0 -0
- msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
- msprobe/mindspore/__init__.py +3 -3
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +151 -55
- msprobe/mindspore/api_accuracy_checker/api_runner.py +25 -11
- msprobe/mindspore/api_accuracy_checker/base_compare_algorithm.py +2 -1
- msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +580 -0
- msprobe/mindspore/api_accuracy_checker/bench_functions/fusion_operator.py +41 -0
- msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
- msprobe/mindspore/api_accuracy_checker/data_manager.py +4 -3
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +9 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +451 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
- msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +11 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/cell_processor.py +204 -33
- msprobe/mindspore/code_mapping/graph_parser.py +4 -21
- msprobe/mindspore/common/const.py +73 -2
- msprobe/mindspore/common/utils.py +157 -29
- msprobe/mindspore/compare/common_dir_compare.py +382 -0
- msprobe/mindspore/compare/distributed_compare.py +2 -26
- msprobe/mindspore/compare/ms_compare.py +18 -398
- msprobe/mindspore/compare/ms_graph_compare.py +20 -10
- msprobe/mindspore/compare/utils.py +37 -0
- msprobe/mindspore/debugger/debugger_config.py +59 -7
- msprobe/mindspore/debugger/precision_debugger.py +83 -90
- msprobe/mindspore/dump/cell_dump_process.py +902 -0
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +889 -0
- msprobe/mindspore/dump/dump_tool_factory.py +18 -8
- msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
- msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
- msprobe/mindspore/dump/hook_cell/api_register.py +176 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +22 -12
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
- msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +42 -26
- msprobe/mindspore/dump/jit_dump.py +35 -27
- msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -16
- msprobe/mindspore/free_benchmark/api_pynative_self_check.py +22 -12
- msprobe/mindspore/free_benchmark/common/utils.py +1 -1
- msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +4 -2
- msprobe/mindspore/free_benchmark/self_check_tool_factory.py +6 -3
- msprobe/mindspore/grad_probe/global_context.py +9 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -1
- msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
- msprobe/mindspore/grad_probe/hook.py +2 -4
- msprobe/mindspore/mindspore_service.py +111 -0
- msprobe/mindspore/monitor/common_func.py +52 -0
- msprobe/mindspore/monitor/data_writers.py +237 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +1 -1
- msprobe/mindspore/monitor/features.py +13 -1
- msprobe/mindspore/monitor/module_hook.py +568 -444
- msprobe/mindspore/monitor/optimizer_collect.py +331 -0
- msprobe/mindspore/monitor/utils.py +71 -9
- msprobe/mindspore/ms_config.py +16 -15
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +5 -3
- msprobe/mindspore/task_handler_factory.py +5 -2
- msprobe/msprobe.py +19 -0
- msprobe/nan_analyze/__init__.py +14 -0
- msprobe/nan_analyze/analyzer.py +255 -0
- msprobe/nan_analyze/graph.py +189 -0
- msprobe/nan_analyze/utils.py +211 -0
- msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +3 -6
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +15 -13
- msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +206 -4
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +9 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +6 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +31 -9
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +28 -20
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +3 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +154 -0
- msprobe/pytorch/attl_manager.py +65 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +6 -0
- msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
- msprobe/pytorch/common/utils.py +53 -19
- msprobe/pytorch/compare/distributed_compare.py +4 -36
- msprobe/pytorch/compare/pt_compare.py +13 -84
- msprobe/pytorch/compare/utils.py +47 -0
- msprobe/pytorch/debugger/debugger_config.py +34 -17
- msprobe/pytorch/debugger/precision_debugger.py +50 -96
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
- msprobe/pytorch/dump/module_dump/module_dump.py +15 -61
- msprobe/pytorch/dump/module_dump/module_processer.py +150 -114
- msprobe/pytorch/free_benchmark/common/utils.py +1 -1
- msprobe/pytorch/free_benchmark/compare/single_benchmark.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +3 -3
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +1 -1
- msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +1 -1
- msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +1 -1
- msprobe/pytorch/function_factory.py +1 -1
- msprobe/pytorch/grad_probe/grad_monitor.py +2 -2
- msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
- msprobe/pytorch/hook_module/api_register.py +155 -0
- msprobe/pytorch/hook_module/hook_module.py +18 -22
- msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
- msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
- msprobe/pytorch/hook_module/register_optimizer_hook.py +2 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +193 -75
- msprobe/pytorch/hook_module/utils.py +28 -2
- msprobe/pytorch/monitor/csv2tb.py +14 -4
- msprobe/pytorch/monitor/data_writers.py +259 -0
- msprobe/pytorch/monitor/distributed/wrap_distributed.py +8 -2
- msprobe/pytorch/monitor/module_hook.py +336 -241
- msprobe/pytorch/monitor/module_metric.py +17 -0
- msprobe/pytorch/monitor/optimizer_collect.py +244 -224
- msprobe/pytorch/monitor/utils.py +84 -4
- msprobe/pytorch/online_dispatch/compare.py +0 -2
- msprobe/pytorch/online_dispatch/dispatch.py +13 -2
- msprobe/pytorch/online_dispatch/dump_compare.py +8 -2
- msprobe/pytorch/online_dispatch/utils.py +3 -0
- msprobe/pytorch/parse_tool/lib/interactive_cli.py +1 -6
- msprobe/pytorch/parse_tool/lib/utils.py +5 -4
- msprobe/pytorch/pt_config.py +16 -11
- msprobe/pytorch/pytorch_service.py +70 -0
- msprobe/visualization/builder/graph_builder.py +69 -10
- msprobe/visualization/builder/msprobe_adapter.py +24 -12
- msprobe/visualization/compare/graph_comparator.py +63 -51
- msprobe/visualization/compare/mode_adapter.py +22 -20
- msprobe/visualization/graph/base_node.py +11 -4
- msprobe/visualization/graph/distributed_analyzer.py +1 -10
- msprobe/visualization/graph/graph.py +2 -13
- msprobe/visualization/graph/node_op.py +1 -2
- msprobe/visualization/graph_service.py +251 -104
- msprobe/visualization/utils.py +26 -44
- msprobe/mindspore/dump/hook_cell/api_registry.py +0 -207
- msprobe/mindspore/dump/hook_cell/wrap_api.py +0 -212
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -140
- msprobe/mindspore/monitor/anomaly_detect.py +0 -404
- msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
- msprobe/mindspore/service.py +0 -543
- msprobe/pytorch/hook_module/api_registry.py +0 -166
- msprobe/pytorch/hook_module/wrap_distributed.py +0 -79
- msprobe/pytorch/hook_module/wrap_functional.py +0 -66
- msprobe/pytorch/hook_module/wrap_npu_custom.py +0 -85
- msprobe/pytorch/hook_module/wrap_tensor.py +0 -69
- msprobe/pytorch/hook_module/wrap_torch.py +0 -84
- msprobe/pytorch/hook_module/wrap_vf.py +0 -60
- msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
- msprobe/pytorch/monitor/anomaly_detect.py +0 -410
- msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
- msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
- msprobe/pytorch/service.py +0 -470
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.2.dist-info → mindstudio_probe-8.1.0.dist-info}/top_level.txt +0 -0
- /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
- /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
- /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
|
@@ -16,6 +16,7 @@ import re
|
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
|
+
from msprobe.pytorch.common.utils import is_float8_tensor
|
|
19
20
|
from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
|
|
20
21
|
from msprobe.pytorch.monitor.utils import get_nan_tensor
|
|
21
22
|
|
|
@@ -143,6 +144,20 @@ class IdentMetric(Metric):
|
|
|
143
144
|
return tensor
|
|
144
145
|
|
|
145
146
|
|
|
147
|
+
@register_config_metric("shape")
|
|
148
|
+
class ShapeMetric(Metric):
|
|
149
|
+
@staticmethod
|
|
150
|
+
def get_metric_value(tensor, eps):
|
|
151
|
+
return tensor.shape
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@register_config_metric("dtype")
|
|
155
|
+
class DtypeMetric(Metric):
|
|
156
|
+
@staticmethod
|
|
157
|
+
def get_metric_value(tensor, eps):
|
|
158
|
+
return tensor.dtype
|
|
159
|
+
|
|
160
|
+
|
|
146
161
|
def get_metrics(ops, tag2tensor, eps, out_dict=None):
|
|
147
162
|
"""
|
|
148
163
|
:param ops: ["op1", "op2"]
|
|
@@ -166,6 +181,8 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None):
|
|
|
166
181
|
# Non-tensor in/output filled with nan.
|
|
167
182
|
out_dict[tag].update({metric_name: get_nan_tensor() for metric_name in ops})
|
|
168
183
|
continue
|
|
184
|
+
if is_float8_tensor(tensor):
|
|
185
|
+
tensor = tensor.float()
|
|
169
186
|
for metric_name in ops:
|
|
170
187
|
fun_metric = config_metric_registry.get(metric_name)
|
|
171
188
|
out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
|
|
@@ -12,129 +12,120 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
-
|
|
16
|
-
from collections import defaultdict
|
|
15
|
+
from abc import abstractmethod
|
|
17
16
|
|
|
18
17
|
import torch
|
|
19
|
-
import torch.distributed as dist
|
|
20
18
|
|
|
21
19
|
from msprobe.pytorch.common.log import logger
|
|
22
|
-
from msprobe.pytorch.monitor.utils import MVResult
|
|
20
|
+
from msprobe.pytorch.monitor.utils import MVResult
|
|
21
|
+
from msprobe.core.common.const import MonitorConst
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
class OptimizerMon(object):
|
|
26
|
-
def __init__(self) -> None:
|
|
25
|
+
def __init__(self, torch_opt) -> None:
|
|
27
26
|
self.fp16_to_fp32_param = {}
|
|
28
|
-
self.
|
|
27
|
+
self.torch_opt = torch_opt
|
|
28
|
+
self.state = {}
|
|
29
|
+
|
|
30
|
+
def narrow_from_flatten(self, param, flatten_state):
|
|
31
|
+
return flatten_state
|
|
32
|
+
|
|
33
|
+
def get_state(self, torch_opt):
|
|
34
|
+
if hasattr(torch_opt, 'chained_optimizers'):
|
|
35
|
+
for opt in torch_opt.chained_optimizers:
|
|
36
|
+
self._get_single_state(opt)
|
|
37
|
+
else:
|
|
38
|
+
self._get_single_state(torch_opt)
|
|
29
39
|
|
|
30
|
-
def
|
|
31
|
-
|
|
40
|
+
def fetch_grad(self, monitor, params2name):
|
|
41
|
+
if not self.fp16_to_fp32_param:
|
|
42
|
+
self.map_fp16_to_fp32_param(self.torch_opt)
|
|
32
43
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
exp_avg_sq_dict = defaultdict(float)
|
|
36
|
-
update_dict = defaultdict()
|
|
37
|
-
ratio_dict = defaultdict()
|
|
44
|
+
grad_dict = {}
|
|
45
|
+
first_param = True
|
|
38
46
|
for param, name in params2name.items():
|
|
39
|
-
if
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
if
|
|
47
|
-
|
|
48
|
-
|
|
47
|
+
if monitor.duplicate_param.get(name, False):
|
|
48
|
+
continue
|
|
49
|
+
if self.fp16_to_fp32_param and param not in self.fp16_to_fp32_param:
|
|
50
|
+
continue
|
|
51
|
+
grad = param.main_grad if monitor.params_have_main_grad else param.grad
|
|
52
|
+
element_in_cur_partition = self.fp16_to_fp32_param.get(param, param).numel()
|
|
53
|
+
if param.numel() != element_in_cur_partition:
|
|
54
|
+
if first_param:
|
|
55
|
+
grad = grad.flatten()[-element_in_cur_partition:]
|
|
56
|
+
else: # supposed to be the last one
|
|
57
|
+
grad = grad.flatten()[:element_in_cur_partition]
|
|
58
|
+
first_param = False
|
|
59
|
+
|
|
60
|
+
if grad is None:
|
|
61
|
+
if not monitor.fsdp_wrapped_module:
|
|
62
|
+
logger.warning(f"grad is None: {name}, maybe something wrong happened.")
|
|
63
|
+
continue
|
|
64
|
+
tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
65
|
+
monitor.register_param_call_id("hook_optimizer", tag)
|
|
66
|
+
grad_dict[tag] = grad
|
|
67
|
+
return grad_dict
|
|
68
|
+
|
|
69
|
+
def map_fp16_to_fp32_param(self, torch_opt):
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
def fetch_mv(self, monitor, params2name):
|
|
73
|
+
if not self.fp16_to_fp32_param:
|
|
74
|
+
self.map_fp16_to_fp32_param(self.torch_opt)
|
|
75
|
+
if not self.state:
|
|
76
|
+
self.get_state(self.torch_opt)
|
|
77
|
+
|
|
78
|
+
exp_avg_dict = {}
|
|
79
|
+
exp_avg_sq_dict = {}
|
|
80
|
+
update_dict = {}
|
|
81
|
+
ratio_dict = {}
|
|
82
|
+
|
|
83
|
+
if not self.state:
|
|
84
|
+
logger.warning('optimizer state can not accessed')
|
|
85
|
+
return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
|
|
86
|
+
|
|
87
|
+
for lp_param, name in params2name.items():
|
|
88
|
+
if lp_param in self.fp16_to_fp32_param:
|
|
89
|
+
hp_param = self.fp16_to_fp32_param[lp_param]
|
|
90
|
+
else:
|
|
91
|
+
hp_param = lp_param
|
|
92
|
+
|
|
93
|
+
if hp_param in self.state:
|
|
94
|
+
state_param = self.state.get(hp_param, {})
|
|
95
|
+
exp_avg = self.narrow_from_flatten(lp_param, state_param.get("exp_avg", None))
|
|
96
|
+
exp_avg_sq = self.narrow_from_flatten(lp_param, state_param.get("exp_avg_sq", None))
|
|
49
97
|
if monitor.mv_distribution:
|
|
50
98
|
exp_avg_dict[name] = exp_avg
|
|
51
99
|
exp_avg_sq_dict[name] = exp_avg_sq
|
|
52
100
|
if monitor.mg_direction:
|
|
53
101
|
exp_avg_dict[name] = exp_avg
|
|
54
102
|
if monitor.ur_distribution:
|
|
55
|
-
if len(torch_opt.param_groups) > 1:
|
|
56
|
-
logger.info(f"the length of torch_opt.param_groups is {len(torch_opt.param_groups)}.")
|
|
103
|
+
if len(self.torch_opt.param_groups) > 1:
|
|
104
|
+
logger.info(f"the length of torch_opt.param_groups is {len(self.torch_opt.param_groups)}.")
|
|
57
105
|
if 'step' in state_param:
|
|
58
106
|
step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
|
|
59
|
-
elif 'step' in torch_opt.param_groups[0]:
|
|
60
|
-
step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed
|
|
107
|
+
elif 'step' in self.torch_opt.param_groups[0]:
|
|
108
|
+
step = self.torch_opt.param_groups[0]['step'] # AdamW from mindspeed
|
|
61
109
|
else:
|
|
62
110
|
logger.warning(f"step of {name} is None, maybe something wrong happened.")
|
|
63
111
|
continue
|
|
64
|
-
exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
|
|
65
|
-
exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
|
|
66
|
-
update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
|
|
112
|
+
exp_avg_hat = exp_avg / (1 - self.torch_opt.defaults['betas'][0] ** step)
|
|
113
|
+
exp_avg_sq_hat = exp_avg_sq / (1 - self.torch_opt.defaults['betas'][1] ** step)
|
|
114
|
+
update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + self.torch_opt.defaults['eps'])
|
|
67
115
|
ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
|
|
68
116
|
monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
|
|
69
117
|
monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
|
|
70
118
|
return MVResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict)
|
|
71
|
-
|
|
72
|
-
def
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
def get_flatten_grad(self, optimizer, group_idx):
|
|
82
|
-
if fp32_partitioned_groups_flat[group_idx].grad is None:
|
|
83
|
-
if partition_id == dist.get_world_size() - 1 and not self.is_stage3:
|
|
84
|
-
fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned(
|
|
85
|
-
optimizer.averaged_gradients[group_idx],
|
|
86
|
-
int(optimizer.partition_size[group_idx])
|
|
87
|
-
).to(fp32_partitioned_groups_flat[group_idx].dtype)
|
|
88
|
-
else:
|
|
89
|
-
fp32_partitioned_groups_flat_grad = optimizer.flatten(
|
|
90
|
-
optimizer.averaged_gradients[group_idx]
|
|
91
|
-
).to(fp32_partitioned_groups_flat[group_idx].dtype)
|
|
92
|
-
return fp32_partitioned_groups_flat_grad
|
|
93
|
-
else:
|
|
94
|
-
return fp32_partitioned_groups_flat[group_idx].grad
|
|
95
|
-
|
|
96
|
-
for group_idx in range(len(fp32_partitioned_groups_flat)):
|
|
97
|
-
fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, torch_opt, group_idx)
|
|
98
|
-
|
|
99
|
-
for name in params2name.values():
|
|
100
|
-
start_idx, end_idx, group_idx, group_with_rank = name2indices[name]
|
|
101
|
-
if group_with_rank != partition_id and isinstance(group_with_rank, int):
|
|
102
|
-
continue
|
|
103
|
-
fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx]
|
|
104
|
-
fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx]
|
|
105
|
-
param2name[fp32_param] = name
|
|
106
|
-
if not torch_opt.state:
|
|
107
|
-
continue
|
|
108
|
-
state_param = list(torch_opt.state.values())[group_idx]
|
|
109
|
-
exp_avg = state_param.get("exp_avg", None)
|
|
110
|
-
exp_avg_sq = state_param.get("exp_avg_sq", None)
|
|
111
|
-
if exp_avg is None or exp_avg_sq is None:
|
|
112
|
-
logger.warning(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.")
|
|
113
|
-
continue
|
|
114
|
-
exp_avg = exp_avg[start_idx: end_idx]
|
|
115
|
-
exp_avg_sq = exp_avg_sq[start_idx: end_idx]
|
|
116
|
-
if monitor.mv_distribution:
|
|
117
|
-
exp_avg_dict[name] = exp_avg
|
|
118
|
-
exp_avg_sq_dict[name] = exp_avg_sq
|
|
119
|
-
if monitor.mg_direction:
|
|
120
|
-
exp_avg_dict[name] = exp_avg
|
|
121
|
-
if monitor.ur_distribution:
|
|
122
|
-
if 'step' in state_param:
|
|
123
|
-
step = state_param['step'] # Optimizer from pytorch or FusedAdam from apex(used by megatron)
|
|
124
|
-
elif 'step' in torch_opt.param_groups[group_idx]:
|
|
125
|
-
step = torch_opt.param_groups[group_idx]['step'] # AdamW from mindspeed
|
|
126
|
-
else:
|
|
127
|
-
logger.warning(f"step of {name} is None, maybe something wrong happened.")
|
|
128
|
-
continue
|
|
129
|
-
exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step)
|
|
130
|
-
exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step)
|
|
131
|
-
update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + torch_opt.defaults['eps'])
|
|
132
|
-
ratio_dict[name] = exp_avg_hat / torch.sqrt(exp_avg_sq_hat)
|
|
133
|
-
monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name])
|
|
134
|
-
monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name])
|
|
135
|
-
del fp32_partitioned_groups_flat_grad
|
|
136
|
-
return MVGradResult(exp_avg=exp_avg_dict, exp_avg_sq=exp_avg_sq_dict, update=update_dict, ratio=ratio_dict,
|
|
137
|
-
grad=param2name)
|
|
119
|
+
|
|
120
|
+
def _get_single_state(self, torch_opt):
|
|
121
|
+
state = {}
|
|
122
|
+
if hasattr(torch_opt, 'param_to_cpu_states_map'):
|
|
123
|
+
state = torch_opt.param_to_cpu_states_map
|
|
124
|
+
elif hasattr(torch_opt, 'state'):
|
|
125
|
+
state = torch_opt.state
|
|
126
|
+
elif hasattr(torch_opt, 'optimizer') and hasattr(torch_opt.optimizer, 'state'):
|
|
127
|
+
state = torch_opt.optimizer.state
|
|
128
|
+
self.state.update(state)
|
|
138
129
|
|
|
139
130
|
|
|
140
131
|
class MixPrecisionOptimizerMon(OptimizerMon):
|
|
@@ -142,21 +133,14 @@ class MixPrecisionOptimizerMon(OptimizerMon):
|
|
|
142
133
|
混合精度优化器监控类。在混合精度训练中监控和管理优化器。
|
|
143
134
|
混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
|
|
144
135
|
"""
|
|
145
|
-
|
|
146
|
-
def map_fp16_tp_fp32_param(self, torch_opt):
|
|
136
|
+
def map_fp16_to_fp32_param(self, torch_opt):
|
|
147
137
|
for fp16_group, fp32_group in zip(torch_opt.float16_groups, torch_opt.fp32_from_float16_groups):
|
|
148
138
|
for fp16_param, fp32_param in zip(fp16_group, fp32_group):
|
|
149
139
|
self.fp16_to_fp32_param[fp16_param] = fp32_param
|
|
150
140
|
|
|
151
|
-
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
152
|
-
if not self.fp16_to_fp32_param and torch_opt is not None:
|
|
153
|
-
self.map_fp16_tp_fp32_param(torch_opt)
|
|
154
|
-
|
|
155
|
-
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
156
|
-
|
|
157
141
|
|
|
158
142
|
class MegatronDistributedOptimizerMon(OptimizerMon):
|
|
159
|
-
def
|
|
143
|
+
def map_fp16_to_fp32_param(self, torch_opt):
|
|
160
144
|
if not (hasattr(torch_opt, "model_float16_groups") and
|
|
161
145
|
hasattr(torch_opt, "shard_fp32_from_float16_groups")):
|
|
162
146
|
raise Exception(
|
|
@@ -167,141 +151,176 @@ class MegatronDistributedOptimizerMon(OptimizerMon):
|
|
|
167
151
|
for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
|
|
168
152
|
self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
|
|
169
153
|
|
|
170
|
-
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
171
|
-
if not self.fp16_to_fp32_param and torch_opt is not None:
|
|
172
|
-
self.map_fp16_tp_fp32_param(torch_opt)
|
|
173
|
-
|
|
174
|
-
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
175
154
|
|
|
155
|
+
class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
|
|
156
|
+
def map_fp16_to_fp32_param(self, torch_opt):
|
|
157
|
+
for opt in torch_opt.chained_optimizers:
|
|
158
|
+
super().map_fp16_to_fp32_param(opt)
|
|
176
159
|
|
|
177
|
-
class MegatronFP32OptimizerMon(OptimizerMon):
|
|
178
|
-
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
179
|
-
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
180
160
|
|
|
161
|
+
class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
|
|
162
|
+
def map_fp16_to_fp32_param(self, torch_opt):
|
|
163
|
+
for opt in torch_opt.chained_optimizers:
|
|
164
|
+
super().map_fp16_to_fp32_param(opt)
|
|
181
165
|
|
|
182
|
-
class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
|
|
183
|
-
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
184
|
-
if not self.fp16_to_fp32_param and torch_opt is not None:
|
|
185
|
-
for opt in torch_opt.chained_optimizers:
|
|
186
|
-
self.map_fp16_tp_fp32_param(opt)
|
|
187
166
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
167
|
+
class DeepSpeedZeroOptimizerMon(OptimizerMon):
|
|
168
|
+
"""
|
|
169
|
+
Base monitor class for DeepSpeed ZeRO optimizer.
|
|
170
|
+
ZeRO stage 0 no partition
|
|
171
|
+
ZeRO stage 1 partitions optimizer states across data parallel processes.
|
|
172
|
+
ZeRO stage 2 additionally partitions gradients.
|
|
173
|
+
ZeRO stage 3 additionally partitions parameters.
|
|
174
|
+
|
|
175
|
+
This class provides monitoring capabilities for ZeRO optimizers by:
|
|
176
|
+
- Handling gradient collection for different ZeRO stages
|
|
177
|
+
- Managing optimizer state access for monitoring
|
|
178
|
+
"""
|
|
179
|
+
def __init__(self, torch_opt):
|
|
180
|
+
super().__init__(torch_opt)
|
|
181
|
+
self.stage = ''
|
|
182
|
+
self.bit16_groups = []
|
|
183
|
+
self.fp32_flat_groups = []
|
|
184
|
+
self.param2group = ()
|
|
185
|
+
self.param2index = []
|
|
186
|
+
self.group_offset = {}
|
|
187
|
+
|
|
188
|
+
@abstractmethod
|
|
189
|
+
def get_grad_for_param(self, lp_param, group_idx, param_id):
|
|
190
|
+
raise NotImplementedError
|
|
191
|
+
|
|
192
|
+
def param_not_in_partition(self, lp_param, group_idx):
|
|
193
|
+
param_slice_mapping = self.torch_opt.state_dict()['param_slice_mappings'][group_idx]
|
|
194
|
+
hp_address = param_slice_mapping.get(self.torch_opt.param_names.get(lp_param))
|
|
195
|
+
return hp_address is None
|
|
196
|
+
|
|
197
|
+
def get_position(self, lp_param, group_idx):
|
|
198
|
+
param_slice_mapping = self.torch_opt.state_dict()['param_slice_mappings'][group_idx]
|
|
199
|
+
hp_address = param_slice_mapping.get(self.torch_opt.param_names.get(lp_param))
|
|
200
|
+
return hp_address.start, hp_address.numel
|
|
201
|
+
|
|
202
|
+
def get_group_index(self):
|
|
203
|
+
param2group = {}
|
|
204
|
+
for group_idx, bit16_group in enumerate(self.bit16_groups):
|
|
205
|
+
for param in bit16_group:
|
|
206
|
+
param2group[param] = group_idx
|
|
207
|
+
return param2group
|
|
208
|
+
|
|
209
|
+
def get_param_index(self, lp_param, group_idx):
|
|
210
|
+
if not self.param2index:
|
|
211
|
+
for group in self.bit16_groups:
|
|
212
|
+
param2index = {}
|
|
213
|
+
for index, param in enumerate(group):
|
|
214
|
+
param2index[param] = index
|
|
215
|
+
self.param2index.append(param2index)
|
|
216
|
+
|
|
217
|
+
return self.param2index[group_idx][lp_param]
|
|
218
|
+
|
|
219
|
+
def narrow_from_flatten(self, param, flatten_state):
|
|
220
|
+
if flatten_state is None:
|
|
221
|
+
return flatten_state
|
|
222
|
+
group_idx = self.param2group[param]
|
|
223
|
+
if self.param_not_in_partition(param, group_idx):
|
|
224
|
+
return None
|
|
225
|
+
start, numel = self.get_position(param, group_idx)
|
|
226
|
+
return flatten_state.narrow(0, start, numel)
|
|
227
|
+
|
|
228
|
+
def map_fp16_to_fp32_param(self, torch_opt):
|
|
229
|
+
for group_idx, group in enumerate(self.bit16_groups):
|
|
230
|
+
for param in group:
|
|
231
|
+
self.fp16_to_fp32_param[param] = self.fp32_flat_groups[group_idx]
|
|
232
|
+
|
|
233
|
+
def fetch_grad(self, monitor, params2name):
|
|
234
|
+
grad_dict = {}
|
|
235
|
+
for lp_param, name in params2name.items():
|
|
236
|
+
group_idx = self.param2group[lp_param]
|
|
237
|
+
param_id = self.get_param_index(lp_param, group_idx)
|
|
238
|
+
if self.param_not_in_partition(lp_param, group_idx):
|
|
239
|
+
continue
|
|
240
|
+
if self.stage == '1or2':
|
|
241
|
+
param_id = param_id - self.group_offset[group_idx] - 1
|
|
242
|
+
grad = self.get_grad_for_param(lp_param, group_idx, param_id)
|
|
243
|
+
tag = monitor.name2tag.get(name, {}).get(MonitorConst.POST_GRAD)
|
|
244
|
+
monitor.register_param_call_id("hook_optimizer", tag)
|
|
245
|
+
grad_dict[tag] = grad
|
|
246
|
+
|
|
247
|
+
return grad_dict
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class DeepSpeedZeroOptimizerStage0Mon(DeepSpeedZeroOptimizerMon):
|
|
251
|
+
def __init__(self, torch_opt):
|
|
252
|
+
super().__init__(torch_opt)
|
|
253
|
+
self.stage = '0'
|
|
254
|
+
self.bit16_groups = torch_opt.bf16_groups
|
|
255
|
+
self.fp32_flat_groups = torch_opt.fp32_groups_flat_partition
|
|
256
|
+
self.param2group = self.get_group_index()
|
|
257
|
+
|
|
258
|
+
def get_grad_for_param(self, lp_param, group_idx, param_id):
|
|
259
|
+
return self.torch_opt.fp32_groups_gradient_dict[group_idx][param_id]
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class DeepSpeedZeroOptimizerStage1or2Mon(DeepSpeedZeroOptimizerMon):
|
|
263
|
+
def __init__(self, torch_opt):
|
|
264
|
+
super().__init__(torch_opt)
|
|
265
|
+
self.stage = '1or2'
|
|
266
|
+
self.bit16_groups = torch_opt.bit16_groups
|
|
267
|
+
self.fp32_flat_groups = torch_opt.single_partition_of_fp32_groups
|
|
268
|
+
self.param2group = self.get_group_index()
|
|
269
|
+
self.group_offset = {}
|
|
270
|
+
self.get_group_offset()
|
|
271
|
+
|
|
272
|
+
def get_grad_for_param(self, lp_param, group_idx, param_id):
|
|
273
|
+
if getattr(self.torch_opt, "cpu_offload", False):
|
|
274
|
+
grads = self.torch_opt.single_partition_of_fp32_groups[group_idx].grad
|
|
275
|
+
start, numel = self.get_position(lp_param, group_idx)
|
|
276
|
+
grad = grads.narrow(0, start, numel)
|
|
277
|
+
else:
|
|
278
|
+
grad = self.torch_opt.averaged_gradients[group_idx][param_id]
|
|
279
|
+
return grad
|
|
280
|
+
|
|
281
|
+
def get_group_offset(self):
|
|
282
|
+
for group_idx, group in enumerate(self.bit16_groups):
|
|
283
|
+
self.group_offset[group_idx] = -1
|
|
284
|
+
for lp_param in group:
|
|
285
|
+
if self.param_not_in_partition(lp_param, group_idx):
|
|
286
|
+
self.group_offset[group_idx] = self.get_param_index(lp_param, group_idx)
|
|
287
|
+
else:
|
|
288
|
+
break
|
|
193
289
|
|
|
194
290
|
|
|
195
|
-
class
|
|
196
|
-
def
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
291
|
+
class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon):
|
|
292
|
+
def __init__(self, torch_opt):
|
|
293
|
+
super().__init__(torch_opt)
|
|
294
|
+
self.stage = '3'
|
|
295
|
+
self.bit16_groups = torch_opt.fp16_groups
|
|
296
|
+
self.fp32_flat_groups = torch_opt.fp32_partitioned_groups_flat
|
|
297
|
+
self.param2group = self.get_group_index()
|
|
200
298
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
def
|
|
210
|
-
return self.
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
|
|
214
|
-
def get_param_index(self, params2name, name2index, torch_opt):
|
|
215
|
-
fp16_groups = torch_opt.fp16_partitioned_groups
|
|
216
|
-
name2indices = defaultdict()
|
|
217
|
-
index_length = defaultdict()
|
|
218
|
-
index = 0
|
|
219
|
-
idx = 0
|
|
220
|
-
for group_idx, fp16_group in enumerate(fp16_groups):
|
|
221
|
-
for param in fp16_group:
|
|
222
|
-
param_length = len(param.flatten())
|
|
223
|
-
index_length[idx] = (index, index + param_length, group_idx)
|
|
224
|
-
index += param_length
|
|
225
|
-
idx += 1
|
|
226
|
-
for _, name in params2name.items():
|
|
227
|
-
idx = name2index[name]
|
|
228
|
-
start_idx, end_idx, group_idx = index_length[idx]
|
|
229
|
-
name2indices[name] = (start_idx, end_idx, group_idx, None)
|
|
230
|
-
return name2indices
|
|
231
|
-
|
|
232
|
-
def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
|
|
233
|
-
self.is_stage3 = True
|
|
234
|
-
fp32_partitioned_groups_flat = torch_opt.fp32_partitioned_groups_flat
|
|
235
|
-
return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
|
|
239
|
-
@staticmethod
|
|
240
|
-
def get_group_index(fp32_length, world_size, index):
|
|
241
|
-
for i in range(len(fp32_length) - 1):
|
|
242
|
-
if fp32_length[i] <= index < fp32_length[i + 1]:
|
|
243
|
-
interval_start = fp32_length[i]
|
|
244
|
-
interval_length = fp32_length[i + 1] - fp32_length[i]
|
|
245
|
-
sub_interval_length = interval_length // world_size
|
|
246
|
-
sub_index = (index - interval_start) // sub_interval_length
|
|
247
|
-
sub_interval_start = interval_start + sub_index * sub_interval_length
|
|
248
|
-
return sub_interval_start, min(sub_index, world_size - 1)
|
|
249
|
-
return fp32_length[-1], 0
|
|
250
|
-
|
|
251
|
-
def get_param_index(self, params2name, name2index, torch_opt):
|
|
252
|
-
padding = torch_opt.groups_padding
|
|
253
|
-
world_size = dist.get_world_size()
|
|
254
|
-
fp32_length = [0]
|
|
255
|
-
for fp32_group_index, single_partition_of_fp32_group in enumerate(torch_opt.single_partition_of_fp32_groups):
|
|
256
|
-
fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index])
|
|
257
|
-
|
|
258
|
-
bf16_groups = []
|
|
259
|
-
name2indices = defaultdict()
|
|
260
|
-
index_length = defaultdict()
|
|
261
|
-
index = 0
|
|
262
|
-
idx = 0
|
|
263
|
-
for group_idx, bf16_group in enumerate(torch_opt.bit16_groups):
|
|
264
|
-
bf16_groups.extend(bf16_group)
|
|
265
|
-
for param in bf16_group:
|
|
266
|
-
param_length = len(param.flatten())
|
|
267
|
-
group_index, group_with_rank = self.get_group_index(fp32_length, world_size, index)
|
|
268
|
-
index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank)
|
|
269
|
-
index += param_length
|
|
270
|
-
idx += 1
|
|
271
|
-
group_length = len(bf16_groups) / len(torch_opt.bit16_groups)
|
|
272
|
-
for _, name in params2name.items():
|
|
273
|
-
name_index = name2index[name]
|
|
274
|
-
start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index]
|
|
275
|
-
need_padding = True if group_with_rank == world_size - 1 else False
|
|
276
|
-
new_start_idx = start_idx - group_index
|
|
277
|
-
new_end_idx = end_idx - group_index
|
|
278
|
-
if need_padding and group_length - 1 <= name_index <= len(bf16_groups) - 1 and name_index % (
|
|
279
|
-
group_length - 1) == 0:
|
|
280
|
-
new_end_idx -= padding[int(name_index // (group_length - 1) - 1)]
|
|
281
|
-
name2indices[name] = (new_start_idx, new_end_idx, group_idx, group_with_rank)
|
|
282
|
-
return name2indices
|
|
283
|
-
|
|
284
|
-
def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
|
|
285
|
-
fp32_partitioned_groups_flat = torch_opt.single_partition_of_fp32_groups
|
|
286
|
-
return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
class DummyOptimizerMon(OptimizerMon):
|
|
290
|
-
def fetch_mv(self, monitor, torch_opt, params2name):
|
|
291
|
-
return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
|
|
299
|
+
def param_not_in_partition(self, param, group_index):
|
|
300
|
+
"""Each param partioned across all zero ranks"""
|
|
301
|
+
return False
|
|
302
|
+
|
|
303
|
+
def get_position(self, lp_param, group_idx):
|
|
304
|
+
param_id = self.torch_opt.get_param_id(lp_param)
|
|
305
|
+
return self.torch_opt.grad_position[param_id][1:]
|
|
306
|
+
|
|
307
|
+
def get_grad_for_param(self, lp_param, group_idx, param_id):
|
|
308
|
+
return self.torch_opt.averaged_gradients[group_idx][param_id]
|
|
292
309
|
|
|
293
310
|
|
|
294
311
|
class OptimizerMonFactory:
|
|
295
312
|
_optimizer_mon_map = {
|
|
296
|
-
"FP32Optimizer":
|
|
313
|
+
"FP32Optimizer": OptimizerMon,
|
|
297
314
|
"Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
|
|
298
315
|
"DistributedOptimizer": MegatronDistributedOptimizerMon,
|
|
316
|
+
"SwapDistributedOptimizer": MegatronDistributedOptimizerMon,
|
|
299
317
|
"ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
|
|
318
|
+
"ChainedSwapDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
|
|
300
319
|
"ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
|
|
301
320
|
"BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon,
|
|
302
321
|
"DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon,
|
|
303
322
|
"DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon,
|
|
304
|
-
"Adam":
|
|
323
|
+
"Adam": OptimizerMon
|
|
305
324
|
}
|
|
306
325
|
|
|
307
326
|
@staticmethod
|
|
@@ -310,6 +329,7 @@ class OptimizerMonFactory:
|
|
|
310
329
|
optimizer_class = optimizer.__class__.__name__
|
|
311
330
|
if optimizer_class == "ChainedOptimizer":
|
|
312
331
|
optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__
|
|
332
|
+
logger.info(f'The optimizer type is {optimizer_class}')
|
|
313
333
|
|
|
314
|
-
optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class,
|
|
315
|
-
return optimizer_mon_class()
|
|
334
|
+
optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, OptimizerMon)
|
|
335
|
+
return optimizer_mon_class(optimizer)
|