mindstudio-probe 1.1.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +3 -2
  2. mindstudio_probe-1.2.2.dist-info/RECORD +415 -0
  3. msprobe/CMakeLists.txt +5 -0
  4. msprobe/README.md +16 -21
  5. msprobe/config.json +1 -0
  6. msprobe/core/common/const.py +185 -11
  7. msprobe/core/common/exceptions.py +3 -1
  8. msprobe/core/common/file_utils.py +33 -7
  9. msprobe/core/common/inplace_ops.yaml +4 -0
  10. msprobe/core/common/utils.py +42 -14
  11. msprobe/core/common_config.py +6 -0
  12. msprobe/core/compare/acc_compare.py +139 -128
  13. msprobe/core/compare/check.py +31 -29
  14. msprobe/core/compare/compare_cli.py +17 -16
  15. msprobe/core/compare/highlight.py +186 -99
  16. msprobe/core/compare/layer_mapping/data_scope_parser.py +19 -8
  17. msprobe/core/compare/layer_mapping/layer_mapping.py +21 -14
  18. msprobe/core/compare/layer_mapping/postprocess_pass.py +4 -3
  19. msprobe/core/compare/merge_result/merge_result.py +381 -0
  20. msprobe/core/compare/merge_result/merge_result_cli.py +31 -0
  21. msprobe/core/compare/merge_result/utils.py +81 -0
  22. msprobe/core/compare/multiprocessing_compute.py +2 -2
  23. msprobe/core/compare/npy_compare.py +109 -147
  24. msprobe/core/compare/utils.py +199 -69
  25. msprobe/core/data_dump/data_collector.py +100 -25
  26. msprobe/core/data_dump/data_processor/base.py +130 -28
  27. msprobe/core/data_dump/data_processor/factory.py +8 -3
  28. msprobe/core/data_dump/data_processor/mindspore_processor.py +170 -23
  29. msprobe/core/data_dump/data_processor/pytorch_processor.py +175 -64
  30. msprobe/core/data_dump/json_writer.py +54 -8
  31. msprobe/core/data_dump/scope.py +19 -18
  32. msprobe/core/overflow_check/abnormal_scene.py +9 -5
  33. msprobe/core/overflow_check/checker.py +1 -1
  34. msprobe/core/overflow_check/utils.py +1 -1
  35. msprobe/docs/01.installation.md +121 -17
  36. msprobe/docs/02.config_introduction.md +18 -16
  37. msprobe/docs/03.config_examples.md +24 -0
  38. msprobe/docs/05.data_dump_PyTorch.md +107 -58
  39. msprobe/docs/06.data_dump_MindSpore.md +95 -34
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +18 -18
  41. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -6
  42. msprobe/docs/10.accuracy_compare_PyTorch.md +99 -41
  43. msprobe/docs/11.accuracy_compare_MindSpore.md +249 -48
  44. msprobe/docs/12.overflow_check_PyTorch.md +1 -1
  45. msprobe/docs/19.monitor.md +310 -220
  46. msprobe/docs/21.visualization_PyTorch.md +125 -35
  47. msprobe/docs/22.visualization_MindSpore.md +149 -41
  48. msprobe/docs/23.generate_operator_PyTorch.md +107 -0
  49. msprobe/docs/24.code_mapping_Mindspore.md +28 -0
  50. msprobe/docs/{23.tool_function_introduction.md → 25.tool_function_introduction.md} +1 -0
  51. msprobe/docs/26.data_dump_PyTorch_baseline.md +37 -0
  52. msprobe/docs/27.dump_json_instruction.md +525 -0
  53. msprobe/docs/28.debugger_save_instruction.md +94 -0
  54. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  55. msprobe/docs/FAQ.md +26 -2
  56. msprobe/docs/accuracy_checker_MindSpore/accuracy_checker_MindSpore_baseline.md +14 -0
  57. msprobe/docs/data_dump_MindSpore/data_dump_MindSpore_baseline.md +22 -0
  58. msprobe/docs/img/merge_result.png +0 -0
  59. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  60. msprobe/docs/img/visualization/fuzzy_match_ms.png +0 -0
  61. msprobe/docs/img/visualization/fuzzy_match_pt.png +0 -0
  62. msprobe/docs/img/visualization/tensorboard_1.png +0 -0
  63. msprobe/docs/img/visualization/tensorboard_2.png +0 -0
  64. msprobe/docs/img/visualization/vis_browser_1.png +0 -0
  65. msprobe/docs/img/visualization/vis_browser_2.png +0 -0
  66. msprobe/docs/img/visualization/vis_precision_info.png +0 -0
  67. msprobe/docs/img/visualization/vis_search_info.png +0 -0
  68. msprobe/docs/img/visualization/vis_show_info.png +0 -0
  69. msprobe/docs/img/visualization/vis_showcase.png +0 -0
  70. msprobe/docs/img/visualization/vis_unmatch_info.png +0 -0
  71. msprobe/docs/visualization/GPTModel.png +0 -0
  72. msprobe/docs/visualization/ParallelMLP.png +0 -0
  73. msprobe/docs/visualization/layer_mapping_example.md +132 -0
  74. msprobe/docs/visualization/mapping.png +0 -0
  75. msprobe/docs/visualization/mapping1.png +0 -0
  76. msprobe/docs/visualization/module_name.png +0 -0
  77. msprobe/docs/visualization/module_name1.png +0 -0
  78. msprobe/docs/visualization/no_mapping.png +0 -0
  79. msprobe/docs/visualization/no_mapping1.png +0 -0
  80. msprobe/docs/visualization/no_mapping_analyze.png +0 -0
  81. msprobe/docs/visualization/top_layer.png +0 -0
  82. msprobe/mindspore/__init__.py +11 -0
  83. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +80 -28
  84. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  85. msprobe/mindspore/api_accuracy_checker/cmd_parser.py +2 -1
  86. msprobe/mindspore/api_accuracy_checker/compute_element.py +52 -8
  87. msprobe/mindspore/api_accuracy_checker/data_manager.py +37 -0
  88. msprobe/mindspore/api_accuracy_checker/main.py +1 -0
  89. msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py +12 -6
  90. msprobe/mindspore/api_accuracy_checker/multi_data_manager.py +3 -1
  91. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  92. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  93. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  94. msprobe/mindspore/code_mapping/bind.py +264 -0
  95. msprobe/mindspore/code_mapping/cmd_parser.py +40 -0
  96. msprobe/mindspore/code_mapping/graph.py +49 -0
  97. msprobe/mindspore/code_mapping/graph_parser.py +226 -0
  98. msprobe/mindspore/code_mapping/main.py +24 -0
  99. msprobe/mindspore/code_mapping/processor.py +34 -0
  100. msprobe/mindspore/common/const.py +3 -1
  101. msprobe/mindspore/common/utils.py +68 -5
  102. msprobe/mindspore/compare/distributed_compare.py +0 -2
  103. msprobe/mindspore/compare/ms_compare.py +105 -63
  104. msprobe/mindspore/compare/ms_graph_compare.py +14 -5
  105. msprobe/mindspore/debugger/debugger_config.py +28 -2
  106. msprobe/mindspore/debugger/precision_debugger.py +100 -12
  107. msprobe/mindspore/dump/hook_cell/api_registry.py +85 -16
  108. msprobe/mindspore/dump/hook_cell/hook_cell.py +60 -38
  109. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +33 -15
  110. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +11 -1
  111. msprobe/mindspore/dump/hook_cell/wrap_api.py +92 -1
  112. msprobe/mindspore/dump/jit_dump.py +7 -6
  113. msprobe/mindspore/dump/kernel_dump/kernel_config.py +33 -0
  114. msprobe/mindspore/dump/kernel_graph_dump.py +7 -0
  115. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +13 -4
  116. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +2 -2
  117. msprobe/mindspore/grad_probe/grad_analyzer.py +24 -12
  118. msprobe/mindspore/grad_probe/hook.py +13 -4
  119. msprobe/mindspore/mindtorch/__init__.py +18 -0
  120. msprobe/mindspore/mindtorch/mindtorch_adaptor.py +255 -0
  121. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  122. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  123. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  124. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  125. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  126. msprobe/mindspore/monitor/features.py +63 -0
  127. msprobe/mindspore/monitor/module_hook.py +821 -0
  128. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  129. msprobe/mindspore/monitor/utils.py +267 -0
  130. msprobe/mindspore/ms_config.py +13 -3
  131. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +7 -0
  132. msprobe/mindspore/service.py +347 -107
  133. msprobe/msprobe.py +24 -3
  134. msprobe/pytorch/__init__.py +7 -7
  135. msprobe/pytorch/api_accuracy_checker/common/utils.py +31 -16
  136. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +41 -8
  137. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +100 -267
  138. msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +4 -1
  139. msprobe/pytorch/api_accuracy_checker/compare/compare.py +69 -68
  140. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +54 -0
  141. msprobe/pytorch/api_accuracy_checker/compare/compare_input.py +51 -0
  142. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +2 -4
  143. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +55 -31
  144. msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +106 -0
  145. msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py +107 -0
  146. msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +151 -0
  147. msprobe/pytorch/api_accuracy_checker/precision_standard/benchmark_compare.py +226 -0
  148. msprobe/pytorch/api_accuracy_checker/precision_standard/binary_consistency.py +68 -0
  149. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +218 -0
  150. msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +104 -0
  151. msprobe/pytorch/api_accuracy_checker/precision_standard/thousandth_standard.py +63 -0
  152. msprobe/pytorch/api_accuracy_checker/precision_standard/ulp_compare.py +200 -0
  153. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +57 -1
  154. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +2 -1
  155. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +42 -14
  156. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +64 -19
  157. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +34 -4
  158. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +5 -3
  159. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  160. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  161. msprobe/pytorch/bench_functions/mish.py +21 -0
  162. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  163. msprobe/pytorch/bench_functions/npu_fusion_attention.py +42 -10
  164. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  165. msprobe/pytorch/common/parse_json.py +2 -1
  166. msprobe/pytorch/common/utils.py +116 -2
  167. msprobe/pytorch/compare/distributed_compare.py +17 -29
  168. msprobe/pytorch/compare/pt_compare.py +40 -20
  169. msprobe/pytorch/debugger/debugger_config.py +42 -17
  170. msprobe/pytorch/debugger/precision_debugger.py +56 -12
  171. msprobe/pytorch/dump/module_dump/__init__.py +0 -0
  172. msprobe/pytorch/dump/module_dump/module_dump.py +86 -0
  173. msprobe/pytorch/dump/module_dump/module_processer.py +204 -0
  174. msprobe/pytorch/free_benchmark/common/params.py +2 -1
  175. msprobe/pytorch/free_benchmark/common/utils.py +3 -0
  176. msprobe/pytorch/free_benchmark/compare/grad_saver.py +0 -2
  177. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +31 -47
  178. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +0 -4
  179. msprobe/pytorch/function_factory.py +7 -1
  180. msprobe/pytorch/hook_module/__init__.py +1 -1
  181. msprobe/pytorch/hook_module/hook_module.py +14 -11
  182. msprobe/pytorch/hook_module/register_optimizer_hook.py +59 -0
  183. msprobe/pytorch/hook_module/support_wrap_ops.yaml +36 -1
  184. msprobe/pytorch/hook_module/wrap_distributed.py +10 -8
  185. msprobe/pytorch/hook_module/wrap_functional.py +0 -40
  186. msprobe/pytorch/monitor/anomaly_analyse.py +1 -1
  187. msprobe/pytorch/monitor/anomaly_detect.py +98 -28
  188. msprobe/pytorch/monitor/csv2tb.py +164 -0
  189. msprobe/pytorch/monitor/distributed/wrap_distributed.py +25 -14
  190. msprobe/pytorch/monitor/features.py +3 -3
  191. msprobe/pytorch/monitor/module_hook.py +543 -318
  192. msprobe/pytorch/monitor/module_metric.py +27 -48
  193. msprobe/pytorch/monitor/module_spec_verifier.py +3 -1
  194. msprobe/pytorch/monitor/optimizer_collect.py +76 -56
  195. msprobe/pytorch/monitor/unittest/test_monitor.py +24 -9
  196. msprobe/pytorch/monitor/utils.py +84 -48
  197. msprobe/pytorch/online_dispatch/dispatch.py +8 -2
  198. msprobe/pytorch/parse_tool/lib/compare.py +10 -10
  199. msprobe/pytorch/parse_tool/lib/config.py +5 -7
  200. msprobe/pytorch/parse_tool/lib/file_desc.py +15 -1
  201. msprobe/pytorch/parse_tool/lib/interactive_cli.py +10 -10
  202. msprobe/pytorch/parse_tool/lib/parse_exception.py +7 -7
  203. msprobe/pytorch/parse_tool/lib/parse_tool.py +11 -10
  204. msprobe/pytorch/parse_tool/lib/utils.py +18 -19
  205. msprobe/pytorch/parse_tool/lib/visualization.py +9 -10
  206. msprobe/pytorch/pt_config.py +19 -22
  207. msprobe/pytorch/service.py +264 -115
  208. msprobe/visualization/builder/graph_builder.py +93 -10
  209. msprobe/visualization/builder/msprobe_adapter.py +30 -6
  210. msprobe/visualization/compare/graph_comparator.py +64 -14
  211. msprobe/visualization/compare/mode_adapter.py +1 -15
  212. msprobe/visualization/graph/base_node.py +15 -19
  213. msprobe/visualization/graph/distributed_analyzer.py +395 -0
  214. msprobe/visualization/graph/graph.py +9 -0
  215. msprobe/visualization/graph/node_op.py +4 -2
  216. msprobe/visualization/graph_service.py +100 -27
  217. msprobe/visualization/utils.py +24 -31
  218. mindstudio_probe-1.1.1.dist-info/RECORD +0 -341
  219. msprobe/pytorch/functional/module_dump.py +0 -84
  220. msprobe/pytorch/module_processer.py +0 -150
  221. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  222. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  223. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  224. {mindstudio_probe-1.1.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
  225. /msprobe/docs/{data_dump_Mindspore → data_dump_MindSpore}/dynamic_graph_quick_start_example.md +0 -0
  226. /msprobe/{pytorch/functional → mindspore/code_mapping}/__init__.py +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,16 +12,12 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- import itertools
16
- import math
17
15
  import re
18
- import statistics
19
16
 
20
17
  import torch
21
18
 
22
- from msprobe.core.common.const import MonitorConst
23
- from msprobe.pytorch.monitor.features import square_sum, get_max, get_min, get_zeros, get_nans, get_norm, get_mean
24
- from msprobe.core.common.log import logger
19
+ from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
20
+ from msprobe.pytorch.monitor.utils import get_nan_tensor
25
21
 
26
22
 
27
23
  def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
@@ -31,7 +27,9 @@ def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank):
31
27
  return f"{module_or_param_name}/rank{rank}/{tag}"
