mindstudio-probe 8.3.0__py3-none-any.whl → 8.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/RECORD +37 -47
- msprobe/README.md +8 -5
- msprobe/core/common/const.py +17 -3
- msprobe/core/common/file_utils.py +64 -13
- msprobe/core/common/framework_adapter.py +10 -1
- msprobe/core/common/utils.py +17 -0
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
- msprobe/core/hook_manager.py +2 -16
- msprobe/core/service.py +5 -16
- msprobe/docs/01.installation.md +2 -0
- msprobe/docs/02.config_introduction.md +0 -13
- msprobe/docs/14.data_parse_PyTorch.md +2 -0
- msprobe/docs/21.visualization_PyTorch.md +1 -1
- msprobe/docs/25.tool_function_introduction.md +0 -1
- msprobe/docs/32.ckpt_compare.md +5 -5
- msprobe/mindspore/monitor/module_hook.py +17 -20
- msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
- msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
- msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
- msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
- msprobe/pytorch/common/utils.py +0 -70
- msprobe/pytorch/debugger/debugger_config.py +0 -10
- msprobe/pytorch/dump/module_dump/module_processer.py +18 -3
- msprobe/pytorch/hook_module/api_register.py +5 -1
- msprobe/pytorch/monitor/module_hook.py +16 -34
- msprobe/pytorch/pt_config.py +2 -51
- msprobe/pytorch/pytorch_service.py +2 -11
- msprobe/visualization/builder/graph_builder.py +2 -2
- msprobe/visualization/builder/graph_merger.py +13 -0
- msprobe/visualization/graph/graph.py +13 -9
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
- msprobe/pytorch/attl_manager.py +0 -65
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.1.dist-info}/top_level.txt +0 -0
|
@@ -73,14 +73,7 @@
|
|
|
73
73
|
| data_mode | 与[ 1.2 task 配置为 statistics ](#12-task-配置为-statistics)中的解释相同 | 否 |
|
|
74
74
|
| file_format | tensor 数据的保存格式,str 类型,仅支持 MindSpore 静态图场景的 L2 级别配置该字段,其他场景不生效。可选参数:<br/> "bin":dump 的 tensor 文件为二进制格式;<br/>"npy":dump 的 tensor 文件后缀为 .npy,默认值。 | 否 |
|
|
75
75
|
| summary_mode | 控制 dump 文件输出的模式,str 类型,支持 PyTorch、MSAdapter、MindSpore 动态图。可选参数:<br/> md5:dump 输出包含 CRC-32 值以及 API 统计信息的 dump.json 文件,用于验证数据的完整性;<br/> statistics:dump 仅输出包含 API 统计信息的 dump.json 文件,默认值。| 否 |
|
|
76
|
-
| online_run_ut<sup>a</sup> | 在线预检模式开关,bool 类型,可选参数 true(开启)、false(关闭),默认未配置,表示关闭。配置为 true 表示开启在线预检。| 否 |
|
|
77
|
-
| nfs_path<sup>a</sup> | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。仅在 online_run_ut 字段配置为 true 时生效,配置该参数后 host 和 port 不生效。 | 否 |
|
|
78
|
-
| host<sup>a</sup> | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的局域网 IP 地址。仅在 online_run_ut 字段配置为 true 时生效,局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
|
|
79
|
-
| port<sup>a</sup> | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的端口号。仅在 online_run_ut 字段配置为 true 时生效,局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。| 否 |
|
|
80
76
|
|
|
81
|
-
**说明**:
|
|
82
|
-
|
|
83
|
-
1. online_run_ut、nfs_path、host、port 等字段仅在线预检场景 NPU 机器生效。
|
|
84
77
|
|
|
85
78
|
**示例**:
|
|
86
79
|
- [PyTorch场景](03.config_examples.md#12-task-配置为-tensor)
|
|
@@ -95,17 +88,11 @@
|
|
|
95
88
|
| white_list<sup>a</sup> | API dump 白名单,仅对指定的 API 进行 dump。<br/>**配置示例**:"white_list": ["conv1d", "conv2d"]。默认未配置白名单,即 dump 全量 API 数据。 | 否 |
|
|
96
89
|
| black_list<sup>a</sup> | API dump 黑名单,被指定的 API 不进行 dump。<br/>**配置示例**:"black_list": ["conv1d", "conv2d"]。默认未配置黑名单,即 dump 全量 API 数据。 | 否 |
|
|
97
90
|
| error_data_path | 配置保存精度未达标的 API 输入输出数据路径,默认为当前路径。<br/>**配置示例**:"error_data_path": "./"。 | 否 |
|
|
98
|
-
| is_online<sup>b</sup> | 在线预检模式开关,bool 类型,可选参数 true(开启)、false(关闭),默认关闭。 | 否 |
|
|
99
|
-
| nfs_path<sup>b</sup> | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。配置该参数后 host 和 port 不生效,仅在 is_online 字段配置为 true 时生效。 | 否 |
|
|
100
|
-
| host<sup>b</sup> | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机地址 127.0.0.1 或本机局域网 IP。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。仅在 is_online 字段配置为 true 时生效。 | 否 |
|
|
101
|
-
| port<sup>b</sup> | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机可用端口。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。仅在 is_online 字段配置为 true 时生效。| 否 |
|
|
102
|
-
| rank_list<sup>b</sup> | 指定在线预检的 Rank ID,默认值为 [0],list[int] 类型,应配置为大于等于 0 的整数,且须根据实际卡的 Rank ID 配置,若所配置的值大于实际训练所运行的卡的 Rank ID,则在线预检输出数据为空。GPU 和 NPU 须配置一致。仅在 is_online 字段配置为 true 时生效。 | 否 |
|
|
103
91
|
|
|
104
92
|
**说明**:
|
|
105
93
|
|
|
106
94
|
1. white_list 和 black_list 同时配置时,二者配置的 API 名单若无交集,则白名单生效,若 API 名单存在交集,则白名单排除的部分以及交集的 API 不进行 dump。
|
|
107
95
|
|
|
108
|
-
2. is_online、nfs_path、host、port、rank_list 等字段仅在线预检场景 GPU 机器生效。
|
|
109
96
|
|
|
110
97
|
**示例**:
|
|
111
98
|
```json
|
|
@@ -8,6 +8,8 @@
|
|
|
8
8
|
|
|
9
9
|
依赖:CANN 包中的 msaccucmp 工具,需要安装 Ascend-CANN-toolkit,详见《[CANN 软件安装指南](https://www.hiascend.com/document/detail/zh/canncommercial/700/envdeployment/instg/instg_0001.html)》。
|
|
10
10
|
|
|
11
|
+
**安全注意事项**: 工具使用时传入的msaccucmp.py文件路径参数请务必确保为CANN中自带的msaccucmp.py原始路径且文件内容未被篡改,工具会在解析数据时直接执行该文件,用户需自行保证该文件内容安全性。
|
|
12
|
+
|
|
11
13
|
## 2 数据解析操作指导
|
|
12
14
|
|
|
13
15
|
### 2.1 进入 parse 交互式界面
|
|
@@ -7,7 +7,6 @@
|
|
|
7
7
|
| [数据采集<br>(dump)](./05.data_dump_PyTorch.md) | 采集模型训练过程中的API或Module层级的前反向输入输出数据,包括层次关系、统计值信息、真实数据和调用栈等。 | 1、将模型中训练的API或Module的前反向输入输出数据保存下来分析<br> 2、模型出现溢出时,可用于查看哪些API或Module出现了溢出 | 1、API级数据采集仅支持白名单列表上的API<br>2、工具会做一些同步操作,引入工具可能会导致一些同步问题消失<br>3、当前对inplace操作API或Module的支持度有限<br>4、暂不支持参数及参数梯度的采集 |
|
|
8
8
|
| [离线预检<br>(api_accuracy_checker)](./07.accuracy_checker_PyTorch.md) | 为网络中每个API创建用例,检验其精度,并根据不同比对算法综合判定API在NPU上的精度是否达标,快速找出精度差异API。 | 1、对模型中所有的API做精度初步排查<br>2、精度排查不受模型累计误差影响 | 1、依赖GPU环境<br>2、不支持通信算子<br>3、仅支持部分融合算子 |
|
|
9
9
|
| [整网比对<br>(compare)](./10.accuracy_compare_PyTorch.md) | 计算模型整网NPU和标杆设备的精度误差指标,标记精度异常API或Module,助力快速定位精度问题根因。 | 1、整网比对定位精度可疑算子 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响<br>2、当模型规模较大时,比对所需时间较长 |
|
|
10
|
-
| [在线预检<br>(online_api_accuracy_checker)](./08.accuracy_checker_online_PyTorch.md) | 通过TCP通信或共享存储空间的方式,进行在线精度预检,解决离线预检大数据量落盘、传输困难痛点。 | 1、使用离线预检,数据量较大落盘困难或传输耗时长时,可通过在线预检进行精度排查 | 1、依赖GPU环境,NPU和GPU能够通信<br>2、重计算模式下,不支持反向aten算子预检 |
|
|
11
10
|
| [溢出检查<br>(overflow_checker)](./12.overflow_check_PyTorch.md) | 检测模型计算过程的输入输出,并在溢出时落盘数据,助力用户快速定位溢出位置。 | 1、当模型出现溢出时,用于快速定位最先溢出的API或Module<br>2、相比数据采集,性能更优,磁盘压力更小 | 1、局限性同数据采集 |
|
|
12
11
|
| [数据解析<br>(parse_tool)](./14.data_parse_PyTorch.md) | 交互式界面处理解析kernel层级dump数据,便于查看分析。 | 1、比对kernel层级dump数据的一致性 | 1、仅限于NPU |
|
|
13
12
|
| [无标杆比对<br>(free_benchmark)](./15.free_benchmarking_PyTorch.md) | 不依赖标杆数据,通过对算子输入增加微小扰动,计算扰动后输出与原始输出的相对误差,识别有精度风险算子。 | 1、无标杆数据场景下的算子精度排查<br>2、对个别算子进行升精度、“to cpu”等操作,以验证其对模型loss的影响 | 1、由于需要拷贝输入进行二次执行,所以在遇到大张量的输入时容易发生显存OOM的问题, 特别是反向比对过程。建议结合白名单使用<br>2、比对会延长训练时间,整网比对可能会造成严重的耗时膨胀,建议结合白名单使用 |
|
msprobe/docs/32.ckpt_compare.md
CHANGED
|
@@ -19,11 +19,11 @@ Megatron、MindSpeed的ckpt加载依赖megatron,请确保megatron在python环
|
|
|
19
19
|
msprobe --framework pytorch config_check --compare path1 path2 -o output_path.json
|
|
20
20
|
```
|
|
21
21
|
|
|
22
|
-
| 参数名 | 解释
|
|
23
|
-
|
|
24
|
-
| -f 或 --framework | 深度学习框架,str类型。支持参数:pytorch,mindspore,注意:msadaptor场景传入mindspore。
|
|
25
|
-
| -c 或 --compare | 2个ckpt
|
|
26
|
-
| -o 或 --output | 比对结果输出路径,默认为 ./ckpt_similarity.json。输出路径存在时将报错终止。
|
|
22
|
+
| 参数名 | 解释 | 是否必选 |
|
|
23
|
+
|--------|----------------------------------------------------------------------------------------------|--------|
|
|
24
|
+
| -f 或 --framework | 深度学习框架,str类型。支持参数:pytorch,mindspore,注意:msadaptor场景传入mindspore。 | 是 |
|
|
25
|
+
| -c 或 --compare | 2个ckpt的路径。在ckpt传给工具加载前,用户需要确保ckpt是安全可信的,若ckpt来源官方有提供SHA256等校验值,用户必须要进行校验,以确保ckpt没有被篡改。 | 是 |
|
|
26
|
+
| -o 或 --output | 比对结果输出路径,默认为 ./ckpt_similarity.json。输出路径存在时将报错终止。 | 否 |
|
|
27
27
|
|
|
28
28
|
Megatron-LM 和 MindSpeed 的 ckpt 目录结构如下:
|
|
29
29
|
|
|
@@ -163,15 +163,11 @@ class GradContext:
|
|
|
163
163
|
def __init__(self) -> None:
|
|
164
164
|
self.pre = {}
|
|
165
165
|
self.post = {}
|
|
166
|
-
self.acc_metric = {}
|
|
167
|
-
self.acc = {}
|
|
168
166
|
self.actv = {}
|
|
169
167
|
|
|
170
168
|
def reset(self):
|
|
171
169
|
self.pre.clear()
|
|
172
170
|
self.post.clear()
|
|
173
|
-
self.acc_metric.clear()
|
|
174
|
-
self.acc.clear()
|
|
175
171
|
self.actv.clear()
|
|
176
172
|
|
|
177
173
|
|
|
@@ -312,7 +308,6 @@ class TrainerMon:
|
|
|
312
308
|
self.recording_l2_features = self.config.get('recording_l2_features', False)
|
|
313
309
|
self.sa_order = self.config.get('sa_order', "s,b,h,d")
|
|
314
310
|
|
|
315
|
-
|
|
316
311
|
if not self.cc_distribution.get('enable', False):
|
|
317
312
|
self.cc_log_only = False
|
|
318
313
|
else:
|
|
@@ -403,7 +398,7 @@ class TrainerMon:
|
|
|
403
398
|
if self.monitoring:
|
|
404
399
|
module_rank_valid = self.is_target_rank()
|
|
405
400
|
step_condition = (context.step >= self.start_step and (
|
|
406
|
-
|
|
401
|
+
context.step - self.start_step) % self.step_interval == 0)
|
|
407
402
|
if module_rank_valid and step_condition:
|
|
408
403
|
self.has_collect_times += 1
|
|
409
404
|
|
|
@@ -447,6 +442,7 @@ class TrainerMon:
|
|
|
447
442
|
hook(optimizer, args, kwargs)
|
|
448
443
|
step_final_hook(optimizer, args, kwargs)
|
|
449
444
|
return out
|
|
445
|
+
|
|
450
446
|
return wrapper
|
|
451
447
|
|
|
452
448
|
if self.is_mindtorch:
|
|
@@ -541,7 +537,7 @@ class TrainerMon:
|
|
|
541
537
|
if self.mv_distribution or self.ur_distribution or self.mg_direction:
|
|
542
538
|
if self.is_mindtorch:
|
|
543
539
|
context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, \
|
|
544
|
-
|
|
540
|
+
context.param_adam_ratio = self.optimizer_mon.fetch_mv(self, self.param2name)
|
|
545
541
|
else:
|
|
546
542
|
context.param_exp_avg, context.param_exp_avg_sq = self.get_mv_for_ms(optimizer)
|
|
547
543
|
|
|
@@ -564,7 +560,6 @@ class TrainerMon:
|
|
|
564
560
|
context = self.optimizer_context[optimizer]
|
|
565
561
|
self.generate_param_metrics(context, MonitorConst.POST_PARAM)
|
|
566
562
|
|
|
567
|
-
|
|
568
563
|
if self.optimizer_hooked or not self.is_target_rank():
|
|
569
564
|
return
|
|
570
565
|
|
|
@@ -577,7 +572,6 @@ class TrainerMon:
|
|
|
577
572
|
if not self.wg_distribution:
|
|
578
573
|
return
|
|
579
574
|
|
|
580
|
-
get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
|
|
581
575
|
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
|
|
582
576
|
|
|
583
577
|
def generate_param_map(self, tag, param_tensor):
|
|
@@ -671,7 +665,7 @@ class TrainerMon:
|
|
|
671
665
|
if not self.wg_distribution:
|
|
672
666
|
return
|
|
673
667
|
|
|
674
|
-
self.summary_writer.write_metrics(self.ops, self.grad_context.
|
|
668
|
+
self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced',
|
|
675
669
|
use_micro_step=self.monitor_mbs_grad)
|
|
676
670
|
self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
|
|
677
671
|
|
|
@@ -810,7 +804,7 @@ class TrainerMon:
|
|
|
810
804
|
f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
811
805
|
MonitorConst.ACTV, module_input))
|
|
812
806
|
module_output = [tensor for tensor in module_output if isinstance(tensor, Tensor)] \
|
|
813
|
-
|
|
807
|
+
if isinstance(module_output, tuple) else module_output
|
|
814
808
|
tbtag_tensor_map.update(
|
|
815
809
|
self.build_tbtag_tensor_map(
|
|
816
810
|
f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
|
|
@@ -868,11 +862,13 @@ class TrainerMon:
|
|
|
868
862
|
if version.parse(mindspore.__version__) >= version.parse('2.6.0'):
|
|
869
863
|
def wrapper(module, args, kwargs, module_output):
|
|
870
864
|
return fwd_hook_fun(module, args, kwargs, module_output, name)
|
|
865
|
+
|
|
871
866
|
return module.register_forward_hook(wrapper, with_kwargs=True)
|
|
872
867
|
|
|
873
868
|
else:
|
|
874
869
|
def wrapper(module, args, module_output):
|
|
875
870
|
return fwd_hook_fun(module, args, None, module_output, name)
|
|
871
|
+
|
|
876
872
|
return module.register_forward_hook(wrapper)
|
|
877
873
|
|
|
878
874
|
def extract_attention_feature_hook(module, args, kwargs, module_output, name):
|
|
@@ -880,7 +876,7 @@ class TrainerMon:
|
|
|
880
876
|
if kwargs:
|
|
881
877
|
kwargs_tensors = [tensor for tensor in kwargs.values() if isinstance(tensor, Tensor)]
|
|
882
878
|
module_input.extend(kwargs_tensors)
|
|
883
|
-
|
|
879
|
+
|
|
884
880
|
if module not in self.feature_hook_context_by_module:
|
|
885
881
|
self.feature_hook_context_by_module[module] = FeatureHookContext(name)
|
|
886
882
|
context: FeatureHookContext = self.feature_hook_context_by_module[module]
|
|
@@ -890,7 +886,7 @@ class TrainerMon:
|
|
|
890
886
|
logger.warning(
|
|
891
887
|
"Calculate attention feature failed, the length of module_input in attention hook's module should "
|
|
892
888
|
"be greater than or equal to 2.")
|
|
893
|
-
|
|
889
|
+
|
|
894
890
|
q_h = module_input[0]
|
|
895
891
|
k_h = module_input[1]
|
|
896
892
|
qkt = cal_qkt(q_h, k_h, order=self.sa_order)
|
|
@@ -985,25 +981,26 @@ class TrainerMon:
|
|
|
985
981
|
self._hook_weights()
|
|
986
982
|
|
|
987
983
|
def _hook_weights(self):
|
|
988
|
-
context = self.grad_context
|
|
989
984
|
|
|
990
985
|
@_no_grad()
|
|
991
|
-
def param_hook(grad,
|
|
986
|
+
def param_hook(grad, param, name):
|
|
992
987
|
key = name
|
|
993
988
|
if self.monitor_mbs_grad:
|
|
994
989
|
key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
|
|
995
990
|
key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
|
|
996
991
|
self.register_param_call_id("param_hook", key)
|
|
997
992
|
param.micro_step += 1
|
|
998
|
-
|
|
993
|
+
grad_dict = {}
|
|
999
994
|
if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
|
|
1000
|
-
|
|
995
|
+
grad_dict[key] = grad
|
|
996
|
+
get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
|
|
997
|
+
|
|
1001
998
|
if param.micro_step == self.micro_batch_number:
|
|
1002
999
|
param.micro_step = 0
|
|
1003
1000
|
|
|
1004
|
-
def param_hook_wrapper(param_hook,
|
|
1001
|
+
def param_hook_wrapper(param_hook, param, name):
|
|
1005
1002
|
def wrapper(grad):
|
|
1006
|
-
return param_hook(grad,
|
|
1003
|
+
return param_hook(grad, param, name)
|
|
1007
1004
|
|
|
1008
1005
|
return wrapper
|
|
1009
1006
|
|
|
@@ -1011,7 +1008,7 @@ class TrainerMon:
|
|
|
1011
1008
|
for param, name in self.param2name.items():
|
|
1012
1009
|
setattr(param, 'micro_step', 0)
|
|
1013
1010
|
handle = param.register_hook(
|
|
1014
|
-
param_hook_wrapper(param_hook,
|
|
1011
|
+
param_hook_wrapper(param_hook, param=param, name=name))
|
|
1015
1012
|
self.handles['wgrads'].append(handle)
|
|
1016
1013
|
self.weight_hooked = True
|
|
1017
1014
|
|
|
@@ -24,8 +24,7 @@ from msprobe.pytorch.pt_config import RunUTConfig
|
|
|
24
24
|
|
|
25
25
|
RunUtConfig = namedtuple('RunUtConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
|
|
26
26
|
'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
|
|
27
|
-
'black_list', 'error_data_path'
|
|
28
|
-
OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
|
|
27
|
+
'black_list', 'error_data_path'])
|
|
29
28
|
|
|
30
29
|
|
|
31
30
|
class Config:
|
|
@@ -46,13 +45,7 @@ class Config:
|
|
|
46
45
|
'white_list': list,
|
|
47
46
|
'black_list': list,
|
|
48
47
|
'error_data_path': str,
|
|
49
|
-
'precision': int
|
|
50
|
-
'is_online': bool,
|
|
51
|
-
'nfs_path': str,
|
|
52
|
-
'host': str,
|
|
53
|
-
'port': int,
|
|
54
|
-
'rank_list': list,
|
|
55
|
-
'tls_path': str
|
|
48
|
+
'precision': int
|
|
56
49
|
}
|
|
57
50
|
if key not in validators:
|
|
58
51
|
raise ValueError(f"{key} must be one of {validators.keys()}")
|
|
@@ -68,10 +61,6 @@ class Config:
|
|
|
68
61
|
RunUTConfig.check_filter_list_config(key, value)
|
|
69
62
|
if key == 'error_data_path':
|
|
70
63
|
RunUTConfig.check_error_data_path_config(value)
|
|
71
|
-
if key == 'nfs_path':
|
|
72
|
-
RunUTConfig.check_nfs_path_config(value)
|
|
73
|
-
if key == 'tls_path':
|
|
74
|
-
RunUTConfig.check_tls_path_config(value)
|
|
75
64
|
return value
|
|
76
65
|
|
|
77
66
|
|
|
@@ -85,12 +74,6 @@ class CheckerConfig:
|
|
|
85
74
|
self.white_list = msCheckerConfig.white_list
|
|
86
75
|
self.black_list = msCheckerConfig.black_list
|
|
87
76
|
self.error_data_path = msCheckerConfig.error_data_path
|
|
88
|
-
self.is_online = msCheckerConfig.is_online
|
|
89
|
-
self.nfs_path = msCheckerConfig.nfs_path
|
|
90
|
-
self.host = msCheckerConfig.host
|
|
91
|
-
self.port = msCheckerConfig.port
|
|
92
|
-
self.rank_list = msCheckerConfig.rank_list
|
|
93
|
-
self.tls_path = msCheckerConfig.tls_path
|
|
94
77
|
|
|
95
78
|
if task_config:
|
|
96
79
|
self.load_config(task_config)
|
|
@@ -99,22 +82,7 @@ class CheckerConfig:
|
|
|
99
82
|
self.white_list = task_config.white_list
|
|
100
83
|
self.black_list = task_config.black_list
|
|
101
84
|
self.error_data_path = task_config.error_data_path
|
|
102
|
-
self.is_online = task_config.is_online
|
|
103
|
-
self.nfs_path = task_config.nfs_path
|
|
104
|
-
self.host = task_config.host
|
|
105
|
-
self.port = task_config.port
|
|
106
|
-
self.rank_list = task_config.rank_list
|
|
107
|
-
self.tls_path = task_config.tls_path
|
|
108
85
|
|
|
109
|
-
def get_online_config(self):
|
|
110
|
-
return OnlineConfig(
|
|
111
|
-
is_online=self.is_online,
|
|
112
|
-
nfs_path=self.nfs_path,
|
|
113
|
-
host=self.host,
|
|
114
|
-
port=self.port,
|
|
115
|
-
rank_list=self.rank_list,
|
|
116
|
-
tls_path=self.tls_path
|
|
117
|
-
)
|
|
118
86
|
|
|
119
87
|
def get_run_ut_config(self, **config_params):
|
|
120
88
|
return RunUtConfig(
|
|
@@ -127,6 +95,5 @@ class CheckerConfig:
|
|
|
127
95
|
real_data_path=config_params.get('real_data_path'),
|
|
128
96
|
white_list=self.white_list.copy() if self.white_list else [],
|
|
129
97
|
black_list=self.black_list.copy() if self.black_list else [],
|
|
130
|
-
error_data_path=config_params.get('error_data_path')
|
|
131
|
-
online_config=self.get_online_config()
|
|
98
|
+
error_data_path=config_params.get('error_data_path')
|
|
132
99
|
)
|
|
@@ -117,30 +117,6 @@ def api_precision_compare(config):
|
|
|
117
117
|
change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
118
118
|
|
|
119
119
|
|
|
120
|
-
def online_api_precision_compare(online_config):
|
|
121
|
-
rank = online_config.rank
|
|
122
|
-
result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace(
|
|
123
|
-
"_rank*.csv", f"_rank{rank}.csv")
|
|
124
|
-
details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace(
|
|
125
|
-
"_rank*.csv", f"_rank{rank}.csv")
|
|
126
|
-
detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
|
|
127
|
-
result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
|
|
128
|
-
if not os.path.exists(result_csv_path):
|
|
129
|
-
write_csv(result_csv_title, result_csv_path)
|
|
130
|
-
if not os.path.exists(details_csv_path):
|
|
131
|
-
write_csv(detail_csv_title, details_csv_path)
|
|
132
|
-
config = CompareConfig("", "", result_csv_path, details_csv_path)
|
|
133
|
-
try:
|
|
134
|
-
npu_data, gpu_data = online_config.npu_data, online_config.gpu_data
|
|
135
|
-
check_csv_columns(npu_data.columns, "npu_csv")
|
|
136
|
-
check_csv_columns(gpu_data.columns, "gpu_csv")
|
|
137
|
-
analyse_csv(npu_data, gpu_data, config)
|
|
138
|
-
except Exception as err:
|
|
139
|
-
logger.error(f"Online api precision compare Error: {str(err)}")
|
|
140
|
-
change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
141
|
-
change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
|
|
142
|
-
|
|
143
|
-
|
|
144
120
|
def analyse_csv(npu_data, gpu_data, config):
|
|
145
121
|
forward_status, backward_status = [], []
|
|
146
122
|
last_api_name, last_api_dtype, last_api_full_name = None, None, None
|
|
@@ -66,13 +66,6 @@ class Comparator:
|
|
|
66
66
|
self.save_path_list = [result_csv_path]
|
|
67
67
|
self.detail_save_path_list = [details_csv_path]
|
|
68
68
|
|
|
69
|
-
if config and config.online_config.is_online:
|
|
70
|
-
self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv")
|
|
71
|
-
self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv")
|
|
72
|
-
self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list]
|
|
73
|
-
self.detail_save_path_list = \
|
|
74
|
-
[self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
|
|
75
|
-
|
|
76
69
|
self.registry = self._register_compare_func()
|
|
77
70
|
|
|
78
71
|
if not is_continue_run_ut:
|
|
@@ -245,9 +238,8 @@ class Comparator:
|
|
|
245
238
|
self.write_detail_csv(args)
|
|
246
239
|
|
|
247
240
|
|
|
248
|
-
def compare_output(self, full_api_name, data_info
|
|
241
|
+
def compare_output(self, full_api_name, data_info):
|
|
249
242
|
"""Get compare result and write to result and detail csv.
|
|
250
|
-
is_online: bool, default False. True: called by online api precision compare, only compare without write to csv.
|
|
251
243
|
"""
|
|
252
244
|
_, api_name = extract_basic_api_segments(full_api_name)
|
|
253
245
|
if not api_name:
|
|
@@ -280,9 +272,7 @@ class Comparator:
|
|
|
280
272
|
fwd_compare_alg_results,
|
|
281
273
|
bwd_compare_alg_results,
|
|
282
274
|
data_info.rank)
|
|
283
|
-
|
|
284
|
-
# get run_ut compare detail
|
|
285
|
-
return self._get_run_ut_detail(result_info)
|
|
275
|
+
|
|
286
276
|
self.record_results(result_info)
|
|
287
277
|
return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
|
|
288
278
|
or bwd_success_status == CompareConst.SPACE
|
|
@@ -39,7 +39,12 @@ from msprobe.core.common.const import FileCheckConst, Const
|
|
|
39
39
|
from msprobe.core.common.utils import CompareException
|
|
40
40
|
|
|
41
41
|
|
|
42
|
-
def split_json_file(input_file, num_splits, filter_api):
|
|
42
|
+
def split_json_file(input_file, num_splits, filter_api, device_id):
|
|
43
|
+
max_processes = len(device_id) * 8
|
|
44
|
+
if num_splits > max_processes:
|
|
45
|
+
logger.warning(f"A device supports a maximum of 8 processes. "
|
|
46
|
+
f"The total number of processes exceeds the limit, and it is set to {max_processes}.")
|
|
47
|
+
num_splits = max_processes
|
|
43
48
|
forward_data, backward_data, real_data_path = parse_json_info_forward_backward(input_file)
|
|
44
49
|
input_dir = os.path.dirname(os.path.abspath(input_file))
|
|
45
50
|
if filter_api:
|
|
@@ -88,7 +93,7 @@ def split_json_file(input_file, num_splits, filter_api):
|
|
|
88
93
|
logger.error(f"File not found or could not be deleted: {file}")
|
|
89
94
|
msg = 'ERROR: Split json file failed, please check the input file and try again.'
|
|
90
95
|
raise CompareException(CompareException.PARSE_FILE_ERROR, msg) from e
|
|
91
|
-
return split_files, total_items
|
|
96
|
+
return split_files, total_items, num_splits
|
|
92
97
|
|
|
93
98
|
|
|
94
99
|
def signal_handler(signum, frame):
|
|
@@ -127,7 +132,8 @@ def run_parallel_ut(config):
|
|
|
127
132
|
def read_process_output(process):
|
|
128
133
|
try:
|
|
129
134
|
while True:
|
|
130
|
-
|
|
135
|
+
# 子进程标准输出流与进程本身状态是分开的,因此增加判断。子进程返回值非None表示子进程结束,标准输出为None表示结束。
|
|
136
|
+
if process.poll() is not None or process.stdout is None:
|
|
131
137
|
break
|
|
132
138
|
output = process.stdout.readline()
|
|
133
139
|
if output == '':
|
|
@@ -175,12 +181,17 @@ def run_parallel_ut(config):
|
|
|
175
181
|
|
|
176
182
|
try:
|
|
177
183
|
for process in processes:
|
|
178
|
-
process.
|
|
184
|
+
process.wait() # wait仅阻塞,不捕获标准输出和标准错误,原communicate不仅阻塞,而且捕获标准输出和标准错误
|
|
179
185
|
except KeyboardInterrupt:
|
|
180
186
|
logger.warning("Interrupted by user, terminating processes and cleaning up...")
|
|
181
187
|
except Exception as e:
|
|
182
188
|
logger.error(f"An unexpected error occurred: {e}")
|
|
183
189
|
finally:
|
|
190
|
+
# 最后再更新一次进度条,避免因缓存写入等原因子进程结束而进度未刷新的问题
|
|
191
|
+
if wait_for_file_write_complete(config.result_csv_path):
|
|
192
|
+
result_file = read_csv(config.result_csv_path)
|
|
193
|
+
completed_items = len(result_file)
|
|
194
|
+
progress_bar.update(completed_items - progress_bar.n)
|
|
184
195
|
if progress_bar.n < config.total_items:
|
|
185
196
|
logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to " \
|
|
186
197
|
"the result CSV file will be utilized to resume the UT task.")
|
|
@@ -195,6 +206,22 @@ def run_parallel_ut(config):
|
|
|
195
206
|
logger.error(f"An unexpected error occurred: {e}")
|
|
196
207
|
|
|
197
208
|
|
|
209
|
+
def wait_for_file_write_complete(file_path, timeout=3600):
|
|
210
|
+
last_size = 0
|
|
211
|
+
start_time = time.time() # 记录开始时间
|
|
212
|
+
while True:
|
|
213
|
+
current_size = os.path.getsize(file_path)
|
|
214
|
+
# 检查是否文件大小未变化
|
|
215
|
+
if current_size == last_size:
|
|
216
|
+
return True # 文件写入完成,返回 True
|
|
217
|
+
last_size = current_size
|
|
218
|
+
# 检查是否超时
|
|
219
|
+
if time.time() - start_time > timeout:
|
|
220
|
+
logger.error("write the result csv file timeout.")
|
|
221
|
+
return False # 超时,返回 False
|
|
222
|
+
time.sleep(0.1) # 适当的延时
|
|
223
|
+
|
|
224
|
+
|
|
198
225
|
def prepare_config(args):
|
|
199
226
|
api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
|
|
200
227
|
ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
@@ -203,7 +230,9 @@ def prepare_config(args):
|
|
|
203
230
|
create_directory(out_path)
|
|
204
231
|
out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE)
|
|
205
232
|
out_path = out_path_checker.common_check()
|
|
206
|
-
split_files, total_items = split_json_file(api_info, args.num_splits,
|
|
233
|
+
split_files, total_items, modified_num_splits = split_json_file(api_info, args.num_splits,
|
|
234
|
+
args.filter_api, args.device_id)
|
|
235
|
+
args.num_splits = modified_num_splits
|
|
207
236
|
config_path = args.config_path if args.config_path else None
|
|
208
237
|
if config_path:
|
|
209
238
|
config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
|