mindstudio-probe 8.3.0__py3-none-any.whl → 8.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/RECORD +44 -54
  3. msprobe/README.md +8 -5
  4. msprobe/core/common/const.py +17 -3
  5. msprobe/core/common/file_utils.py +64 -13
  6. msprobe/core/common/framework_adapter.py +10 -1
  7. msprobe/core/common/utils.py +17 -0
  8. msprobe/core/compare/utils.py +26 -6
  9. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +6 -1
  10. msprobe/core/hook_manager.py +2 -16
  11. msprobe/core/service.py +5 -16
  12. msprobe/docs/01.installation.md +2 -0
  13. msprobe/docs/02.config_introduction.md +0 -13
  14. msprobe/docs/05.data_dump_PyTorch.md +1 -1
  15. msprobe/docs/07.accuracy_checker_PyTorch.md +13 -13
  16. msprobe/docs/10.accuracy_compare_PyTorch.md +6 -6
  17. msprobe/docs/14.data_parse_PyTorch.md +2 -0
  18. msprobe/docs/19.monitor.md +4 -4
  19. msprobe/docs/21.visualization_PyTorch.md +1 -1
  20. msprobe/docs/25.tool_function_introduction.md +0 -1
  21. msprobe/docs/32.ckpt_compare.md +5 -5
  22. msprobe/mindspore/monitor/module_hook.py +17 -20
  23. msprobe/pytorch/api_accuracy_checker/common/config.py +3 -36
  24. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +0 -24
  25. msprobe/pytorch/api_accuracy_checker/compare/compare.py +2 -12
  26. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -6
  27. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +34 -5
  28. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +12 -132
  29. msprobe/pytorch/common/utils.py +0 -70
  30. msprobe/pytorch/debugger/debugger_config.py +0 -10
  31. msprobe/pytorch/dump/module_dump/module_processer.py +18 -3
  32. msprobe/pytorch/hook_module/api_register.py +14 -3
  33. msprobe/pytorch/monitor/module_hook.py +16 -34
  34. msprobe/pytorch/pt_config.py +2 -51
  35. msprobe/pytorch/pytorch_service.py +10 -14
  36. msprobe/visualization/builder/graph_builder.py +2 -2
  37. msprobe/visualization/builder/graph_merger.py +13 -0
  38. msprobe/visualization/db_utils.py +42 -18
  39. msprobe/visualization/graph/graph.py +13 -9
  40. msprobe/visualization/graph_service.py +20 -10
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +0 -295
  42. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py +0 -0
  43. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +0 -205
  44. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +0 -378
  45. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +0 -239
  46. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +0 -115
  47. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +0 -250
  48. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +0 -63
  49. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +0 -198
  50. msprobe/pytorch/attl_manager.py +0 -65
  51. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/LICENSE +0 -0
  52. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/WHEEL +0 -0
  53. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/entry_points.txt +0 -0
  54. {mindstudio_probe-8.3.0.dist-info → mindstudio_probe-8.3.2.dist-info}/top_level.txt +0 -0
msprobe/core/service.py CHANGED
@@ -35,7 +35,6 @@ class BaseService(ABC):
35
35
  self.config.level = getattr(config, 'level_ori', config.level) # 兼容MindSpore配置
36
36
  self.model = None
37
37
  self.data_collector = build_data_collector(self.config)
38
- self.attl_manager = None
39
38
  self.current_iter = 0
40
39
  self.loop = 0
41
40
  self.init_step = 0
@@ -91,10 +90,6 @@ class BaseService(ABC):
91
90
  self.config.task in self.data_collector.tasks_need_tensor_data or
92
91
  (self.config.task == Const.STATISTICS and self.config.tensor_list)
93
92
  )
94
-
95
- @property
96
- def _is_online_run_ut(self):
97
- return getattr(self.config, "online_run_ut", False)
98
93
 
99
94
  @property
100
95
  @abstractmethod
@@ -146,11 +141,9 @@ class BaseService(ABC):
146
141
  self.primitive_switch = True
147
142
  self._change_jit_switch(True)
148
143
  self.logger.info(f"Dump switch is turned on at step {self.current_iter}. ")
149
- if self._is_online_run_ut:
150
- self._run_ut_dispatch(True)
151
- else:
152
- self.create_dirs()
153
- self.logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
144
+
145
+ self.create_dirs()
146
+ self.logger.info(f"Dump data will be saved in {self.dump_iter_dir}.")
154
147
 
155
148
  def stop(self):
156
149
  """通用stop模板"""
@@ -165,8 +158,7 @@ class BaseService(ABC):
165
158
  self._change_jit_switch(False)
166
159
  if self._is_l2_level:
167
160
  return
168
- if self._is_online_run_ut:
169
- self._run_ut_dispatch(False)
161
+
170
162
  self._process_async_dump()
171
163
  self.data_collector.write_json()
172
164
 
@@ -266,8 +258,6 @@ class BaseService(ABC):
266
258
  end_service = self.config.step and self.current_iter > max(self.config.step) or \
267
259
  self.data_collector and self.data_collector.data_processor.is_terminated
268
260
  if end_service:
269
- if self._is_online_run_ut and self.attl_manager:
270
- self.attl_manager.attl_stop()
271
261
  self.primitive_switch = False
272
262
  self._change_jit_switch(False)
273
263
  Runtime.is_running = False
@@ -310,8 +300,7 @@ class BaseService(ABC):
310
300
  if root_model and isinstance(root_model, list):
311
301
  root_model = root_model[0]
312
302
  self.logger.warning("Infer model can only input one to support token_range, choose the first one.")
313
- if self._is_online_run_ut:
314
- return
303
+
315
304
  root_model.register_forward_pre_hook(infer_hook)
316
305
 
317
306
  def _create_l2_dirs(self, cur_rank):
@@ -16,6 +16,8 @@ pip install mindstudio-probe
16
16
 