32
28
 
33
29
 
34
- def squash_param_name(param_name):
30
+ def squash_param_name(param_name, enable=True):
31
+ if not enable:
32
+ return param_name
35
33
  name = ''
36
34
  for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']:
37
35
  match = re.findall(pattern, param_name)
@@ -63,7 +61,7 @@ class TensorMetrics:
63
61
  self.metrics = {} # tensor_tag --> []
64
62
  self.cur_idx = {}
65
63
 
66
- def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank, eps=1e-8):
64
+ def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank):
67
65
  """get stats and insert into metrics dictionary"""
68
66
  prefix = get_summary_writer_tag_name(module_name, tensor_name, rank)
69
67
  for stat_op in stat_ops:
@@ -120,14 +118,14 @@ class NormMetric(Metric):
120
118
  @staticmethod
121
119
  def get_metric_value(tensor, eps):
122
120
  return get_norm(tensor)
123
-
121
+
124
122
 
125
123
  @register_config_metric("zeros")
126
124
  class ZerosMetric(Metric):
127
125
  @staticmethod
128
126
  def get_metric_value(tensor, eps):
129
127
  return get_zeros(tensor, eps)
130
-
128
+
131
129
 
132
130
  @register_config_metric("nans")
133
131
  class NaNsMetric(Metric):
