mindstudio-probe 1.3.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (213) hide show
  1. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +4 -2
  2. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +204 -152
  3. msprobe/README.md +32 -1
  4. msprobe/core/__init__.py +17 -0
  5. msprobe/core/common/const.py +120 -21
  6. msprobe/core/common/exceptions.py +2 -2
  7. msprobe/core/common/file_utils.py +279 -50
  8. msprobe/core/common/framework_adapter.py +169 -0
  9. msprobe/core/common/global_lock.py +86 -0
  10. msprobe/core/common/runtime.py +25 -0
  11. msprobe/core/common/utils.py +136 -45
  12. msprobe/core/common_config.py +7 -0
  13. msprobe/core/compare/acc_compare.py +646 -428
  14. msprobe/core/compare/check.py +36 -103
  15. msprobe/core/compare/compare_cli.py +4 -0
  16. msprobe/core/compare/config.py +72 -0
  17. msprobe/core/compare/highlight.py +215 -215
  18. msprobe/core/compare/layer_mapping/layer_mapping.py +2 -0
  19. msprobe/core/compare/merge_result/merge_result.py +4 -4
  20. msprobe/core/compare/multiprocessing_compute.py +223 -110
  21. msprobe/core/compare/npy_compare.py +2 -4
  22. msprobe/core/compare/utils.py +214 -244
  23. msprobe/core/config_check/__init__.py +17 -0
  24. msprobe/{pytorch/dump/kernel_dump/kernel_config.py → core/config_check/checkers/__init__.py} +8 -16
  25. msprobe/core/config_check/checkers/base_checker.py +60 -0
  26. msprobe/core/config_check/checkers/dataset_checker.py +138 -0
  27. msprobe/core/config_check/checkers/env_args_checker.py +96 -0
  28. msprobe/core/config_check/checkers/hyperparameter_checker.py +170 -0
  29. msprobe/core/config_check/checkers/pip_checker.py +90 -0
  30. msprobe/core/config_check/checkers/random_checker.py +367 -0
  31. msprobe/core/config_check/checkers/weights_checker.py +147 -0
  32. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +74 -0
  33. msprobe/core/config_check/ckpt_compare/megatron_loader.py +302 -0
  34. msprobe/core/config_check/ckpt_compare/metrics.py +83 -0
  35. msprobe/core/config_check/ckpt_compare/name_mapping.yaml +12 -0
  36. msprobe/core/config_check/config_check_cli.py +51 -0
  37. msprobe/core/config_check/config_checker.py +100 -0
  38. msprobe/{mindspore/runtime.py → core/config_check/resource/dependency.yaml} +7 -4
  39. msprobe/core/config_check/resource/env.yaml +57 -0
  40. msprobe/core/config_check/resource/hyperparameter.yaml +21 -0
  41. msprobe/core/config_check/utils/hyperparameter_parser.py +115 -0
  42. msprobe/core/config_check/utils/utils.py +107 -0
  43. msprobe/core/data_dump/api_registry.py +67 -4
  44. msprobe/core/data_dump/data_collector.py +170 -89
  45. msprobe/core/data_dump/data_processor/base.py +72 -51
  46. msprobe/core/data_dump/data_processor/mindspore_processor.py +109 -55
  47. msprobe/core/data_dump/data_processor/pytorch_processor.py +90 -82
  48. msprobe/core/data_dump/json_writer.py +143 -27
  49. msprobe/core/debugger/precision_debugger.py +144 -0
  50. msprobe/core/grad_probe/constant.py +1 -1
  51. msprobe/core/grad_probe/grad_compare.py +1 -1
  52. msprobe/core/grad_probe/utils.py +1 -1
  53. msprobe/core/hook_manager.py +242 -0
  54. msprobe/core/monitor/anomaly_processor.py +384 -0
  55. msprobe/core/service.py +357 -0
  56. msprobe/core/single_save/__init__.py +0 -0
  57. msprobe/core/single_save/single_comparator.py +243 -0
  58. msprobe/core/single_save/single_saver.py +146 -0
  59. msprobe/docs/01.installation.md +6 -5
  60. msprobe/docs/02.config_introduction.md +79 -22
  61. msprobe/docs/03.config_examples.md +1 -0
  62. msprobe/docs/04.kernel_dump_PyTorch.md +1 -1
  63. msprobe/docs/05.data_dump_PyTorch.md +118 -49
  64. msprobe/docs/06.data_dump_MindSpore.md +167 -20
  65. msprobe/docs/07.accuracy_checker_PyTorch.md +2 -2
  66. msprobe/docs/08.accuracy_checker_online_PyTorch.md +69 -9
  67. msprobe/docs/09.accuracy_checker_MindSpore.md +18 -6
  68. msprobe/docs/10.accuracy_compare_PyTorch.md +212 -74
  69. msprobe/docs/11.accuracy_compare_MindSpore.md +87 -37
  70. msprobe/docs/12.overflow_check_PyTorch.md +2 -2
  71. msprobe/docs/13.overflow_check_MindSpore.md +2 -2
  72. msprobe/docs/14.data_parse_PyTorch.md +3 -3
  73. msprobe/docs/17.grad_probe.md +2 -1
  74. msprobe/docs/18.online_dispatch.md +2 -2
  75. msprobe/docs/19.monitor.md +90 -44
  76. msprobe/docs/21.visualization_PyTorch.md +68 -15
  77. msprobe/docs/22.visualization_MindSpore.md +71 -18
  78. msprobe/docs/25.tool_function_introduction.md +23 -22
  79. msprobe/docs/26.data_dump_PyTorch_baseline.md +14 -3
  80. msprobe/docs/27.dump_json_instruction.md +1 -1
  81. msprobe/docs/28.debugger_save_instruction.md +111 -20
  82. msprobe/docs/29.data_dump_MSAdapter.md +2 -2
  83. msprobe/docs/30.overflow_check_MSAdapter.md +2 -2
  84. msprobe/docs/31.config_check.md +95 -0
  85. msprobe/docs/32.ckpt_compare.md +69 -0
  86. msprobe/docs/33.generate_operator_MindSpore.md +181 -0
  87. msprobe/docs/34.RL_collect.md +92 -0
  88. msprobe/docs/35.nan_analyze.md +72 -0
  89. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +12 -1
  90. msprobe/docs/data_dump_MindSpore/dynamic_graph_quick_start_example.md +3 -1
  91. msprobe/docs/img/compare_result.png +0 -0
  92. msprobe/docs/img/save_compare_result_sample.png +0 -0
  93. msprobe/docs/img/visualization/proxy.png +0 -0
  94. msprobe/mindspore/__init__.py +1 -2
  95. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +150 -58
  96. msprobe/mindspore/api_accuracy_checker/api_runner.py +7 -3
  97. msprobe/mindspore/api_accuracy_checker/bench_functions/flash_attention_score.py +47 -69
  98. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +4 -0
  99. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  100. msprobe/mindspore/api_accuracy_checker/data_manager.py +2 -2
  101. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +460 -0
  102. msprobe/mindspore/api_accuracy_checker/generate_op_script/operator_replication.template +2081 -0
  103. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +9 -0
  104. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  105. msprobe/mindspore/cell_processor.py +204 -33
  106. msprobe/mindspore/code_mapping/graph_parser.py +4 -21
  107. msprobe/mindspore/common/const.py +17 -7
  108. msprobe/mindspore/common/utils.py +128 -11
  109. msprobe/mindspore/compare/common_dir_compare.py +382 -0
  110. msprobe/mindspore/compare/distributed_compare.py +2 -26
  111. msprobe/mindspore/compare/ms_compare.py +17 -405
  112. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  113. msprobe/mindspore/compare/utils.py +37 -0
  114. msprobe/mindspore/debugger/debugger_config.py +53 -3
  115. msprobe/mindspore/debugger/precision_debugger.py +72 -91
  116. msprobe/mindspore/dump/cell_dump_process.py +877 -0
  117. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +864 -0
  118. msprobe/mindspore/dump/dump_tool_factory.py +13 -5
  119. msprobe/mindspore/dump/graph_mode_cell_dump.py +139 -0
  120. msprobe/mindspore/dump/graph_tensor_dump.py +123 -0
  121. msprobe/mindspore/dump/hook_cell/api_register.py +40 -6
  122. msprobe/mindspore/dump/hook_cell/hook_cell.py +18 -7
  123. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +88 -0
  124. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +8 -2
  125. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +18 -0
  126. msprobe/mindspore/dump/jit_dump.py +21 -18
  127. msprobe/mindspore/dump/kernel_kbyk_dump.py +6 -3
  128. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +110 -0
  129. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +15 -15
  130. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +12 -6
  131. msprobe/mindspore/free_benchmark/common/utils.py +1 -1
  132. msprobe/mindspore/grad_probe/global_context.py +7 -2
  133. msprobe/mindspore/grad_probe/grad_stat_csv.py +3 -2
  134. msprobe/mindspore/mindspore_service.py +114 -0
  135. msprobe/mindspore/monitor/common_func.py +52 -0
  136. msprobe/mindspore/monitor/data_writers.py +237 -0
  137. msprobe/mindspore/monitor/features.py +20 -7
  138. msprobe/mindspore/monitor/module_hook.py +281 -209
  139. msprobe/mindspore/monitor/optimizer_collect.py +334 -0
  140. msprobe/mindspore/monitor/utils.py +25 -5
  141. msprobe/mindspore/ms_config.py +16 -15
  142. msprobe/mindspore/task_handler_factory.py +5 -2
  143. msprobe/msprobe.py +19 -0
  144. msprobe/nan_analyze/__init__.py +14 -0
  145. msprobe/nan_analyze/analyzer.py +255 -0
  146. msprobe/nan_analyze/graph.py +189 -0
  147. msprobe/nan_analyze/utils.py +211 -0
  148. msprobe/pytorch/api_accuracy_checker/common/config.py +2 -2
  149. msprobe/pytorch/api_accuracy_checker/compare/compare.py +36 -34
  150. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  151. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +4 -7
  152. msprobe/pytorch/api_accuracy_checker/generate_op_script/operator_replication.template +204 -2
  153. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +12 -11
  154. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +1 -0
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +8 -5
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +2 -3
  157. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +29 -13
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +12 -2
  159. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +45 -31
  160. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +156 -0
  161. msprobe/pytorch/attl_manager.py +65 -0
  162. msprobe/pytorch/bench_functions/npu_fusion_attention.py +27 -0
  163. msprobe/pytorch/common/utils.py +26 -14
  164. msprobe/pytorch/compare/distributed_compare.py +4 -36
  165. msprobe/pytorch/compare/pt_compare.py +13 -84
  166. msprobe/pytorch/compare/utils.py +47 -0
  167. msprobe/pytorch/debugger/debugger_config.py +34 -17
  168. msprobe/pytorch/debugger/precision_debugger.py +66 -118
  169. msprobe/pytorch/dump/module_dump/hook_wrapper.py +93 -0
  170. msprobe/pytorch/dump/module_dump/module_dump.py +11 -58
  171. msprobe/pytorch/dump/module_dump/module_processer.py +143 -113
  172. msprobe/pytorch/grad_probe/grad_stat_csv.py +3 -2
  173. msprobe/pytorch/hook_module/api_register.py +29 -5
  174. msprobe/pytorch/hook_module/hook_module.py +9 -18
  175. msprobe/pytorch/hook_module/jit_script_wrapper.py +33 -0
  176. msprobe/pytorch/hook_module/pt_hook_manager.py +68 -0
  177. msprobe/pytorch/hook_module/support_wrap_ops.yaml +22 -1
  178. msprobe/pytorch/hook_module/utils.py +28 -2
  179. msprobe/pytorch/monitor/csv2tb.py +6 -2
  180. msprobe/pytorch/monitor/data_writers.py +259 -0
  181. msprobe/pytorch/monitor/module_hook.py +227 -158
  182. msprobe/pytorch/monitor/module_metric.py +14 -0
  183. msprobe/pytorch/monitor/optimizer_collect.py +242 -270
  184. msprobe/pytorch/monitor/utils.py +16 -3
  185. msprobe/pytorch/online_dispatch/dispatch.py +4 -2
  186. msprobe/pytorch/online_dispatch/dump_compare.py +5 -2
  187. msprobe/pytorch/parse_tool/lib/utils.py +3 -3
  188. msprobe/pytorch/pt_config.py +8 -7
  189. msprobe/pytorch/pytorch_service.py +73 -0
  190. msprobe/visualization/builder/graph_builder.py +33 -13
  191. msprobe/visualization/builder/msprobe_adapter.py +24 -11
  192. msprobe/visualization/compare/graph_comparator.py +53 -45
  193. msprobe/visualization/compare/mode_adapter.py +31 -1
  194. msprobe/visualization/graph/base_node.py +3 -3
  195. msprobe/visualization/graph/graph.py +2 -2
  196. msprobe/visualization/graph_service.py +250 -103
  197. msprobe/visualization/utils.py +27 -11
  198. msprobe/mindspore/dym_loader/hook_dynamic_loader.cc +0 -106
  199. msprobe/mindspore/monitor/anomaly_detect.py +0 -404
  200. msprobe/mindspore/monitor/module_spec_verifier.py +0 -94
  201. msprobe/mindspore/service.py +0 -549
  202. msprobe/pytorch/monitor/anomaly_analyse.py +0 -201
  203. msprobe/pytorch/monitor/anomaly_detect.py +0 -410
  204. msprobe/pytorch/monitor/module_spec_verifier.py +0 -95
  205. msprobe/pytorch/monitor/unittest/test_monitor.py +0 -160
  206. msprobe/pytorch/service.py +0 -473
  207. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  208. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  209. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  210. {mindstudio_probe-1.3.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
  211. /msprobe/{mindspore → core}/compare/ms_to_pt_api.yaml +0 -0
  212. /msprobe/{mindspore/dump → core}/kernel_dump/kernel_config.py +0 -0
  213. /msprobe/{pytorch/monitor/unittest → core/monitor}/__init__.py +0 -0
@@ -1,95 +0,0 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import re
17
- import abc
18
- import torch
19
-
20
- from msprobe.pytorch.common.log import logger
21
-
22
- # 用于存储所有validator实现类的注册表
23
- config_validator_registry = {}
24
-
25
-
26
- def register_config_validator(cls):
27
- """装饰器 用于注册ConfigValidator的实现类"""
28
- config_validator_registry[cls.__name__] = cls
29
- return cls
30
-
31
-
32
- class ConfigValidator(metaclass=abc.ABCMeta):
33
- @abc.abstractmethod
34
- def check_pattern_match(self, config_spec: str):
35
- pass
36
-
37
- @abc.abstractmethod
38
- def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
39
- pass
40
-
41
-
42
- @register_config_validator
43
- class TensorValidator(ConfigValidator):
44
- def check_pattern_match(self, config_spec: str):
45
- pattern = re.compile(r"tensor")
46
- return pattern.match(config_spec)
47
-
48
- def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
49
- if not torch.is_tensor(actual_data):
50
- raise ValueError(
51
- f"Format of {module_name} {data_type} does not match the required format 'tensor' in config.")
52
-
53
-
54
- @register_config_validator
55
- class TupleValidator(ConfigValidator):
56
- def check_pattern_match(self, config_spec: str):
57
- pattern = re.compile(r"tuple\[(\d+)\]:?(\d+)?")
58
- return pattern.match(config_spec)
59
-
60
- def validate(self, actual_data, module_name: str, data_type: str, pattern_match):
61
- length, index = pattern_match.groups()
62
- if index is None:
63
- index = 0
64
- length, index = int(length), int(index)
65
-
66
- if not (0 <= index < length):
67
- raise ValueError(
68
- f"Format of {module_name} {data_type} in config.json does not match the required format 'tuple[x]:y'."
69
- f"y must be greater than or equal to 0 and less than x.")
70
- if not isinstance(actual_data, tuple):
71
- raise ValueError(
72
- f"Type of {module_name} {data_type} does not match spec of config.json, should be tuple, please check.")
73
- if len(actual_data) != length:
74
- raise ValueError(
75
- f"Length of {module_name} {data_type} does not match spec of config.json, should be {length}, "
76
- f"actual is {len(actual_data)} please check.")
77
- return index
78
-
79
-
80
- def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str):
81
- focused_col = None
82
- if not config_spec or not isinstance(config_spec, str):
83
- return focused_col
84
- for _, validator_cls in config_validator_registry.items():
85
- config_validator = validator_cls()
86
- pattern_match = config_validator.check_pattern_match(config_spec)
87
- if pattern_match:
88
- try:
89
- focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match)
90
- except ValueError as e:
91
- logger.warning(f"config spec validate failed: {str(e)}")
92
- return focused_col
93
- logger.warning(f"config spec in {module_name} {data_type} not supported, "
94
- f"expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.")
95
- return focused_col
@@ -1,160 +0,0 @@
1
- # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import argparse
17
- import os
18
- import re
19
- from glob import glob
20
-
21
- import pandas as pd
22
-
23
- from msprobe.pytorch.common.log import logger
24
-
25
-
26
- def parse_logfile(logfile):
27
- grad_norm = []
28
- step = []
29
- with open(logfile) as f:
30
- for line in f.readlines():
31
- if 'consumed samples' in line:
32
- grad_norm.append(float(re.findall('(?<=grad norm\: )[\d\.]*', line)[0]))
33
- return grad_norm
34
-
35
-
36
- def parse_monitor_output(output_dir):
37
- reduced = {}
38
- unreduced = {}
39
- for directory in glob(output_dir + '*'):
40
- rank = int(re.findall('(?<=rank)[\d]*', directory)[0])
41
- unreduced[rank] = []
42
- reduced[rank] = []
43
- for file in os.listdir(directory):
44
- df = pd.read_csv(os.path.join(directory, file))
45
- if '_unreduced_' in file:
46
- unreduced[rank].append(df)
47
- pass
48
- elif '_reduced_' in file:
49
- reduced[rank].append(df)
50
- else:
51
- logger.info(f'unexpected file {file} in {directory}')
52
- return reduced, unreduced
53
-
54
-
55
- def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel):
56
- steps = len(reduced[0])
57
- world_size = len(reduced)
58
- errors = []
59
- for _, row in unreduced[0][0].iterrows():
60
- param = row['param_name']
61
- is_tp_duplicate = False
62
- for step in range(2):
63
- # sum reduced
64
- reduced_mean = 0.
65
- for rank in range(world_size):
66
- if len(reduced[rank]) == 0:
67
- continue
68
- df = reduced[rank][step]
69
- value = list(df[df['param_name'] == param]['mean'])
70
- if not value:
71
- if step == 0:
72
- is_tp_duplicate = True
73
- continue
74
- reduced_mean += value[0]
75
-
76
- # sum unreduced
77
- unreduced_mean = 0.
78
- for rank in range(world_size):
79
- df = unreduced[rank][step]
80
- value = list(df[df['param_name'] == param]['mean'])
81
- if not value:
82
- continue
83
- unreduced_mean += list(df[df['param_name'] == param]['mean'])[0]
84
-
85
- unreduced_mean /= dp_size
86
- if is_tp_duplicate and (not sequence_parallel or 'embedding' in param):
87
- unreduced_mean /= tp_size
88
- try:
89
- assert_equal(unreduced_mean, reduced_mean)
90
- except AssertionError as e:
91
- errors.append([param, step, e, is_tp_duplicate])
92
- if errors:
93
- logger.info(errors)
94
- else:
95
- logger.info(f'grad mean is in consist between unreduced grad and reduced grad monitored.')
96
-
97
-
98
- def assert_equal(a, b):
99
- if b == 0 or a == 0:
100
- return
101
- if b == 0:
102
- rel_diff = a
103
- elif a == 0:
104
- rel_diff = b
105
- else:
106
- rel_diff = abs(a / b - 1)
107
- assert rel_diff < 0.01, f'{a}, {b}, {rel_diff}'
108
-
109
-
110
- def valid_total_norm(total_norm, reduced, duplicate_embedding):
111
- steps = len(total_norm)
112
- world_size = len(reduced)
113
- errors = []
114
- for step in range(steps):
115
- calculated_norm = 0.
116
- for rank in range(world_size):
117
- if len(reduced[rank]) == 0:
118
- if step == 0:
119
- logger.info(f'rank {rank} is duplicated in dp group')
120
- continue
121
- for _, row in reduced[rank][step].iterrows():
122
- if duplicate_embedding and 'word_embedding' in row['param_name']:
123
- continue
124
- calculated_norm += row['norm'] ** 2
125
- try:
126
- assert_equal(calculated_norm ** 0.5, total_norm[step])
127
- except AssertionError as e:
128
- errors.append([step, e])
129
- if errors:
130
- logger.info('total norm errors: ', errors)
131
- else:
132
- logger.info('grad norm in consist between training log and reduced gradients monitored')
133
-
134
-
135
- if __name__ == "__main__":
136
- parser = argparse.ArgumentParser()
137
- parser.add_argument('--monitor_output', '-m', type=str, required=True,
138
- help='path prefix to the output of monitor e.g. monitor_output/Aug12_07-16')
139
- parser.add_argument('--logfile', '-l', type=str, required=True, help='path to the training log file')
140
- parser.add_argument('--tp_size', '-t', type=int, required=True, help='tp parallel size')
141
- parser.add_argument('--dp_size', '-d', type=int, required=True, help='dp parallel size')
142
- parser.add_argument('--pp_size', '-p', type=int, required=True, help='pp parallel size')
143
- parser.add_argument('--untie_embeddings_and_output_weights', '-u', action="store_true", default=False,
144
- help='whether untie_embeddings_and_output_weights in pp parallel')
145
- parser.add_argument('--sequence_parallel', '-s', action="store_true", default=False,
146
- help='whether sequence parallel is enabled. Add -s to store true')
147
-
148
- args = parser.parse_args()
149
-
150
- assert args.tp_size > 0, 'if tp not enabled, set tp_size = 1'
151
- assert args.dp_size > 0, 'if tp not enabled, set dp_size = 1'
152
- assert args.pp_size > 0, 'if tp not enabled, set pp_size = 1'
153
-
154
- total_norm = parse_logfile(args.logfile)
155
- reduced, unreduced = parse_monitor_output(args.monitor_output)
156
-
157
- duplicate_embedding = not args.untie_embeddings_and_output_weights and args.pp_size > 1
158
-
159
- valid_total_norm(total_norm, reduced, duplicate_embedding)
160
- valid_reduce(reduced, unreduced, args.tp_size, args.dp_size, args.sequence_parallel)