17
17
  | 版本 | 发布日期 |支持 PyTorch 版本|支持 MindSpore 版本| 下载链接 |校验码|
18
18
  |:-----:|:----------:|:--:|:--:|:----------------------------------------------------------------------------------------------------------------------------------:|:--:|
19
+ | 8.3.0 | 2025.10.30 |1.11/2.0/2.1/2.2/2.5/2.6/2.7|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.3.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.3/mindstudio_probe-8.3.0-py3-none-any.whl) |e933657b8ceb20774f924865d1f47978bd49cd1d9cec5fb60dec4f18802afbcc|
20
+ | 8.2.1 | 2025.9.29 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.2.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.2/mindstudio_probe-8.2.1-py3-none-any.whl) |2152fadefec3d70148a910a54f2cfb6fccfa60b5ccd59f835f20289a7a5f4764|
19
21
  | 8.2.0 | 2025.9.03 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.2.0-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.2/mindstudio_probe-8.2.0-py3-none-any.whl) |bbc1577d76754adf987069308177d3e0a04e36de9c7f22e75c34cf4ad0ce1af2|
20
22
  | 8.1.2 | 2025.8.01 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.1.2-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.1/mindstudio_probe-8.1.2-py3-none-any.whl) |ff07bb81fddd3b8f3096d119ca1481bde8fdb24f10644def5250caad727448ab|