@@ -146,48 +144,29 @@ class IdentMetric(Metric):
146
144
 
147
145
 
148
146
  def get_metrics(ops, tag2tensor, eps, out_dict=None):
147
+ """
148
+ :param ops: ["op1", "op2"]
149
+ :param tag2tensor: {
150
+ '0:fc.input:0/actv': torch.randn([3, 4]),
151
+ '0:fc.output:0/actv': torch.randn([3, 3])
152
+ }
153
+ :param eps: float 1e-8
154
+ :param out_dict:{
155
+ '0:fc.input:0/actv': {"op1": op1(torch.randn([3, 4])), "op2": op2(torch.randn([3, 4]))}
156
+ '0:fc.output:0/actv': {"op1": op1(torch.randn([3, 3])), "op2": op2(torch.randn([3, 3]))}
157
+ }
158
+ :return: out_dict
159
+ """
149
160
  if out_dict is None:
150
161
  out_dict = {}
151
162
  for tag, tensor in tag2tensor.items():
152
163
  if tag not in out_dict:
153
164
  out_dict[tag] = {}
154
- for metric_name in ops:
165
+ if not torch.is_tensor(tensor):
166
+ # Non-tensor in/output filled with nan.
167
+ out_dict[tag].update({metric_name: get_nan_tensor() for metric_name in ops})
168
+ continue
169
+ for metric_name in ops:
155
170
  fun_metric = config_metric_registry.get(metric_name)