21
23
  | 8.1.1 | 2025.6.20 |1.11/2.0/2.1/2.2|2.4.0/2.5.0/2.6.0| [mindstudio_probe-8.1.1-py3-none-any.whl](https://ptdbg.obs.myhuaweicloud.com/msprobe/8.1/mindstudio_probe-8.1.1-py3-none-any.whl) |2aad10a243575544d7feef552caf4d06aa93028488ebd0bbc9aa350379da859d|
@@ -73,14 +73,7 @@
73
73
  | data_mode | 与[ 1.2 task 配置为 statistics ](#12-task-配置为-statistics)中的解释相同 | 否 |
74
74
  | file_format | tensor 数据的保存格式,str 类型,仅支持 MindSpore 静态图场景的 L2 级别配置该字段,其他场景不生效。可选参数:<br/> "bin":dump 的 tensor 文件为二进制格式;<br/>"npy":dump 的 tensor 文件后缀为 .npy,默认值。 | 否 |
75
75
  | summary_mode | 控制 dump 文件输出的模式,str 类型,支持 PyTorch、MSAdapter、MindSpore 动态图。可选参数:<br/> md5:dump 输出包含 CRC-32 值以及 API 统计信息的 dump.json 文件,用于验证数据的完整性;<br/> statistics:dump 仅输出包含 API 统计信息的 dump.json 文件,默认值。| 否 |
76
- | online_run_ut<sup>a</sup> | 在线预检模式开关,bool 类型,可选参数 true(开启)、false(关闭),默认未配置,表示关闭。配置为 true 表示开启在线预检。| 否 |
77
- | nfs_path<sup>a</sup> | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。仅在 online_run_ut 字段配置为 true 时生效,配置该参数后 host 和 port 不生效。 | 否 |
78
- | host<sup>a</sup> | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的局域网 IP 地址。仅在 online_run_ut 字段配置为 true 时生效,局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。 | 否 |
79
- | port<sup>a</sup> | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,NPU 侧须配置为 GPU 侧的端口号。仅在 online_run_ut 字段配置为 true 时生效,局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。| 否 |
80
76
 
81
- **说明**:
82
-
83
- 1. online_run_ut、nfs_path、host、port 等字段仅在线预检场景 NPU 机器生效。
84
77
 
85
78
  **示例**:
86
79
  - [PyTorch场景](03.config_examples.md#12-task-配置为-tensor)
@@ -95,17 +88,11 @@
95
88
  | white_list<sup>a</sup> | API dump 白名单,仅对指定的 API 进行 dump。<br/>**配置示例**:"white_list": ["conv1d", "conv2d"]。默认未配置白名单,即 dump 全量 API 数据。 | 否 |
96
89
  | black_list<sup>a</sup> | API dump 黑名单,被指定的 API 不进行 dump。<br/>**配置示例**:"black_list": ["conv1d", "conv2d"]。默认未配置黑名单,即 dump 全量 API 数据。 | 否 |
97
90
  | error_data_path | 配置保存精度未达标的 API 输入输出数据路径,默认为当前路径。<br/>**配置示例**:"error_data_path": "./"。 | 否 |
98
- | is_online<sup>b</sup> | 在线预检模式开关,bool 类型,可选参数 true(开启)、false(关闭),默认关闭。 | 否 |
99
- | nfs_path<sup>b</sup> | 在线预检模式共享存储目录路径,str 类型,用于 GPU 设备和 NPU 设备间进行通信。配置该参数后 host 和 port 不生效,仅在 is_online 字段配置为 true 时生效。 | 否 |
100
- | host<sup>b</sup> | 在线预检模式局域网场景信息接收端 IP,str 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机地址 127.0.0.1 或本机局域网 IP。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。仅在 is_online 字段配置为 true 时生效。 | 否 |
101
- | port<sup>b</sup> | 在线预检模式局域网场景信息接收端端口号,int 类型,用于 GPU 设备和 NPU 设备间进行通信,GPU 侧配置为本机可用端口。局域网场景时,不能配置 nfs_path 参数,否则局域网场景不生效。仅在 is_online 字段配置为 true 时生效。| 否 |
102
- | rank_list<sup>b</sup> | 指定在线预检的 Rank ID,默认值为 [0],list[int] 类型,应配置为大于等于 0 的整数,且须根据实际卡的 Rank ID 配置,若所配置的值大于实际训练所运行的卡的 Rank ID,则在线预检输出数据为空。GPU 和 NPU 须配置一致。仅在 is_online 字段配置为 true 时生效。 | 否 |
103
91
 
104
92
  **说明**:
105
93
 
106
94
  1. white_list 和 black_list 同时配置时,二者配置的 API 名单若无交集,则白名单生效,若 API 名单存在交集,则白名单排除的部分以及交集的 API 不进行 dump。
107
95
 
108
- 2. is_online、nfs_path、host、port、rank_list 等字段仅在线预检场景 GPU 机器生效。
109
96
 
110
97
  **示例**:
111
98
  ```json
@@ -445,7 +445,7 @@ seed_all()
445
445
  debugger = PrecisionDebugger(config_path="./config.json", dump_path="./dump_path")
446
446
  # 模型定义及初始化等操作
447
447
  prompts = ["Hello, my name is"]
448
- sampling_params = SamplingParams(temprature=0.8, top_p=0.95)
448
+ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
449
449
  llm = LLM(model='...')
450
450
  model = llm.llm_engine.model_executor.driver_worker.worker.model_runner.get_model()
451
451
  # 开启数据dump, 指定采集推理模型逐字符循环推理中的第1~3次
@@ -34,17 +34,17 @@ run_ut 预检操作包括以下两种方式:
34
34
  msprobe -f pytorch run_ut -api_info ./dump_path/step{step_number}/rank{rank_number}/dump.json
35
35
  ```
36
36
 
37
- | 参数名称 | 解释 | 是否必选 |
38
- |-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ---------------------------------- |
39
- | -f 或 --framework | 指定训练框架。pytorch。 | 是 |
40
- | -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。 | 是 |
41
- | -save_error_data | 保存精度未达标的 API 输入输出数据。 | 否 |
42
- | -o 或 --out_path | 指定 run_ut 执行结果存盘路径,默认“./”。 | 否 |
43
- | -j 或 --jit_compile | 开启 jit 编译。 | 否 |
44
- | -d 或 --device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为 0。 | 否 |
45
- | -csv_path 或 --result_csv_path | 指定本次运行中断时生成的 `accuracy_checking_result_{timestamp}.csv` 文件路径,执行 run_ut 中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的 `accuracy_checking_result_{timestamp}.csv` 文件。详见 [3.3 断点续检](#33-断点续检)。 | run_ut 操作中断后继续执行场景下必须配置 |
46
- | -f 或 --filter_api | 过滤模型中除最大值和最小值以外其他参数和结构相同的 API。适用于模型较大且重复 API 较多的场景。 | 否 |
47
- | -config 或 --config_path | 指定离线预检操作过程中额外配置(包括黑名单、白名单等)的 [config.json](../config.json) 文件,默认未配置。config.json 文件的配置可参考[配置文件介绍](./02.config_introduction.md)。 | 否 |
37
+ | 参数名称 | 解释 | 是否必选 |
38
+ |-------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ---------------------------------- |
39
+ | -f 或 --framework | 指定训练框架,当前场景配置为pytorch。 | 是 |
40
+ | -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。 | 是 |
41
+ | -save_error_data | 保存精度未达标的 API 输入输出数据。 | 否 |
42
+ | -o 或 --out_path | 指定 run_ut 执行结果存盘路径,默认“./”。 | 否 |
43
+ | -j 或 --jit_compile | 开启 jit 编译。 | 否 |
44
+ | -d 或 --device | 指定 Device ID,选择 UT 代码运行所在的卡,默认值为 0。 | 否 |
45
+ | -csv_path 或 --result_csv_path | 指定本次运行中断时生成的 `accuracy_checking_result_{timestamp}.csv` 文件路径,执行 run_ut 中断时,若想从中断处继续执行,配置此参数即可。需要指定为上次中断的 `accuracy_checking_result_{timestamp}.csv` 文件。详见 [3.3 断点续检](#33-断点续检)。 | run_ut 操作中断后继续执行场景下必须配置 |
46
+ | -f 或 --filter_api | 过滤模型中除最大值和最小值以外其他参数和结构相同的 API。适用于模型较大且重复 API 较多的场景。 | 否 |
47
+ | -config 或 --config_path | 指定离线预检操作过程中额外配置(包括黑名单、白名单等)的 [config.json](../config.json) 文件,默认未配置。config.json 文件的配置可参考[配置文件介绍](./02.config_introduction.md)。 | 否 |
48
48
 
49
49
  run_ut 执行结果包括 `accuracy_checking_result_{timestamp}.csv` 和 `accuracy_checking_details_{timestamp}.csv` 两个文件。`accuracy_checking_result_{timestamp}.csv` 属于 API 级,标明每个 API 是否通过测试。建议用户先查看 `accuracy_checking_result_{timestamp}.csv` 文件,对于其中没有通过测试的或者特定感兴趣的 API,根据其 API name 字段在 `accuracy_checking_details_{timestamp}.csv` 中查询其各个输出的达标情况以及比较指标。详细介绍请参见[ 4 预检结果](#4-预检结果)。
50
50
 
@@ -104,7 +104,7 @@ msprobe -f pytorch multi_run_ut -api_info ./dump_path/step{step_number}/rank{ran
104
104
 
105
105
  | 参数名称 | 解释 | 是否必选 |
106
106
  | ---------------------------- | ------------------------------------------------------------ | ---------------------------------- |
107
- | -f 或 --framework | 指定训练框架。pytorch。 | 是 |
107
+ | -f 或 --framework | 指定训练框架,当前场景配置为pytorch。 | 是 |
108
108
  | -api_info 或 --api_info_file | 指定 API 信息文件 dump.json。 | 是 |
109
109
  | -save_error_data | 保存精度未达标的 API 输入输出数据。 | 否 |
110
110
  | -o 或 --out_path | 指定 run_ut 执行结果存盘路径,默认“./”。 | 否 |
@@ -216,7 +216,7 @@ msprobe -f pytorch api_precision_compare -npu /home/xxx/npu/accuracy_checking_de
216
216
 
217
217
  | 参数名称 | 说明 | 是否必选 |
218
218
  |-----------------------| ------------- | -------- |
219
- | -f 或 --framework | 指定训练框架。pytorch。 | 是 |
219
+ | -f 或 --framework | 指定训练框架,当前场景配置为pytorch。 | 是 |
220
220
  | -npu 或 --npu_csv_path | NPU 预检结果 `accuracy_checking_details_{timestamp}.csv` 文件路径。默认从当前目录下识别该文件。 | 是 |
221
221
  | -gpu 或 --gpu_csv_path | GPU 预检结果 `accuracy_checking_details_{timestamp}.csv` 文件路径。默认从当前目录下识别该文件。 | 是 |
222
222
  | -o 或 --out_path | 指定 api_precision_compare.py 执行结果存盘路径,默认为当前目录。 | 否 |
@@ -53,15 +53,15 @@ msprobe -f pytorch compare -i ./compare.json -o ./output -s
53
53
 
54
54
  | 参数名 | 说明 | 是否必选 |
55
55
  |---------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -------- |
56
- | -f 或 --framework | 指定训练框架。pytorch。 | 是 |
56
+ | -f 或 --framework | 指定训练框架,当前场景配置为pytorch。 | 是 |
57
57
  | -i 或 --input_path | 指定[比对文件](#51-比对文件),str 类型。 | 是 |
58
58
  | -o 或 --output_path | 配置比对结果文件存盘目录,str 类型,默认在当前目录创建output目录。文件名称基于时间戳自动生成,格式为:`compare_result_{timestamp}.xlsx`。<br>提示:output目录下与结果件同名文件将被删除覆盖。 | 否 |
59
59
  | -s 或 --stack_mode | 比对结果展示调用栈信息(NPU_Stack_Info)的开关,bool 类型。单卡场景开启时,根据[比对文件](#51-比对文件)的参数说明配置stack_path;多卡场景开启时,自动识别npu_dump目录下stack.json文件,如存在生成详细调用栈信息,否则不生成,此参数不生效。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 |
60
60
  | -c 或 --compare_only | 仅比对开关,bool 类型。该参数默认未配置,会启用自动精度分析,工具自动针对比对结果进行分析,识别到第一个精度可能不达标节点(在比对结果文件中的 Accuracy Reached or Not 列显示为 No),并给出问题可能产生的原因(打屏展示并生成 `advisor_{timestamp}.txt` 文件)。通过配置该参数取消自动精度分析,仅输出比对结果表格。 | 否 |
61
61
  | -f 或 --fuzzy_match | 模糊匹配,bool 类型。开启后,对于网络中同一层级且命名仅调用次数不同的 API,可匹配并进行比对。通过直接配置该参数开启,默认未配置,表示关闭。 | 否 |
62
62
  | -hl 或 --highlight | 高亮颜色标记。开启后,比对结果件中通过红色或黄色标记精度可疑API或模块。通过直接配置该参数开启,默认未配置,表示关闭。 开启高亮颜色标记后,比对性能降低,如果比对结果行数超出excel单页限制,程序强制关闭高亮颜色标记。 | 否 |
63
- | -dm或--data_mapping | 自定义映射关系比对。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件](#52-自定义映射文件)。仅[API和模块无法自动匹配场景](#213-api和模块无法自动匹配场景)需要配置。仅支持逐卡比对,即使用[比对文件](#51-比对文件)的单卡场景示例。 | 否 |
64
- | -da或--diff_analyze | 自动识别网络中首差异节点,支持md5、统计量等dump数据。支持单卡/多卡场景。 | 否 |
63
+ | -dm 或 --data_mapping | 自定义映射关系比对。需要指定自定义映射文件*.yaml。自定义映射文件的格式请参见[自定义映射文件](#52-自定义映射文件)。仅[API和模块无法自动匹配场景](#213-api和模块无法自动匹配场景)需要配置。仅支持逐卡比对,即使用[比对文件](#51-比对文件)的单卡场景示例。 | 否 |
64
+ | -da 或 --diff_analyze | 自动识别网络中首差异节点,支持md5、统计量等dump数据。支持单卡/多卡场景。 | 否 |
65
65
 
66
66
  #### 2.1.2 整网比对场景
67
67
 
@@ -396,7 +396,7 @@ msprobe -f pytorch merge_result -i ./input_dir -o ./output_dir -config ./config.
396
396
 
397
397
  | 参数名 | 说明 | 是否必选 |
398
398
  | --------------------- |-------------------------------------------------------------------------------------------------------------------| -------- |
399
- | -f 或 --framework | 指定训练框架。pytorch。 | 是 |
399
+ | -f 或 --framework | 指定训练框架,当前场景配置为pytorch。 | 是 |
400
400
  | -i 或 --input_dir | 多卡比对结果存盘目录,即使用compare比对的结果输出目录,str类型。所有比对结果应全部为真实数据比对结果或统计数据比对结果,否则可能导致汇总数据不完整。 | 是 |
401
401
  | -o 或 --output_dir | 数据提取汇总结果存盘目录,str类型。文件名称基于时间戳自动生成,格式为:`multi_ranks_compare_merge_{timestamp}.xlsx`。<br>提示:output目录下与结果件同名文件将被删除覆盖。 | 是 |
402
402
  | -config或--config-path | 指定需要汇总数据的API和比对指标的yaml文件路径,str类型。<br>yaml文件详细介绍见下文“**yaml文件说明**”。 | 是 |
@@ -527,7 +527,7 @@ input_args、input_kwargs和output使用统一的命名规则,当值是list类
527
527
  "input_args": [
528
528
  {
529
529
  "type": "torch.Tensor",
530
- "dytpe": "torch_float32",
530
+ "dtype": "torch_float32",
531
531
  "shape": [
532
532
  1,
533
533
  64,
@@ -604,7 +604,7 @@ output是list,长度为1,第0项后面是Tensor,命名结束;按照顺
604
604
  ```
605
605
  Functional.max_pool2d.0.forward.output.0
606
606
  ```
607
- 综上,生成的的op_name为
607
+ 综上,生成的op_name为
608
608
  ```
609
609
  Functional.max_pool2d.0.forward.input.0
610
610
  Functional.max_pool2d.0.forward.input.1
@@ -8,6 +8,8 @@
8
8
 
9
9
  依赖:CANN 包中的 msaccucmp 工具,需要安装 Ascend-CANN-toolkit,详见《[CANN 软件安装指南](https://www.hiascend.com/document/detail/zh/canncommercial/700/envdeployment/instg/instg_0001.html)》。
10
10
 
11
+ **安全注意事项**: 工具使用时传入的msaccucmp.py文件路径参数请务必确保为CANN中自带的msaccucmp.py原始路径且文件内容未被篡改,工具会在解析数据时直接执行该文件,用户需自行保证该文件内容安全性。
12
+
11
13
  ## 2 数据解析操作指导
12
14
 
13
15
  ### 2.1 进入 parse 交互式界面
@@ -24,7 +24,7 @@
24
24
  | [采集module堆栈信息](#采集module堆栈信息) | 采集监控的第一个 step 的 module 对应的堆栈信息辅助问题定位 | PyTorch、MindSpore |
25
25
  | [指定监控对象](#指定监控对象) | 指定监控的nn.Module(nn.Cell)及对应的输入输出 | PyTorch、MindSpore |
26
26
  | [打印模型结构](#打印模型结构) | 打印模型结构 | PyTorch |
27
- | [l2可解释特征监控](#l2可解释特征监控) | 开启模型状态的高阶监控 | PyTorch |
27
+ | [l2可解释特征监控](#l2可解释特征监控) | 开启模型状态的高阶监控 | PyTorch、MindSpore |
28
28
  | [输出格式和统计量](#输出格式和统计量) | format PyTorch支持`csv`、`tensorboard`和`api`,MindSpore仅支持`csv`,`ops`、`ndigits`均支持 | PyTorch、MindSpore |
29
29
  | [mbs粒度梯度监控](#mbs粒度梯度监控) | 开启梯度监控时,采集聚合前梯度时支持`micro_batch_size`粒度 | PyTorch、MindSpore |
30
30
  | [异常告警](#异常告警) | 监控对象指标异常时自动告警,支持异常数据落盘 | PyTorch、MindSpore |
@@ -37,9 +37,9 @@
37
37
  推荐使用方式:权重梯度的监控性能损耗小(20B dense模型全量权重梯度监控,时间增加<1%,内存增加<1%),可以长期开启。激活值监控性能损耗大,在必要时开启或者仅监控部分。
38
38
 
39
39
  ### 工具使能
40
- 在实际训练代码中找到模型、优化器定义的位置,使能monitor工具,通过配置文件(json)控制工具行为。如下分别为Pytorch场景和MindSpore场景下的使能方式。
40
+ 在实际训练代码中找到模型、优化器定义的位置,使能monitor工具,通过配置文件(json)控制工具行为。如下分别为PyTorch场景和MindSpore场景下的使能方式。
41
41
 
42
- - Pytorch使能方式:
42
+ - PyTorch使能方式:
43
43
  ```python
44
44
  # Megatron-LM(core_r0.6.0) training.py
45
45
  model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
@@ -112,7 +112,7 @@ monitor.set_monitor(
112
112
 
113
113
  请注意以下两点:
114
114
  - Mindspore功能在1.2.2版本后支持, <1.2.2版本不支持
115
- - 上述接口使用方式为1.2.2后更新的最新接口使用方式, <1.2.2版本的Pytorch旧接口使用方式为:
115
+ - 上述接口使用方式为1.2.2后更新的最新接口使用方式, <1.2.2版本的PyTorch旧接口使用方式为:
116
116
  ```Python
117
117
  from msprobe.pytorch import TrainerMon
118
118
  monitor = TrainerMon(
@@ -2,7 +2,7 @@
2
2
 
3
3
  分级可视化工具将msprobe工具dump的精度数据进行解析,还原模型图结构,实现模型各个层级的精度数据比对,方便用户理解模型结构、分析精度问题。
4
4
 
5
- 工具支持PyTorch版本:2.1/2.2
5
+ 工具支持PyTorch版本:2.1/2.2/2.5/2.6/2.7
6
6
 
7
7
  ## 工具特性
8
8
 
@@ -7,7 +7,6 @@
7
7
  | [数据采集<br>(dump)](./05.data_dump_PyTorch.md) | 采集模型训练过程中的API或Module层级的前反向输入输出数据,包括层次关系、统计值信息、真实数据和调用栈等。 | 1、将模型中训练的API或Module的前反向输入输出数据保存下来分析<br> 2、模型出现溢出时,可用于查看哪些API或Module出现了溢出 | 1、API级数据采集仅支持白名单列表上的API<br>2、工具会做一些同步操作,引入工具可能会导致一些同步问题消失<br>3、当前对inplace操作API或Module的支持度有限<br>4、暂不支持参数及参数梯度的采集 |
8
8
  | [离线预检<br>(api_accuracy_checker)](./07.accuracy_checker_PyTorch.md) | 为网络中每个API创建用例,检验其精度,并根据不同比对算法综合判定API在NPU上的精度是否达标,快速找出精度差异API。 | 1、对模型中所有的API做精度初步排查<br>2、精度排查不受模型累计误差影响 | 1、依赖GPU环境<br>2、不支持通信算子<br>3、仅支持部分融合算子 |
9
9
  | [整网比对<br>(compare)](./10.accuracy_compare_PyTorch.md) | 计算模型整网NPU和标杆设备的精度误差指标,标记精度异常API或Module,助力快速定位精度问题根因。 | 1、整网比对定位精度可疑算子 | 1、由于使用整网dump数据,定位的可疑算子受累计误差影响<br>2、当模型规模较大时,比对所需时间较长 |
10
- | [在线预检<br>(online_api_accuracy_checker)](./08.accuracy_checker_online_PyTorch.md) | 通过TCP通信或共享存储空间的方式,进行在线精度预检,解决离线预检大数据量落盘、传输困难痛点。 | 1、使用离线预检,数据量较大落盘困难或传输耗时长时,可通过在线预检进行精度排查 | 1、依赖GPU环境,NPU和GPU能够通信<br>2、重计算模式下,不支持反向aten算子预检 |
11
10
  | [溢出检查<br>(overflow_checker)](./12.overflow_check_PyTorch.md) | 检测模型计算过程的输入输出,并在溢出时落盘数据,助力用户快速定位溢出位置。 | 1、当模型出现溢出时,用于快速定位最先溢出的API或Module<br>2、相比数据采集,性能更优,磁盘压力更小 | 1、局限性同数据采集 |
12
11
  | [数据解析<br>(parse_tool)](./14.data_parse_PyTorch.md) | 交互式界面处理解析kernel层级dump数据,便于查看分析。 | 1、比对kernel层级dump数据的一致性 | 1、仅限于NPU |
13
12
  | [无标杆比对<br>(free_benchmark)](./15.free_benchmarking_PyTorch.md) | 不依赖标杆数据,通过对算子输入增加微小扰动,计算扰动后输出与原始输出的相对误差,识别有精度风险算子。 | 1、无标杆数据场景下的算子精度排查<br>2、对个别算子进行升精度、“to cpu”等操作,以验证其对模型loss的影响 | 1、由于需要拷贝输入进行二次执行,所以在遇到大张量的输入时容易发生显存OOM的问题, 特别是反向比对过程。建议结合白名单使用<br>2、比对会延长训练时间,整网比对可能会造成严重的耗时膨胀,建议结合白名单使用 |
@@ -19,11 +19,11 @@ Megatron、MindSpeed的ckpt加载依赖megatron,请确保megatron在python环
19
19
  msprobe --framework pytorch config_check --compare path1 path2 -o output_path.json
20
20
  ```
21
21
 
22
- | 参数名 | 解释 | 是否必选 |
23
- |--------|-------|--------|
24
- | -f 或 --framework | 深度学习框架,str类型。支持参数:pytorch,mindspore,注意:msadaptor场景传入mindspore。 | 是 |
25
- | -c 或 --compare | 2个ckpt的路径 | 是 |
26
- | -o 或 --output | 比对结果输出路径,默认为 ./ckpt_similarity.json。输出路径存在时将报错终止。 | 否 |
22
+ | 参数名 | 解释 | 是否必选 |
23
+ |--------|----------------------------------------------------------------------------------------------|--------|
24
+ | -f 或 --framework | 深度学习框架,str类型。支持参数:pytorch,mindspore,注意:msadaptor场景传入mindspore。 | 是 |
25
+ | -c 或 --compare | 2个ckpt的路径。在ckpt传给工具加载前,用户需要确保ckpt是安全可信的,若ckpt来源官方有提供SHA256等校验值,用户必须要进行校验,以确保ckpt没有被篡改。 | 是 |
26
+ | -o 或 --output | 比对结果输出路径,默认为 ./ckpt_similarity.json。输出路径存在时将报错终止。 | 否 |
27
27
 
28
28
  Megatron-LM 和 MindSpeed 的 ckpt 目录结构如下:
29
29
 
@@ -163,15 +163,11 @@ class GradContext:
163
163
  def __init__(self) -> None:
164
164
  self.pre = {}
165
165
  self.post = {}
166
- self.acc_metric = {}
167
- self.acc = {}
168
166
  self.actv = {}
169
167
 
170
168
  def reset(self):
171
169
  self.pre.clear()
172
170
  self.post.clear()
173
- self.acc_metric.clear()
174
- self.acc.clear()
175
171
  self.actv.clear()
176
172
 
177
173
 
@@ -312,7 +308,6 @@ class TrainerMon:
312
308
  self.recording_l2_features = self.config.get('recording_l2_features', False)
313
309
  self.sa_order = self.config.get('sa_order', "s,b,h,d")
314
310
 
315
-
316
311
  if not self.cc_distribution.get('enable', False):
317
312
  self.cc_log_only = False
318
313
  else:
@@ -403,7 +398,7 @@ class TrainerMon:
403
398
  if self.monitoring:
404
399
  module_rank_valid = self.is_target_rank()
405
400
  step_condition = (context.step >= self.start_step and (
406
- context.step - self.start_step) % self.step_interval == 0)
401
+ context.step - self.start_step) % self.step_interval == 0)
407
402
  if module_rank_valid and step_condition:
408
403
  self.has_collect_times += 1
409
404
 
@@ -447,6 +442,7 @@ class TrainerMon:
447
442
  hook(optimizer, args, kwargs)
448
443
  step_final_hook(optimizer, args, kwargs)
449
444
  return out
445
+
450
446
  return wrapper
451
447
 
452
448
  if self.is_mindtorch:
@@ -541,7 +537,7 @@ class TrainerMon:
541
537
  if self.mv_distribution or self.ur_distribution or self.mg_direction:
542
538
  if self.is_mindtorch:
543
539
  context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, \
544
- context.param_adam_ratio = self.optimizer_mon.fetch_mv(self, self.param2name)
540
+ context.param_adam_ratio = self.optimizer_mon.fetch_mv(self, self.param2name)
545
541
  else:
546
542
  context.param_exp_avg, context.param_exp_avg_sq = self.get_mv_for_ms(optimizer)
547
543
 
@@ -564,7 +560,6 @@ class TrainerMon:
564
560
  context = self.optimizer_context[optimizer]
565
561
  self.generate_param_metrics(context, MonitorConst.POST_PARAM)
566
562
 
567
-
568
563
  if self.optimizer_hooked or not self.is_target_rank():
569
564
  return
570
565
 
@@ -577,7 +572,6 @@ class TrainerMon:
577
572
  if not self.wg_distribution:
578
573
  return
579
574
 
580
- get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric)
581
575
  get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post)
582
576
 
583
577
  def generate_param_map(self, tag, param_tensor):
@@ -671,7 +665,7 @@ class TrainerMon:
671
665
  if not self.wg_distribution:
672
666
  return
673
667
 
674
- self.summary_writer.write_metrics(self.ops, self.grad_context.acc_metric, step, 'grad_unreduced',
668
+ self.summary_writer.write_metrics(self.ops, self.grad_context.pre, step, 'grad_unreduced',
675
669
  use_micro_step=self.monitor_mbs_grad)
676
670
  self.summary_writer.write_metrics(self.ops, self.grad_context.post, step, 'grad_reduced')
677
671
 
@@ -810,7 +804,7 @@ class TrainerMon:
810
804
  f'{context.module_name}.{Const.INPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
811
805
  MonitorConst.ACTV, module_input))
812
806
  module_output = [tensor for tensor in module_output if isinstance(tensor, Tensor)] \
813
- if isinstance(module_output, tuple) else module_output
807
+ if isinstance(module_output, tuple) else module_output
814
808
  tbtag_tensor_map.update(
815
809
  self.build_tbtag_tensor_map(
816
810
  f'{context.module_name}.{Const.OUTPUT}', f'{MonitorConst.NAME_SEP}{context.micro_step}',
@@ -868,11 +862,13 @@ class TrainerMon:
868
862
  if version.parse(mindspore.__version__) >= version.parse('2.6.0'):
869
863
  def wrapper(module, args, kwargs, module_output):
870
864
  return fwd_hook_fun(module, args, kwargs, module_output, name)
865
+
871
866
  return module.register_forward_hook(wrapper, with_kwargs=True)
872
867
 
873
868
  else:
874
869
  def wrapper(module, args, module_output):
875
870
  return fwd_hook_fun(module, args, None, module_output, name)
871
+
876
872
  return module.register_forward_hook(wrapper)
877
873
 
878
874
  def extract_attention_feature_hook(module, args, kwargs, module_output, name):
@@ -880,7 +876,7 @@ class TrainerMon:
880
876
  if kwargs:
881
877
  kwargs_tensors = [tensor for tensor in kwargs.values() if isinstance(tensor, Tensor)]
882
878
  module_input.extend(kwargs_tensors)
883
-
879
+
884
880
  if module not in self.feature_hook_context_by_module:
885
881
  self.feature_hook_context_by_module[module] = FeatureHookContext(name)
886
882
  context: FeatureHookContext = self.feature_hook_context_by_module[module]
@@ -890,7 +886,7 @@ class TrainerMon:
890
886
  logger.warning(
891
887
  "Calculate attention feature failed, the length of module_input in attention hook's module should "
892
888
  "be greater than or equal to 2.")
893
-
889
+
894
890
  q_h = module_input[0]
895
891
  k_h = module_input[1]
896
892
  qkt = cal_qkt(q_h, k_h, order=self.sa_order)
@@ -985,25 +981,26 @@ class TrainerMon:
985
981
  self._hook_weights()
986
982
 
987
983
  def _hook_weights(self):
988
- context = self.grad_context
989
984
 
990
985
  @_no_grad()
991
- def param_hook(grad, context_dict, param, name):
986
+ def param_hook(grad, param, name):
992
987
  key = name
993
988
  if self.monitor_mbs_grad:
994
989
  key += f'{MonitorConst.NAME_SEP}{param.micro_step}'
995
990
  key = get_summary_writer_tag_name(key, 'acc_grad', self.rank)
996
991
  self.register_param_call_id("param_hook", key)
997
992
  param.micro_step += 1
998
-
993
+ grad_dict = {}
999
994
  if self.monitor_mbs_grad or (param.micro_step == self.micro_batch_number):
1000
- context_dict[key] = grad
995
+ grad_dict[key] = grad
996
+ get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre)
997
+
1001
998
  if param.micro_step == self.micro_batch_number:
1002
999
  param.micro_step = 0
1003
1000
 
1004
- def param_hook_wrapper(param_hook, context_dict, param, name):
1001
+ def param_hook_wrapper(param_hook, param, name):
1005
1002
  def wrapper(grad):
1006
- return param_hook(grad, context_dict, param, name)
1003
+ return param_hook(grad, param, name)
1007
1004
 
1008
1005
  return wrapper
1009
1006
 
@@ -1011,7 +1008,7 @@ class TrainerMon:
1011
1008
  for param, name in self.param2name.items():
1012
1009
  setattr(param, 'micro_step', 0)
1013
1010
  handle = param.register_hook(
1014
- param_hook_wrapper(param_hook, context_dict=context.acc, param=param, name=name))
1011
+ param_hook_wrapper(param_hook, param=param, name=name))
1015
1012
  self.handles['wgrads'].append(handle)
1016
1013
  self.weight_hooked = True
1017
1014
 
@@ -24,8 +24,7 @@ from msprobe.pytorch.pt_config import RunUTConfig
24
24
 
25
25
  RunUtConfig = namedtuple('RunUtConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path',
26
26
  'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list',
27
- 'black_list', 'error_data_path', 'online_config'])
28
- OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path'])
27
+ 'black_list', 'error_data_path'])
29
28
 
30
29
 
31
30
  class Config:
@@ -46,13 +45,7 @@ class Config:
46
45
  'white_list': list,
47
46
  'black_list': list,
48
47
  'error_data_path': str,
49
- 'precision': int,
50
- 'is_online': bool,
51
- 'nfs_path': str,
52
- 'host': str,
53
- 'port': int,
54
- 'rank_list': list,
55
- 'tls_path': str
48
+ 'precision': int
56
49
  }
57
50
  if key not in validators:
58
51
  raise ValueError(f"{key} must be one of {validators.keys()}")
@@ -68,10 +61,6 @@ class Config:
68
61
  RunUTConfig.check_filter_list_config(key, value)
69
62
  if key == 'error_data_path':
70
63
  RunUTConfig.check_error_data_path_config(value)
71
- if key == 'nfs_path':
72
- RunUTConfig.check_nfs_path_config(value)
73
- if key == 'tls_path':
74
- RunUTConfig.check_tls_path_config(value)
75
64
  return value
76
65
 
77
66
 
@@ -85,12 +74,6 @@ class CheckerConfig:
85
74
  self.white_list = msCheckerConfig.white_list
86
75
  self.black_list = msCheckerConfig.black_list
87
76
  self.error_data_path = msCheckerConfig.error_data_path
88
- self.is_online = msCheckerConfig.is_online
89
- self.nfs_path = msCheckerConfig.nfs_path
90
- self.host = msCheckerConfig.host
91
- self.port = msCheckerConfig.port
92
- self.rank_list = msCheckerConfig.rank_list
93
- self.tls_path = msCheckerConfig.tls_path
94
77
 
95
78
  if task_config:
96
79
  self.load_config(task_config)
@@ -99,22 +82,7 @@ class CheckerConfig:
99
82
  self.white_list = task_config.white_list
100
83
  self.black_list = task_config.black_list
101
84
  self.error_data_path = task_config.error_data_path
102
- self.is_online = task_config.is_online
103
- self.nfs_path = task_config.nfs_path
104
- self.host = task_config.host
105
- self.port = task_config.port
106
- self.rank_list = task_config.rank_list
107
- self.tls_path = task_config.tls_path
108
85
 
109
- def get_online_config(self):
110
- return OnlineConfig(
111
- is_online=self.is_online,
112
- nfs_path=self.nfs_path,
113
- host=self.host,
114
- port=self.port,
115
- rank_list=self.rank_list,
116
- tls_path=self.tls_path
117
- )
118
86
 
119
87
  def get_run_ut_config(self, **config_params):
120
88
  return RunUtConfig(
@@ -127,6 +95,5 @@ class CheckerConfig:
127
95
  real_data_path=config_params.get('real_data_path'),
128
96
  white_list=self.white_list.copy() if self.white_list else [],
129
97
  black_list=self.black_list.copy() if self.black_list else [],
130
- error_data_path=config_params.get('error_data_path'),
131
- online_config=self.get_online_config()
98
+ error_data_path=config_params.get('error_data_path')
132
99
  )
@@ -117,30 +117,6 @@ def api_precision_compare(config):
117
117
  change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
118
118
 
119
119
 
120
- def online_api_precision_compare(online_config):
121
- rank = online_config.rank
122
- result_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.result_csv_path).replace(
123
- "_rank*.csv", f"_rank{rank}.csv")
124
- details_csv_path = os.path.join(Const.DEFAULT_PATH, online_config.details_csv_path).replace(
125
- "_rank*.csv", f"_rank{rank}.csv")
126
- detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()]
127
- result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()]
128
- if not os.path.exists(result_csv_path):
129
- write_csv(result_csv_title, result_csv_path)
130
- if not os.path.exists(details_csv_path):
131
- write_csv(detail_csv_title, details_csv_path)
132
- config = CompareConfig("", "", result_csv_path, details_csv_path)
133
- try:
134
- npu_data, gpu_data = online_config.npu_data, online_config.gpu_data
135
- check_csv_columns(npu_data.columns, "npu_csv")
136
- check_csv_columns(gpu_data.columns, "gpu_csv")
137
- analyse_csv(npu_data, gpu_data, config)
138
- except Exception as err:
139
- logger.error(f"Online api precision compare Error: {str(err)}")
140
- change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
141
- change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY)
142
-
143
-
144
120
  def analyse_csv(npu_data, gpu_data, config):
145
121
  forward_status, backward_status = [], []
146
122
  last_api_name, last_api_dtype, last_api_full_name = None, None, None
@@ -66,13 +66,6 @@ class Comparator:
66
66
  self.save_path_list = [result_csv_path]
67
67
  self.detail_save_path_list = [details_csv_path]
68
68
 
69
- if config and config.online_config.is_online:
70
- self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv")
71
- self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv")
72
- self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list]
73
- self.detail_save_path_list = \
74
- [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list]
75
-
76
69
  self.registry = self._register_compare_func()