156
171
  out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
157
172
  return out_dict
158
-
159
-
160
- def write_metrics_base(ops, summary_writer, metric_value, step, prefix=''):
161
- if not metric_value:
162
- return
163
- tensors = []
164
- tags = list(itertools.product(metric_value.keys(), ops))
165
- for op2tensor in metric_value.values():
166
- tensors.extend(op2tensor.values())
167
- with torch.no_grad():
168
- metric_list = torch.stack(tensors).cpu()
169
- for tag, metric in zip(tags, metric_list):
170
- summary_writer.add_scalar(tag, metric, step)
171
-
172
-
173
- def write_metrics_csv(ops, summary_writer, metric_value, step, prefix=''):
174
- write_metrics_base(ops, summary_writer, metric_value, step, prefix='')
175
-
176
- if not summary_writer.header:
177
- # 前向的norm用input.ops_和output.ops_,反向的用input_grad.ops_和output_grad.ops_
178
- if prefix in {"actv", "actv_grad"}:
179
- if prefix == "actv":
180
- input_and_output = [MonitorConst.ACTV_IN, MonitorConst.ACTV_OUT]
181
- else:
182
- input_and_output = [MonitorConst.ACTVGRAD_IN, MonitorConst.ACTVGRAD_OUT]
183
- ops_ = [MonitorConst.DOT.join(i[::-1]) for i in itertools.product(ops, input_and_output)]
184
- summary_writer.header = ["module_name", "step", *ops_]
185
- else:
186
- summary_writer.header = ["param_name", "step", *ops]
187
-
188
- for key in metric_value.keys():
189
- if MonitorConst.VPP_SEP in key:
190
- summary_writer.header.insert(0, 'vpp_stage')
191
- break
192
- summary_writer.write_csv(prefix, step)
193
- summary_writer.header = []
@@ -17,7 +17,7 @@ import re
17
17
  import abc
18
18
  import torch
19
19
 
20
- from msprobe.core.common.log import logger
20
+ from msprobe.pytorch.common.log import logger
21
21
 
22
22
  # 用于存储所有validator实现类的注册表
23
23
  config_validator_registry = {}
@@ -79,6 +79,8 @@ class TupleValidator(ConfigValidator):
79
79
 
80
80
  def validate_config_spec(config_spec: str, actual_data, module_name: str, data_type: str):
81
81
  focused_col = None
82
+ if not config_spec or not isinstance(config_spec, str):
83
+ return focused_col
82
84
  for _, validator_cls in config_validator_registry.items():
83
85
  config_validator = validator_cls()
84
86
  pattern_match = config_validator.check_pattern_match(config_spec)
@@ -1,4 +1,4 @@
1
- # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
2
  # All rights reserved.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,27 +13,20 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from abc import ABC, abstractmethod
17
16
  from collections import defaultdict
18
17
 
19
18
  import torch
20
19
  import torch.distributed as dist
21
20
 
22
- from msprobe.core.common.log import logger
21
+ from msprobe.pytorch.common.log import logger
23
22
  from msprobe.pytorch.monitor.utils import MVResult, MVGradResult
24
23
 
25
24
 
26
25
  class OptimizerMon(object):
27
- wrapped_optimizer = None
28
-
29
26
  def __init__(self) -> None:
30
27
  self.fp16_to_fp32_param = {}
31
28
  self.is_stage3 = False
32
29
 
33
- @classmethod
34
- def set_wrapped_optimizer(cls, wrapped_optimizer):
35
- cls.wrapped_optimizer = wrapped_optimizer
36
-
37
30
  def fetch_mv(self, monitor, torch_opt, params2name):
38
31
  pass
39
32
 
@@ -83,11 +76,10 @@ class OptimizerMon(object):
83
76
  ratio_dict = defaultdict()
84
77
  param2name = defaultdict()
85
78
  fp32_partitioned_groups_flat_grad = defaultdict()
86
- mix_prec_opt = OptimizerMon.wrapped_optimizer
87
79
  partition_id = dist.get_rank()
88
80
 
89
81
  def get_flatten_grad(self, optimizer, group_idx):
90
- if fp32_partitioned_groups_flat[group_idx].grad is None:
82
+ if fp32_partitioned_groups_flat[group_idx].grad is None:
91
83
  if partition_id == dist.get_world_size() - 1 and not self.is_stage3:
92
84
  fp32_partitioned_groups_flat_grad = optimizer.flatten_dense_tensors_aligned(
93
85
  optimizer.averaged_gradients[group_idx],
@@ -102,7 +94,7 @@ class OptimizerMon(object):
102
94
  return fp32_partitioned_groups_flat[group_idx].grad
103
95
 
104
96
  for group_idx in range(len(fp32_partitioned_groups_flat)):
105
- fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, mix_prec_opt, group_idx)
97
+ fp32_partitioned_groups_flat_grad[group_idx] = get_flatten_grad(self, torch_opt, group_idx)
106
98
 
107
99
  for name in params2name.values():
108
100
  start_idx, end_idx, group_idx, group_with_rank = name2indices[name]
@@ -111,9 +103,9 @@ class OptimizerMon(object):
111
103
  fp32_param = fp32_partitioned_groups_flat[group_idx][start_idx: end_idx]
112
104
  fp32_param.grad = fp32_partitioned_groups_flat_grad[group_idx][start_idx: end_idx]
113
105
  param2name[fp32_param] = name
114
- if not mix_prec_opt.state:
106
+ if not torch_opt.state:
115
107
  continue
116
- state_param = list(mix_prec_opt.state.values())[group_idx]
108
+ state_param = list(torch_opt.state.values())[group_idx]
117
109
  exp_avg = state_param.get("exp_avg", None)
118
110
  exp_avg_sq = state_param.get("exp_avg_sq", None)
119
111
  if exp_avg is None or exp_avg_sq is None:
@@ -151,29 +143,33 @@ class MixPrecisionOptimizerMon(OptimizerMon):
151
143
  混合精度训练通过适当降低某些计算的精度来加速训练过程并减少内存消耗。
152
144
  """
153
145
 
146
+ def map_fp16_tp_fp32_param(self, torch_opt):
147
+ for fp16_group, fp32_group in zip(torch_opt.float16_groups, torch_opt.fp32_from_float16_groups):
148
+ for fp16_param, fp32_param in zip(fp16_group, fp32_group):
149
+ self.fp16_to_fp32_param[fp16_param] = fp32_param
150
+
154
151
  def fetch_mv(self, monitor, torch_opt, params2name):
155
- mix_prec_opt = self.wrapped_optimizer
152
+ if not self.fp16_to_fp32_param and torch_opt is not None:
153
+ self.map_fp16_tp_fp32_param(torch_opt)
156
154
 
157
- if not self.fp16_to_fp32_param and mix_prec_opt is not None:
158
- for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups):
159
- for fp16_param, fp32_param in zip(fp16_group, fp32_group):
160
- self.fp16_to_fp32_param[fp16_param] = fp32_param
161
155
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
162
156
 
163
157
 
164
158
  class MegatronDistributedOptimizerMon(OptimizerMon):
165
- def fetch_mv(self, monitor, torch_opt, params2name):
166
- mix_prec_opt = self.wrapped_optimizer
167
- if not (hasattr(mix_prec_opt, "model_float16_groups") and
168
- hasattr(mix_prec_opt, "shard_fp32_from_float16_groups")):
159
+ def map_fp16_tp_fp32_param(self, torch_opt):
160
+ if not (hasattr(torch_opt, "model_float16_groups") and
161
+ hasattr(torch_opt, "shard_fp32_from_float16_groups")):
169
162
  raise Exception(
170
163
  "megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, "
171
164
  "if not, please check megatron-lm version")
172
- if not self.fp16_to_fp32_param and mix_prec_opt is not None:
173
- for fp16_group, shard_fp32_group in zip(mix_prec_opt.model_float16_groups,
174
- mix_prec_opt.shard_fp32_from_float16_groups):
175
- for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
176
- self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
165
+ for fp16_group, shard_fp32_group in zip(torch_opt.model_float16_groups,
166
+ torch_opt.shard_fp32_from_float16_groups):
167
+ for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group):
168
+ self.fp16_to_fp32_param[fp16_param] = shard_fp32_param
169
+
170
+ def fetch_mv(self, monitor, torch_opt, params2name):
171
+ if not self.fp16_to_fp32_param and torch_opt is not None:
172
+ self.map_fp16_tp_fp32_param(torch_opt)
177
173
 
178
174
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
179
175
 
@@ -183,15 +179,40 @@ class MegatronFP32OptimizerMon(OptimizerMon):
183
179
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
184
180
 
185
181
 
182
+ class MegatronChainedDistributedOptimizerMon(MegatronDistributedOptimizerMon):
183
+ def fetch_mv(self, monitor, torch_opt, params2name):
184
+ if not self.fp16_to_fp32_param and torch_opt is not None:
185
+ for opt in torch_opt.chained_optimizers:
186
+ self.map_fp16_tp_fp32_param(opt)
187
+
188
+ if not isinstance(torch_opt, torch.optim.Optimizer):
189
+ torch_opt.state = {}
190
+ for opt in torch_opt.chained_optimizers:
191
+ torch_opt.state.update(opt.optimizer.state)
192
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
193
+
194
+
195
+ class MegatronChainedMixPrecisionOptimizerMon(MixPrecisionOptimizerMon):
196
+ def fetch_mv(self, monitor, torch_opt, params2name):
197
+ if not self.fp16_to_fp32_param and torch_opt is not None:
198
+ for opt in torch_opt.chained_optimizers:
199
+ self.map_fp16_tp_fp32_param(opt)
200
+
201
+ if not isinstance(torch_opt, torch.optim.Optimizer):
202
+ torch_opt.state = {}
203
+ for opt in torch_opt.chained_optimizers:
204
+ torch_opt.state.update(opt.optimizer.state)
205
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
206
+
207
+
186
208
  class DeepSpeedZeroOptimizerStage0Mon(OptimizerMon):
187
209
  def fetch_mv(self, monitor, torch_opt, params2name):
188
210
  return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
189
211
 
190
212
 
191
213
  class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
192
- def get_param_index(self, params2name, name2index):
193
- mix_prec_opt = OptimizerMon.wrapped_optimizer
194
- fp16_groups = mix_prec_opt.fp16_partitioned_groups
214
+ def get_param_index(self, params2name, name2index, torch_opt):
215
+ fp16_groups = torch_opt.fp16_partitioned_groups
195
216
  name2indices = defaultdict()
196
217
  index_length = defaultdict()
197
218
  index = 0
@@ -210,13 +231,11 @@ class DeepSpeedZeroOptimizerStage3Mon(OptimizerMon):
210
231
 
211
232
  def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
212
233
  self.is_stage3 = True
213
- mix_prec_opt = OptimizerMon.wrapped_optimizer
214
- fp32_partitioned_groups_flat = mix_prec_opt.fp32_partitioned_groups_flat
234
+ fp32_partitioned_groups_flat = torch_opt.fp32_partitioned_groups_flat
215
235
  return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
216
236
 
217
237
 
218
238
  class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
219
-
220
239
  @staticmethod
221
240
  def get_group_index(fp32_length, world_size, index):
222
241
  for i in range(len(fp32_length) - 1):
@@ -229,12 +248,11 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
229
248
  return sub_interval_start, min(sub_index, world_size - 1)
230
249
  return fp32_length[-1], 0
231
250
 
232
- def get_param_index(self, params2name, name2index):
233
- mix_prec_opt = OptimizerMon.wrapped_optimizer
234
- padding = mix_prec_opt.groups_padding
251
+ def get_param_index(self, params2name, name2index, torch_opt):
252
+ padding = torch_opt.groups_padding
235
253
  world_size = dist.get_world_size()
236
254
  fp32_length = [0]
237
- for fp32_group_index, single_partition_of_fp32_group in enumerate(mix_prec_opt.single_partition_of_fp32_groups):
255
+ for fp32_group_index, single_partition_of_fp32_group in enumerate(torch_opt.single_partition_of_fp32_groups):
238
256
  fp32_length.append(len(single_partition_of_fp32_group) * world_size + fp32_length[fp32_group_index])
239
257
 
240
258
  bf16_groups = []
@@ -242,7 +260,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
242
260
  index_length = defaultdict()
243
261
  index = 0
244
262
  idx = 0
245
- for group_idx, bf16_group in enumerate(mix_prec_opt.bit16_groups):
263
+ for group_idx, bf16_group in enumerate(torch_opt.bit16_groups):
246
264
  bf16_groups.extend(bf16_group)
247
265
  for param in bf16_group:
248
266
  param_length = len(param.flatten())
@@ -250,7 +268,7 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
250
268
  index_length[idx] = (index, index + param_length, group_idx, group_index, group_with_rank)
251
269
  index += param_length
252
270
  idx += 1
253
- group_length = len(bf16_groups) / len(mix_prec_opt.bit16_groups)
271
+ group_length = len(bf16_groups) / len(torch_opt.bit16_groups)
254
272
  for _, name in params2name.items():
255
273
  name_index = name2index[name]
256
274
  start_idx, end_idx, group_idx, group_index, group_with_rank = index_length[name_index]
@@ -264,32 +282,34 @@ class DeepSpeedZeroOptimizerStage1or2Mon(OptimizerMon):
264
282
  return name2indices
265
283
 
266
284
  def fetch_mv(self, monitor, torch_opt, params2name, name2indices=None):
267
- mix_prec_opt = OptimizerMon.wrapped_optimizer
268
- fp32_partitioned_groups_flat = mix_prec_opt.single_partition_of_fp32_groups
285
+ fp32_partitioned_groups_flat = torch_opt.single_partition_of_fp32_groups
269
286
  return self._fetch_mv_grad_in_adam(monitor, torch_opt, params2name, name2indices, fp32_partitioned_groups_flat)
270
287
 
271
288
 
272
289
  class DummyOptimizerMon(OptimizerMon):
273
290
  def fetch_mv(self, monitor, torch_opt, params2name):
274
- return MVResult(exp_avg=None, exp_avg_sq=None, update=None, ratio=None)
291
+ return self._fetch_mv_in_adam(monitor, torch_opt, params2name)
275
292
 
276
293
 
277
294
  class OptimizerMonFactory:
278
295
  _optimizer_mon_map = {
279
- "Megatron_Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
280
- "Megatron_DistributedOptimizer": MegatronDistributedOptimizerMon,
281
- "Megatron_FP32Optimizer": MegatronFP32OptimizerMon,
282
- "DeepSpeedZeroOptimizer_Stage0": DeepSpeedZeroOptimizerStage0Mon,
283
- "DeepSpeedZeroOptimizer_Stage1_or_2": DeepSpeedZeroOptimizerStage1or2Mon,
296
+ "FP32Optimizer": MegatronFP32OptimizerMon,
297
+ "Float16OptimizerWithFloat16Params": MixPrecisionOptimizerMon,
298
+ "DistributedOptimizer": MegatronDistributedOptimizerMon,
299
+ "ChainedDistributedOptimizer": MegatronChainedDistributedOptimizerMon,
300
+ "ChainedFloat16OptimizerWithFloat16Params": MegatronChainedMixPrecisionOptimizerMon,
301
+ "BF16_Optimizer": DeepSpeedZeroOptimizerStage0Mon,
302
+ "DeepSpeedZeroOptimizer": DeepSpeedZeroOptimizerStage1or2Mon,
284
303
  "DeepSpeedZeroOptimizer_Stage3": DeepSpeedZeroOptimizerStage3Mon,
285
- "unknown": DummyOptimizerMon
304
+ "Adam": DummyOptimizerMon
286
305
  }
287
306
 
288
307
  @staticmethod
289
- def create_optimizer_mon(opt_ty: str):
290
- if not opt_ty:
291
- return DummyOptimizerMon()
292
- optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(opt_ty)
293
- if not optimizer_mon_class:
294
- raise Exception("opt_ty should be one of: " + ", ".join(OptimizerMonFactory._optimizer_mon_map.keys()))
295
- return optimizer_mon_class()
308
+ def create_optimizer_mon(optimizer):
309
+ # auto replace opt_ty
310
+ optimizer_class = optimizer.__class__.__name__
311
+ if optimizer_class == "ChainedOptimizer":
312
+ optimizer_class = "Chained" + optimizer.chained_optimizers[0].__class__.__name__
313
+
314
+ optimizer_mon_class = OptimizerMonFactory._optimizer_mon_map.get(optimizer_class, DummyOptimizerMon)
315
+ return optimizer_mon_class(), optimizer_class
@@ -1,11 +1,26 @@
1
+ # Copyright (c) 2024-2025, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
1
17
  import os
2
18
  import re
3
- import argparse
4
19
  from glob import glob
5
20
 
6
21
  import pandas as pd
7
22
 
8
- from msprobe.core.common.log import logger
23
+ from msprobe.pytorch.common.log import logger
9
24
 
10
25
 
11
26
  def parse_logfile(logfile):
@@ -21,19 +36,19 @@ def parse_logfile(logfile):
21
36
  def parse_monitor_output(output_dir):
22
37
  reduced = {}
23
38
  unreduced = {}
24
- for dir in glob(output_dir + '*'):
25
- rank = int(re.findall('(?<=rank)[\d]*', dir)[0])
39
+ for directory in glob(output_dir + '*'):
40
+ rank = int(re.findall('(?<=rank)[\d]*', directory)[0])
26
41
  unreduced[rank] = []
27
42
  reduced[rank] = []
28
- for file in os.listdir(dir):
29
- df = pd.read_csv(os.path.join(dir, file))
43
+ for file in os.listdir(directory):
44
+ df = pd.read_csv(os.path.join(directory, file))
30
45
  if '_unreduced_' in file:
31
46
  unreduced[rank].append(df)
32
47
  pass
33
48
  elif '_reduced_' in file:
34
49
  reduced[rank].append(df)
35
50
  else:
36
- logger.info(f'unexpected file {file} in {dir}')
51
+ logger.info(f'unexpected file {file} in {directory}')
37
52
  return reduced, unreduced
38
53
 
39
54
 
@@ -41,7 +56,7 @@ def valid_reduce(reduced, unreduced, tp_size, dp_size, sequence_parallel):
41
56
  steps = len(reduced[0])
42
57
  world_size = len(reduced)
43
58
  errors = []
44
- for index, row in unreduced[0][0].iterrows():
59
+ for _, row in unreduced[0][0].iterrows():
45
60
  param = row['param_name']
46
61
  is_tp_duplicate = False
47
62
  for step in range(2):
@@ -103,7 +118,7 @@ def valid_total_norm(total_norm, reduced, duplicate_embedding):
103
118
  if step == 0:
104
119
  logger.info(f'rank {rank} is duplicated in dp group')
105
120
  continue
106
- for index, row in reduced[rank][step].iterrows():
121
+ for _, row in reduced[rank][step].iterrows():
107
122
  if duplicate_embedding and 'word_embedding' in row['param_name']:
108
123
  continue
109
124
  calculated_norm += row['norm'] ** 2
@@ -16,13 +16,27 @@ import inspect
16
16
  from collections import namedtuple
17
17
  from datetime import timezone, timedelta
18
18
  from functools import wraps
19
+ from datetime import datetime
20
+ import os
21
+ import re
19
22
 
20
23
  import torch
21
24
 
22
25
  from msprobe.core.common.const import MonitorConst, Const
23
- from msprobe.core.common.log import logger
26
+ from msprobe.pytorch.common.log import logger
24
27
  from msprobe.core.common.utils import is_int
28
+ from msprobe.core.common.file_utils import check_file_or_directory_path
25
29
 
30
+
31
+ device = "cpu"
32
+ try:
33
+ import torch_npu
34
+ device = "npu"
35
+ except ImportError:
36
+ if torch.cuda.is_available():
37
+ device = "cuda"
38
+
39
+ NAN_TENSOR_ON_DEVICE = None
26
40
  FILE_MAX_SIZE = 10 * 1024 * 1024 * 1024
27
41
  FILE_NAME_MAX_LENGTH = 255
28
42
  DIRECTORY_MAX_LENGTH = 4096
@@ -39,6 +53,17 @@ class MsgConst:
39
53
  SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"]
40
54
 
41
55
 
56
+ def get_output_base_dir():
57
+ return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR)
58
+
59
+
60
+ def get_nan_tensor():
61
+ global NAN_TENSOR_ON_DEVICE
62
+ if not NAN_TENSOR_ON_DEVICE:
63
+ NAN_TENSOR_ON_DEVICE = torch.tensor(torch.nan, device=device)
64
+ return NAN_TENSOR_ON_DEVICE
65
+
66
+
42
67
  def filter_special_chars(func):
43
68
  @wraps(func)
44
69
  def func_level(msg):
@@ -64,60 +89,19 @@ def get_param_struct(param):
64
89
  return res
65
90
 
66
91
 
67
- def is_recomputation():
68
- """Check if the current operation is in the re-computation phase.
69
-
70
- This function inspects the current call stack to indicate whether the current operation is in the
71
- re-computation phase. We use a blacklist mechanism, now supported megatron and mindspeed framework.
72
- megatron: The 'backward' function is called by the 'torch/autograd/function.py' file.
73
- mindspeed: The 'checkpoint_function_backward' function is called by the 'torch/autograd/function.py'
74
- file or the custom module(use CheckpointWithoutOutput) with the 'backward' function is executed within the
75
- 'torch/_tensor.py' file.
76
-
77
- Returns:
78
- bool: True if in the re-computation phase, False otherwise.
79
- """
80
- backward_function_indices = []
81
- call_stack = inspect.stack()
82
-
83
- # Identify the function 'backward' is being executed within the 'torch/_tensor.py' file.
84
- for frame_info in call_stack:
85
- if frame_info.function == Const.BACKWARD and frame_info.filename.endswith('torch/_tensor.py'):
86
- del call_stack
87
- return True
88
-
89
- # Identify indices in the call stack where the specific function is being executed
90
- for idx, frame_info in enumerate(call_stack):
91
- if frame_info.function == Const.BACKWARD or frame_info.function == 'checkpoint_function_backward':
92
- backward_function_indices.append(idx)
93
-
94
- # Check if the execution is within 'torch/autograd/function.py' file
95
- for idx in backward_function_indices:
96
- # The Megatron and MindSpeed L0&L1 scenes
97
- if idx + 1 < len(call_stack) and call_stack[idx + 1].filename.endswith('torch/autograd/function.py'):
98
- del call_stack
99
- return True
100
- # The latest MindSpeed L2 and ModelLink scenes
101
- if idx + 2 < len(call_stack) and call_stack[idx + 2].filename.endswith('torch/autograd/function.py'):
102
- del call_stack
103
- return True
104
-
105
- del call_stack
106
- return False
107
-
108
-
109
92
  def validate_ops(ops):
110
93
  if not isinstance(ops, list):
111
94
  raise TypeError("ops should be a list")
112
- if not ops:
113
- raise TypeError(f"specify ops to calculate metrics. Optional ops: {MonitorConst.OP_LIST}")
114
-
115
95
  valid_ops = []
116
96
  for op in ops:
117
97
  if op not in MonitorConst.OP_LIST:
118
98
  logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}")
119
- else:
120
- valid_ops.append(op)
99
+ continue
100
+ valid_ops.append(op)
101
+ if not valid_ops:
102
+ default_op = MonitorConst.OP_LIST[0]
103
+ valid_ops.append(default_op)
104
+ logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used")
121
105
  return valid_ops
122
106
 
123
107
 
@@ -164,6 +148,11 @@ def validate_mg_distribution(mg_distribution):
164
148
  raise TypeError('mg_distribution should be a bool')
165
149
 
166
150
 
151
+ def validate_param_distribution(param_distribution):
152
+ if not isinstance(param_distribution, bool):
153
+ raise TypeError('param_distribution should be a bool')
154
+
155
+
167
156
  def validate_cc_distribution(cc_distribution):
168
157
  if not isinstance(cc_distribution, dict):
169
158
  raise TypeError('cc_distribution should be a dictionary')
@@ -184,6 +173,11 @@ def validate_cc_distribution(cc_distribution):
184
173
  raise TypeError(f'{key} of cc_distribution is not supported.')
185
174
 
186
175
 
176
+ def validate_squash_name(squash_name):
177
+ if not isinstance(squash_name, bool):
178
+ raise TypeError('squash_name should be a bool')
179
+
180
+
187
181
  def validate_alert(alert):
188
182
  if not isinstance(alert, dict):
189
183
  raise TypeError('alert should be a dictionary')
@@ -240,6 +234,9 @@ def validate_config(config):
240
234
  mg_distribution = config.get('mg_distribution', False)
241
235
  validate_mg_distribution(mg_distribution)
242
236
 
237
+ param_distribution = config.get('param_distribution', False)
238
+ validate_param_distribution(param_distribution)
239
+
243
240
  cc_distribution = config.get('cc_distribution', {})
244
241
  validate_cc_distribution(cc_distribution)
245
242
 
@@ -248,3 +245,42 @@ def validate_config(config):
248
245
 
249
246
  step_count_per_record = config.get('step_count_per_record', 1)
250
247
  validate_step_count_per_record(step_count_per_record)
248
+
249
+ squash_name = config.get('squash_name', True)
250
+ validate_squash_name(squash_name)
251
+
252
+ if not targets:
253
+ if xy_distribution:
254
+ config["all_xy"] = True
255
+ config["targets"] = {"": {}}
256
+
257
+
258
+ def time_str2time_digit(time_str):
259
+ time_format = '%b%d_%H-%M-%S'
260
+ try:
261
+ time_digit = datetime.strptime(time_str, time_format)
262
+ except Exception as e:
263
+ raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \
264
+ of existing output dirpath, like 'Dec03_21-34-40'.") from e
265
+ return time_digit
266
+
267
+
268
+ def get_target_output_dir(monitor_path, time_start, time_end):
269
+ check_file_or_directory_path(monitor_path, isdir=True)
270
+ time_start = time_str2time_digit(time_start) if time_start is not None else time_start
271
+ time_end = time_str2time_digit(time_end) if time_end is not None else time_end
272
+ if time_start and time_end and time_start > time_end:
273
+ raise ValueError(f"time_start({time_start}) greater than time_end({time_end})")
274
+ result = {}
275
+ for dirname in os.listdir(monitor_path):
276
+ match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname)
277
+ if not match:
278
+ continue
279
+ time_tag = match.group(1)
280
+ rank = match.group(2)
281
+ target_time = time_str2time_digit(time_tag)
282
+ start_ok = time_start is None or target_time >= time_start
283
+ end_ok = time_end is None or target_time <= time_end
284
+ if start_ok and end_ok:
285
+ result[rank] = os.path.join(monitor_path, dirname)
286
+ return result