77
70
 
78
71
  if not is_continue_run_ut:
@@ -245,9 +238,8 @@ class Comparator:
245
238
  self.write_detail_csv(args)
246
239
 
247
240
 
248
- def compare_output(self, full_api_name, data_info, is_online=False):
241
+ def compare_output(self, full_api_name, data_info):
249
242
  """Get compare result and write to result and detail csv.
250
- is_online: bool, default False. True: called by online api precision compare, only compare without write to csv.
251
243
  """
252
244
  _, api_name = extract_basic_api_segments(full_api_name)
253
245
  if not api_name:
@@ -280,9 +272,7 @@ class Comparator:
280
272
  fwd_compare_alg_results,
281
273
  bwd_compare_alg_results,
282
274
  data_info.rank)
283
- if is_online:
284
- # get run_ut compare detail
285
- return self._get_run_ut_detail(result_info)
275
+
286
276
  self.record_results(result_info)
287
277
  return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS \
288
278
  or bwd_success_status == CompareConst.SPACE
@@ -2,9 +2,4 @@ white_list: []
2
2
  black_list: []
3
3
  error_data_path: './'
4
4
  precision: 14
5
- is_online: False
6
- nfs_path: ""
7
- host: ""
8
- port: -1
9
- rank_list: [0]
10
- tls_path: "./"
5
+