mindstudio-probe 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/RECORD +85 -66
  3. msprobe/README.md +2 -2
  4. msprobe/core/common/const.py +34 -9
  5. msprobe/core/common/inplace_ops.yaml +1 -0
  6. msprobe/core/common/utils.py +14 -0
  7. msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
  8. msprobe/core/compare/merge_result/merge_result.py +8 -7
  9. msprobe/core/compare/merge_result/utils.py +81 -0
  10. msprobe/core/compare/utils.py +10 -0
  11. msprobe/core/data_dump/data_collector.py +58 -13
  12. msprobe/core/data_dump/data_processor/base.py +92 -8
  13. msprobe/core/data_dump/data_processor/factory.py +3 -0
  14. msprobe/core/data_dump/data_processor/mindspore_processor.py +17 -4
  15. msprobe/core/data_dump/data_processor/pytorch_processor.py +58 -7
  16. msprobe/core/data_dump/json_writer.py +26 -8
  17. msprobe/docs/01.installation.md +25 -0
  18. msprobe/docs/02.config_introduction.md +14 -12
  19. msprobe/docs/03.config_examples.md +24 -0
  20. msprobe/docs/05.data_dump_PyTorch.md +34 -15
  21. msprobe/docs/06.data_dump_MindSpore.md +45 -22
  22. msprobe/docs/09.accuracy_checker_MindSpore.md +4 -2
  23. msprobe/docs/19.monitor.md +257 -260
  24. msprobe/docs/21.visualization_PyTorch.md +10 -0
  25. msprobe/docs/22.visualization_MindSpore.md +11 -0
  26. msprobe/docs/27.dump_json_instruction.md +24 -20
  27. msprobe/docs/28.debugger_save_instruction.md +94 -0
  28. msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
  29. msprobe/docs/img/monitor/step_count_per_record.png +0 -0
  30. msprobe/mindspore/__init__.py +1 -0
  31. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +26 -6
  32. msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
  33. msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
  34. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
  35. msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
  36. msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
  37. msprobe/mindspore/common/utils.py +20 -2
  38. msprobe/mindspore/debugger/debugger_config.py +25 -2
  39. msprobe/mindspore/debugger/precision_debugger.py +25 -6
  40. msprobe/mindspore/dump/hook_cell/api_registry.py +2 -0
  41. msprobe/mindspore/dump/jit_dump.py +7 -6
  42. msprobe/mindspore/monitor/anomaly_detect.py +404 -0
  43. msprobe/mindspore/monitor/distributed/__init__.py +0 -0
  44. msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
  45. msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
  46. msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
  47. msprobe/mindspore/monitor/features.py +63 -0
  48. msprobe/mindspore/monitor/module_hook.py +821 -0
  49. msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
  50. msprobe/mindspore/monitor/utils.py +267 -0
  51. msprobe/mindspore/ms_config.py +8 -2
  52. msprobe/mindspore/service.py +95 -21
  53. msprobe/pytorch/__init__.py +0 -1
  54. msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
  55. msprobe/pytorch/bench_functions/apply_adam.py +215 -0
  56. msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
  57. msprobe/pytorch/bench_functions/mish.py +21 -0
  58. msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
  59. msprobe/pytorch/bench_functions/sort_v2.py +21 -0
  60. msprobe/pytorch/common/utils.py +71 -0
  61. msprobe/pytorch/debugger/debugger_config.py +19 -9
  62. msprobe/pytorch/debugger/precision_debugger.py +14 -0
  63. msprobe/pytorch/dump/module_dump/module_processer.py +10 -30
  64. msprobe/pytorch/function_factory.py +7 -1
  65. msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
  66. msprobe/pytorch/hook_module/wrap_distributed.py +4 -0
  67. msprobe/pytorch/monitor/anomaly_detect.py +14 -29
  68. msprobe/pytorch/monitor/csv2tb.py +10 -12
  69. msprobe/pytorch/monitor/module_hook.py +123 -104
  70. msprobe/pytorch/monitor/module_metric.py +6 -6
  71. msprobe/pytorch/monitor/optimizer_collect.py +45 -63
  72. msprobe/pytorch/monitor/utils.py +8 -43
  73. msprobe/pytorch/pt_config.py +19 -22
  74. msprobe/pytorch/service.py +103 -24
  75. msprobe/visualization/builder/graph_builder.py +31 -5
  76. msprobe/visualization/builder/msprobe_adapter.py +7 -5
  77. msprobe/visualization/graph/base_node.py +3 -2
  78. msprobe/visualization/graph/distributed_analyzer.py +80 -3
  79. msprobe/visualization/graph/node_op.py +4 -2
  80. msprobe/visualization/graph_service.py +3 -4
  81. msprobe/visualization/utils.py +10 -2
  82. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
  83. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
  84. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
  85. {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
@@ -302,6 +302,16 @@ msprobe -f pytorch graph -i ./compare.json -o ./output
302
302
  ├── compare_stepn_rankn_{timestamp}.vis
303
303
  ```
304
304
 
305
+ #### 3.2.4 仅模型结构比对
306
+
307
+ 适用场景:**主要关注模型结构而非训练过程数据**。例如,在模型迁移过程中,确保迁移前后模型结构的一致性,或在排查精度差异时,判断是否由模型结构差异所引起。
308
+
309
+ 使用msprobe工具对模型数据进行采集时,**可选择仅采集模型结构(task配置为structure)**,此配置将避免采集模型训练过程的数据,从而显著减少采集所需的时间。
310
+
311
+ dump配置请参考[dump配置示例](./03.config_examples.md#16-task-配置为-structure)
312
+
313
+ 得到dump数据后,若需比较特定两个rank之间的数据,请参考[3.2.2 双图比对](#322-双图比对);若需进行多个rank或多个step的数据批量比对,请参考[3.2.3 批量构建或比对](#323-批量构建或比对)。
314
+
305
315
  ## 4.启动tensorboard
306
316
 
307
317
  ### 4.1 可直连的服务器
@@ -303,6 +303,17 @@ msprobe -f mindspore graph -i ./compare.json -o ./output
303
303
  ├── compare_stepn_rankn_{timestamp}.vis
304
304
  ```
305
305
 
306
+ #### 3.2.4 仅模型结构比对
307
+
308
+ 适用场景:**主要关注模型结构而非训练过程数据**。例如,在模型迁移过程中,确保迁移前后模型结构的一致性,或在排查精度差异时,判断是否由模型结构差异所引起。
309
+
310
+ 使用msprobe工具对模型数据进行采集时,**可选择仅采集模型结构(task配置为structure)**,此配置将避免采集模型训练过程的数据,从而显著减少采集所需的时间。
311
+
312
+ dump配置请参考[dump配置示例](./03.config_examples.md#35-task-配置为-structure)
313
+
314
+ 得到dump数据后,若需比较特定两个rank之间的数据,请参考[3.2.2 双图比对](#322-双图比对);若需进行多个rank或多个step的数据批量比对,请参考[3.2.3 批量构建或比对](#323-批量构建或比对)。
315
+
316
+
306
317
  ## 4.启动tensorboard
307
318
 
308
319
  ### 4.1 可直连的服务器
@@ -1,16 +1,18 @@
1
1
  # dump.json文件说明及示例
2
2
 
3
- ## 1. dump.json文件介绍(Pytorch
3
+ ## 1. dump.json文件示例(PyTorch
4
4
 
5
5
  ### 1.1 L0级别
6
- L0级别的dump.json文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。以Pytorch的Conv2d模块为例,网络中模块调用代码为:
7
- `output = torch.nn.Conv2d(64, 128, 5, padding=2, bias=True)(input)`
6
+ L0级别的dump.json文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。以PyTorch的Conv2d模块为例,网络中模块调用代码为:
7
+ `output = self.conv2(input) # self.conv2 = torch.nn.Conv2d(64, 128, 5, padding=2, bias=True)`
8
8
 
9
- dump.json文件中包含以下字段:
9
+ dump.json文件中包含以下数据名称:
10
10
 
11
- 1. `Module.conv2.Conv2d.forward.0`为模块的前向数据,其中input_args为模块的输入数据(位置参数),input_kwargs为模块的输入数据(关键字参数),output为模块的输出数据,parameters为模块的参数数据,包括权重(weight)和偏置(bias)。
12
- 2. `Module.conv2.Conv2d.parameters_grad`为模块的参数梯度数据,包括权重(weight)和偏置(bias)的梯度。
13
- 3. `Module.conv2.Conv2d.backward.0`为模块的反向数据,其中input为模块反向的输入梯度(对应前向输出的梯度),output为模块的反向输出梯度(对应前向输入的梯度)。
11
+ - `Module.conv2.Conv2d.forward.0`:模块的前向数据,其中input_args为模块的输入数据(位置参数),input_kwargs为模块的输入数据(关键字参数),output为模块的输出数据,parameters为模块的参数数据,包括权重(weight)和偏置(bias)。
12
+ - `Module.conv2.Conv2d.parameters_grad`:模块的参数梯度数据,包括权重(weight)和偏置(bias)的梯度。
13
+ - `Module.conv2.Conv2d.backward.0`:模块的反向数据,其中input为模块反向的输入梯度(对应前向输出的梯度),output为模块的反向输出梯度(对应前向输入的梯度)。
14
+
15
+ **说明**:当dump时传入的model参数为List[torch.nn.Module]或Tuple[torch.nn.Module]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为`{Module}.{index}.*`,*表示以上三种模块级数据的命名格式,例如:`Module.0.conv1.Conv2d.forward.0`。
14
16
 
15
17
  ```json
16
18
  {
@@ -167,12 +169,12 @@ dump.json文件中包含以下字段:
167
169
  ```
168
170
 
169
171
  ### 1.2 L1级别
170
- L1级别的dump.json文件包括API的前反向的输入输出。以Pytorch的relu函数为例,网络中API调用代码为:
171
- `output = torch.nn.functional.relu(input)`
172
+ L1级别的dump.json文件包括API的前反向的输入输出。以PyTorch的relu函数为例,网络中API调用代码为:
173
+ `output = torch.nn.functional.relu(input)`
172
174
 
173
- dump.json文件中包含以下字段:
174
- 1. `Functional.relu.0.forward`为API的前向数据,其中input_args为API的输入数据(位置参数),input_kwargs为API的输入数据(关键字参数),output为API的输出数据。
175
- 2. `Functional.relu.0.backward`为API的反向数据,其中input为API的反向输入梯度(对应前向输出的梯度),output为API的反向输出梯度(对应前向输入的梯度)。
175
+ dump.json文件中包含以下数据名称:
176
+ - `Functional.relu.0.forward`:API的前向数据,其中input_args为API的输入数据(位置参数),input_kwargs为API的输入数据(关键字参数),output为API的输出数据。
177
+ - `Functional.relu.0.backward`:API的反向数据,其中input为API的反向输入梯度(对应前向输出的梯度),output为API的反向输出梯度(对应前向输入的梯度)。
176
178
 
177
179
  ```json
178
180
  {
@@ -272,12 +274,14 @@ mix级别的dump.json文件同时包括L0和L1级别的dump数据,文件格式
272
274
 
273
275
  L0级别的dump.json文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。
274
276
  以MindSpore的Conv2d模块为例,dump.json文件中使用的模块调用代码为:
275
- `output = mindspore.nn.Conv2d(64, 128, 5, pad_mode='same', has_bias=True)(input)`
277
+ `output = self.conv2(input) # self.conv2 = mindspore.nn.Conv2d(64, 128, 5, pad_mode='same', has_bias=True)`
278
+
279
+ dump.json文件中包含以下数据名称:
280
+ - `Cell.conv2.Conv2d.forward.0`:模块的前向数据,其中input_args为模块的输入数据(位置参数),input_kwargs为模块的输入数据(关键字参数),output为模块的输出数据,parameters为模块的参数数据,包括权重(weight)和偏置(bias)。
281
+ - `Cell.conv2.Conv2d.parameters_grad`:模块的参数梯度数据,包括权重(weight)和偏置(bias)的梯度。
282
+ - `Cell.conv2.Conv2d.backward.0`:模块的反向数据,其中input为模块反向的输入梯度(对应前向输出的梯度),output为模块的反向输出梯度(对应前向输入的梯度)。
276
283
 
277
- dump.json文件中包含以下字段:
278
- 1. `Cell.conv2.Conv2d.forward.0`为模块的前向数据,其中input_args为模块的输入数据(位置参数),input_kwargs为模块的输入数据(关键字参数),output为模块的输出数据,parameters为模块的参数数据,包括权重(weight)和偏置(bias)。
279
- 2. `Cell.conv2.Conv2d.parameters_grad`为模块的参数梯度数据,包括权重(weight)和偏置(bias)的梯度。
280
- 3. `Cell.conv2.Conv2d.backward.0`为模块的反向数据,其中input为模块反向的输入梯度(对应前向输出的梯度),output为模块的反向输出梯度(对应前向输入的梯度)。
284
+ **说明**:当dump时传入的model参数为List[mindspore.nn.Cell]或Tuple[mindspore.nn.Cell]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为`{Cell}.{index}.*`,*表示以上三种模块级数据的命名格式,例如:`Cell.0.conv2.Conv2d.forward.0`。
281
285
 
282
286
  ```json
283
287
  {
@@ -429,9 +433,9 @@ dump.json文件中包含以下字段:
429
433
  L1级别的dump.json文件包括API的前反向的输入输出,以MindSpore的relu函数为例,网络中API调用代码为:
430
434
  `output = mindspore.ops.relu(input)`
431
435
 
432
- dump.json文件中包含以下字段:
433
- 1. `Functional.relu.0.forward`为API的前向数据,其中input_args为API的输入数据(位置参数),input_kwargs为API的输入数据(关键字参数),output为API的输出数据。
434
- 2. `Functional.relu.0.backward`为API的反向数据,其中input为API的反向输入梯度(对应前向输出的梯度),output为API的反向输出梯度(对应前向输入的梯度)。
436
+ dump.json文件中包含以下数据名称:
437
+ - `Functional.relu.0.forward`:API的前向数据,其中input_args为API的输入数据(位置参数),input_kwargs为API的输入数据(关键字参数),output为API的输出数据。
438
+ - `Functional.relu.0.backward`:API的反向数据,其中input为API的反向输入梯度(对应前向输出的梯度),output为API的反向输出梯度(对应前向输入的梯度)。
435
439
 
436
440
  ```json
437
441
  {
@@ -0,0 +1,94 @@
1
+ # 单点保存工具 README
2
+
3
+ ## 简介
4
+ L0, L1, mix dump存在盲区,网络中的非api/module的输入输出不会被批量dump下来。单点保存提供类似np.save和print的功能和使用体验,可以保存指定的变量。同时针对大模型场景进行了增强,具备以下特性:
5
+ - 可保存变量的反向梯度结果。
6
+ - 能直接保存嵌套结构数据(如 list、dict),无需手动遍历。
7
+ - 自动分 rank 保存。
8
+ - 多次调用时会自动计数。
9
+ - 可配置保存统计值或者张量。
10
+
11
+ ## 支持场景
12
+ 仅支持 PyTorch 与 MindSpore 的动态图场景。
13
+
14
+ ## 使能方式
15
+
16
+ ### 配置文件说明
17
+
18
+ 通用配置:
19
+
20
+ | 参数 | 解释 | 是否必选 |
21
+ | -------- |-------------------------------------------| -------- |
22
+ | task | dump 的任务类型,str 类型。 单点保存场景仅支持传入"statistics", "tensor"。 | 是 |
23
+ | level | dump 级别,str 类型,根据不同级别采集不同数据。单点保存场景传入"debug"。 | 是 |
24
+ | dump_path | 设置 dump 数据目录路径,str 类型。细节详见[通用配置说明](./02.config_introduction.md#11-通用配置) | 是 |
25
+ | rank | 指定对某张卡上的数据进行采集,list[Union[int, str]] 类型。细节详见[通用配置说明](./02.config_introduction.md#11-通用配置) | 否 |
26
+
27
+ "statistics" 任务子配置项:
28
+ | 参数 | 解释 | 是否必选 |
29
+ | -------- |-------------------------------------------| -------- |
30
+ | summary_mode | 控制 dump 文件输出的模式,str 类型。支持传入"statistics", "md5"。 细节详见[statistics任务子配置项说明](./02.config_introduction.md#12-task-配置为-statistics) | 否 |
31
+
32
+ "tensor" 任务无子配置项。
33
+
34
+ ### 接口调用说明
35
+
36
+ 调用PrecisionDebugger.save,传入需要保存的变量,指定变量名称以及是否需要保存反向数据。接口入参说明详见[pytorch单点保存接口](./05.data_dump_PyTorch.md#19-save),[mindspore单点保存接口](./06.data_dump_MindSpore.md#615-save)
37
+
38
+ ### 实例(以pytorch场景为例)
39
+
40
+ 配置文件
41
+ ```json
42
+ {
43
+ "task": "statistics",
44
+ "dump_path": "./dump_path",
45
+ "rank": [],
46
+ "level": "debug",
47
+ "statistics": {
48
+ "summary_mode": "statistics"
49
+ }
50
+ }
51
+ ```
52
+
53
+ 初始化
54
+ ```python
55
+ # 训练启动py脚本
56
+ from mindspore.pytorch import PrecisionDebugger
57
+ debugger = PrecisionDebugger("./config.json")
58
+ for data, label in data_loader:
59
+ # 执行模型训练
60
+ train(data, label)
61
+
62
+ ```
63
+
64
+ 初始化(无配置文件)
65
+ ```python
66
+ # 训练启动py脚本
67
+ from mindspore.pytorch import PrecisionDebugger
68
+ debugger = PrecisionDebugger(dump_path="dump_path", level="debug")
69
+ for data, label in data_loader:
70
+ # 执行模型训练
71
+ train(data, label)
72
+
73
+ ```
74
+
75
+ 调用保存接口
76
+ ```python
77
+ # 训练过程中被调用py文件
78
+ from mindspore.pytorch import PrecisionDebugger
79
+ dict_variable = {"key1": "value1", "key2": [1, 2]}
80
+ PrecisionDebugger.save(dict_variable, "dict_variable", save_backward=False)
81
+
82
+ ```
83
+
84
+ ## 输出结果
85
+ * **"task" 配置为 "statistics" 场景** :在 dump 目录下会生成包含变量统计值信息的 `debug.json` 文件。
86
+ * **"task" 配置为 "tensor" 场景** :除了在 dump 目录下生成包含变量统计值信息的 `debug.json` 文件外,还会在 dump 子目录 `dump_tensor_data` 中保存张量二进制文件,文件名称格式为 `{variable_name}{grad_flag}.{count}.tensor.{indexes}.{file_suffix}`。
87
+
88
+ - variable_name: 传入save接口的变量名称。
89
+ - grad_flag: 反向数据标识,反向数据为"_grad",正向数据为""。
90
+ - count: 调用计数,多次以相同变量名称调用时的计数。
91
+ - indexes: 索引,在保存嵌套结构数据时的索引。例如:嵌套结构为`{"key1": "value1", "key2": ["value2", "value3"]}`,"value2"的索引为"key2.0"
92
+ - file_suffix:文件后缀,pytorch场景为"pt",mindspore场景为"npy"
93
+
94
+
@@ -0,0 +1,69 @@
1
+ # MindSpore 场景的 kernel dump 说明
2
+
3
+ 当使用 msprobe 数据采集功能时,level 配置为 "L2" 表示采集 kernel 层级的算子数据,仅支持昇腾 NPU 平台。
4
+
5
+ 本文主要介绍 kernel dump 的配置示例和采集结果介绍, msprobe 数据采集功能的详细使用参考 《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》。
6
+
7
+ ## 1 kernel dump 配置示例
8
+
9
+ 使用 kernel dump 时,list 必须要填一个 API 名称,kernel dump 目前每个 step 只支持采集一个 API 的数据。
10
+ API 名称填写参考 L1 dump 结果文件 dump.json 中的API名称,命名格式为:`{api_type}.{api_name}.{API调用次数}.{forward/backward}`。
11
+
12
+ ```json
13
+ {
14
+ "task": "tensor",
15
+ "dump_path": "/home/data_dump",
16
+ "level": "L2",
17
+ "rank": [],
18
+ "step": [],
19
+ "tensor": {
20
+ "scope": [],
21
+ "list": ["Functional.linear.0.backward"]
22
+ }
23
+ }
24
+ ```
25
+
26
+ ## 2 结果文件介绍
27
+
28
+ ### 2.1 采集结果说明
29
+
30
+ 如果 API kernel 级数据采集成功,会打印以下信息:
31
+
32
+ ```bash
33
+ The kernel data of {api_name} is dumped successfully.
34
+ ```
35
+
36
+ 注意:如果打印该信息后,没有数据生成,参考**常见问题3.1**进行排查。
37
+
38
+ 如果 kernel dump 遇到不支持的 API, 会打印以下信息:
39
+
40
+ ```bash
41
+ The kernel dump does not support the {api_name} API.
42
+ ```
43
+
44
+ 其中 {api_name} 是对应溢出的 API 名称。
45
+
46
+ ### 2.2 输出文件说明
47
+ kernel dump 采集成功后,会在指定的 dump_path 目录下生成如下文件:
48
+
49
+ ```
50
+ ├── /home/data_dump/
51
+ │ ├── step0
52
+ │ │ ├── 20241201103000 # 日期时间格式,表示2024-12-01 10:30:00
53
+ │ │ │ ├── 0 # 表示 device id
54
+ │ │ │ │ ├──{op_type}.{op_name}.{task_id}.{stream_id}.{timestamp} # kernel 层算子数据
55
+ │ │ │ ...
56
+ │ │ ├── kernel_config_{device_id}.json # kernel dump 在接口调用过程中生成的中间文件,一般情况下无需关注
57
+ │ │ ...
58
+ │ ├── step1
59
+ │ ...
60
+ ```
61
+ 成功采集到数据后,可以使用 msprobe 工具提供的《[PyTorch 场景的数据解析](./14.data_parse_PyTorch.md)》功能分析数据。
62
+
63
+ ## 3 常见问题
64
+
65
+ #### 3.1 采集结果文件为空,有可能是什么原因?
66
+
67
+ 1. 首先需要确认工具使用方式、配置文件内容、list 填写的 API 名称格式是否都正确无误。
68
+
69
+ 2. 其次需要确认 API 是否运行在昇腾 NPU 上,如果是运行在其他设备上则不会存在 kernel 级数据。
@@ -25,3 +25,4 @@ except ImportError:
25
25
 
26
26
  from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger
27
27
  from msprobe.mindspore.common.utils import seed_all
28
+ from msprobe.mindspore.monitor.module_hook import TrainerMon
@@ -26,6 +26,7 @@ from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager
26
26
  from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context,
27
27
  trim_output_compute_element_list)
28
28
  from msprobe.mindspore.common.log import logger
29
+ from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer
29
30
 
30
31
  cur_path = os.path.dirname(os.path.realpath(__file__))
31
32
  yaml_path = os.path.join(cur_path, MsCompareConst.SUPPORTED_API_LIST_FILE)
@@ -82,9 +83,11 @@ class ApiAccuracyChecker:
82
83
  # get output
83
84
  if global_context.get_is_constructed():
84
85
  # constructed situation, need use constructed input to run mindspore api getting tested_output
85
- tested_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.MS_FRAMEWORK)
86
+ tested_outputs = api_runner(api_input_aggregation, api_name_str,
87
+ forward_or_backward, global_context.get_framework())
86
88
  else:
87
89
  tested_outputs = api_info.get_compute_element_list(forward_or_backward, Const.OUTPUT)
90
+
88
91
  bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
89
92
  tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward)
90
93
  bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward)
@@ -153,13 +156,19 @@ class ApiAccuracyChecker:
153
156
  real_api_str = Const.SEP.join(api_name_str_list[1:-2])
154
157
  api_list = load_yaml(yaml_path)
155
158
  supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
156
- if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL):
159
+ if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \
160
+ and global_context.get_framework() == Const.MS_FRAMEWORK:
161
+ return True
162
+ if api_type_str in MsCompareConst.MT_VALID_API_TYPES \
163
+ and global_context.get_framework() == Const.MT_FRAMEWORK:
157
164
  return True
158
- if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list:
165
+ if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \
166
+ and global_context.get_framework() == Const.MS_FRAMEWORK:
159
167
  return True
160
168
  return False
161
169
 
162
170
  def parse(self, api_info_path):
171
+
163
172
  api_info_dict = load_json(api_info_path)
164
173
 
165
174
  # init global context
@@ -167,14 +176,25 @@ class ApiAccuracyChecker:
167
176
  "task field in api_info.json", accepted_type=str,
168
177
  accepted_value=(MsCompareConst.STATISTICS_TASK,
169
178
  MsCompareConst.TENSOR_TASK))
179
+ try:
180
+ framework = check_and_get_from_json_dict(api_info_dict, MsCompareConst.FRAMEWORK,
181
+ "framework field in api_info.json", accepted_type=str,
182
+ accepted_value=(Const.MS_FRAMEWORK,
183
+ Const.MT_FRAMEWORK))
184
+ except Exception as e:
185
+ framework = Const.MS_FRAMEWORK
186
+ logger.warning(f"JSON parsing error in framework field: {e}")
187
+
188
+ if framework == Const.MT_FRAMEWORK and not torch_mindtorch_importer.is_valid_pt_mt_env:
189
+ raise Exception(f"Please check if you have a valid PyTorch and MindTorch environment")
190
+
170
191
  is_constructed = task == MsCompareConst.STATISTICS_TASK
171
192
  if not is_constructed:
172
193
  dump_data_dir = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DUMP_DATA_DIR_FIELD,
173
- "dump_data_dir field in api_info.json",
174
- accepted_type=str)
194
+ "dump_data_dir field in api_info.json", accepted_type=str)
175
195
  else:
176
196
  dump_data_dir = ""
177
- global_context.init(is_constructed, dump_data_dir)
197
+ global_context.init(is_constructed, dump_data_dir, framework)
178
198
 
179
199
  api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
180
200
  "data field in api_info.json", accepted_type=dict)
@@ -14,7 +14,6 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import mindspore
17
- import torch
18
17
  from mindspore import ops
19
18
  from msprobe.core.common.const import Const, MsCompareConst
20
19
  from msprobe.core.common.exceptions import ApiAccuracyCheckerException
@@ -24,14 +23,28 @@ from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
24
23
  from msprobe.mindspore.common.log import logger
25
24
 
26
25
 
26
+ from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer
27
+
28
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch
29
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch_tensor
30
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch_func
31
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch_dist
32
+
33
+ if torch_mindtorch_importer.is_valid_pt_mt_env:
34
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import torch
35
+ else:
36
+ import torch
37
+
38
+
39
+
27
40
  class ApiInputAggregation:
28
41
  def __init__(self, inputs, kwargs, gradient_inputs) -> None:
29
- '''
42
+ """
30
43
  Args:
31
44
  inputs: List[ComputeElement]
32
45
  kwargs: dict{str: ComputeElement}
33
46
  gradient_inputs: Union[List[ComputeElement], None]
34
- '''
47
+ """
35
48
  self.inputs = inputs
36
49
  self.kwargs = kwargs
37
50
  self.gradient_inputs = gradient_inputs
@@ -43,16 +56,34 @@ api_parent_module_mapping = {
43
56
  (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
44
57
  (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional,
45
58
  (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): mindspore.Tensor,
46
- (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor
59
+ (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor,
60
+ (MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): mindtorch_tensor,
61
+ (MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): torch.Tensor,
62
+ (MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): mindtorch,
63
+ (MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): torch,
64
+ (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func,
65
+ (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional,
66
+ (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist,
67
+ (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed
68
+
47
69
  }
48
70
 
71
+
49
72
  api_parent_module_str_mapping = {
50
73
  (MsCompareConst.MINT, Const.MS_FRAMEWORK): "mindspore.mint",
51
74
  (MsCompareConst.MINT, Const.PT_FRAMEWORK): "torch",
52
75
  (MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): "mindspore.mint.nn.functional",
53
76
  (MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): "torch.nn.functional",
54
77
  (MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): "mindspore.Tensor",
55
- (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor"
78
+ (MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor",
79
+ (MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): "mindtorch_tensor",
80
+ (MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): "torch.Tensor",
81
+ (MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): "mindtorch",
82
+ (MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): "torch",
83
+ (MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func",
84
+ (MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional",
85
+ (MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist",
86
+ (MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed"
56
87
  }
57
88
 
58
89
 
@@ -64,7 +95,7 @@ class ApiRunner:
64
95
  api_input_aggregation: ApiInputAggregation
65
96
  api_name_str: str, e.g. "MintFunctional.relu.0"
66
97
  forward_or_backward: str, Union["forward", "backward"]
67
- api_platform: str, Union["mindspore", "torch"]
98
+ api_platform: str, Union["mindspore", "torch", "mindtorch"]
68
99
 
69
100
  Return:
70
101
  outputs: list[ComputeElement]
@@ -72,35 +103,41 @@ class ApiRunner:
72
103
  Description:
73
104
  run mindspore.mint/torch api
74
105
  '''
75
- api_type_str, api_sub_name = self.get_info_from_name(api_name_str)
106
+
107
+ api_type_str, api_sub_name = self.get_info_from_name(api_name_str, api_platform)
76
108
  api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform)
77
109
 
78
110
  return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform)
79
111
 
80
112
  @staticmethod
81
- def get_info_from_name(api_name_str):
82
- '''
113
+ def get_info_from_name(api_name_str, api_platform=Const.MS_FRAMEWORK):
114
+ """
83
115
  Args:
84
116
  api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
85
-
117
+ api_platform: str, the platform for the API, which can be either "mindspore" or "mindtorch".
118
+ It specifies which framework is being used. Default is "mindspore".
86
119
  Return:
87
- api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
120
+ api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Torch", "Functional"]
88
121
  api_sub_name: str, e.g. "relu"
89
- '''
122
+ """
90
123
  api_name_list = api_name_str.split(Const.SEP)
91
124
  if len(api_name_list) != 3:
92
125
  err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
93
126
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
94
127
  api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
95
- if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API]:
128
+ if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API] \
129
+ and api_platform == Const.MS_FRAMEWORK:
96
130
  err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api"
97
131
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
98
132
 
133
+ if api_type_str not in MsCompareConst.MT_VALID_API_TYPES and api_platform == Const.MT_FRAMEWORK:
134
+ err_msg = f"ApiRunner.get_info_from_name failed: not torch, functional or Tensor api"
135
+ logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
99
136
  return api_type_str, api_sub_name
100
137
 
101
138
  @staticmethod
102
139
  def get_api_instance(api_type_str, api_sub_name, api_platform):
103
- '''
140
+ """
104
141
  Args:
105
142
  api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
106
143
  api_sub_name: str, e.g. "relu"
@@ -113,11 +150,12 @@ class ApiRunner:
113
150
  get mindspore.mint/torch api fucntion
114
151
  mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
115
152
  mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
116
- '''
153
+ """
117
154
 
118
155
  api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
119
156
  api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform))
120
157
  full_api_name = api_parent_module_str + Const.SEP + api_sub_name
158
+
121
159
  if not hasattr(api_parent_module, api_sub_name):
122
160
  err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
123
161
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
@@ -147,7 +185,7 @@ class ApiRunner:
147
185
  logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
148
186
  gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
149
187
  for compute_element in gradient_inputs)
150
- if api_platform == Const.MS_FRAMEWORK:
188
+ if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK:
151
189
  if len(gradient_inputs) == 1:
152
190
  gradient_inputs = gradient_inputs[0]
153
191
 
@@ -25,6 +25,7 @@ from msprobe.core.common.file_utils import load_npy
25
25
  from msprobe.mindspore.api_accuracy_checker.type_mapping import (api_info_type_str_to_type,
26
26
  ms_dtype_to_dtype_str, torch_dtype_to_dtype_str,
27
27
  dtype_str_to_ms_dtype, dtype_str_to_np_dtype,
28
+ dtype_str_to_mindtorch_dtype,
28
29
  dtype_str_to_torch_dtype, type_to_api_info_type_str,
29
30
  DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE, TUPLE_TYPE_STR,
30
31
  MINDSPORE_TENSOR_TYPE_STR, MINDSPORE_DTYPE_TYPE_STR,
@@ -33,6 +34,15 @@ from msprobe.mindspore.api_accuracy_checker.type_mapping import (api_info_type_s
33
34
  from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
34
35
  from msprobe.mindspore.common.log import logger
35
36
 
37
+ import msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer as env_module
38
+
39
+
40
+ if env_module.is_valid_pt_mt_env:
41
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch
42
+ from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import torch
43
+ else:
44
+ import torch
45
+
36
46
 
37
47
  class MstensorMetaData:
38
48
  def __init__(self, dtype_str, npy_path, maximum, minimum, shape) -> None:
@@ -86,6 +96,37 @@ class ComputeElement:
86
96
  torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
87
97
  return torch_tensor
88
98
 
99
+ @staticmethod
100
+ def transfer_to_mindtorch_tensor(ms_tensor):
101
+ """
102
+ Args:
103
+ ms_tensor: mindspore.Tensor
104
+ Return:
105
+ mindtorch_tensor: mindtorch.Tensor
106
+ """
107
+
108
+ ms_dtype = ms_tensor.dtype
109
+
110
+ dtype_str = ms_dtype_to_dtype_str.get(ms_dtype)
111
+
112
+ if dtype_str not in dtype_str_to_mindtorch_dtype:
113
+ err_msg = f"ComputeElement.transfer_to_mindtorch_tensor failed: no matching mindtorch dtype for {dtype_str}"
114
+ logger.error_log_with_exp(err_msg,
115
+ ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
116
+ else:
117
+ mindtorch_dtype = dtype_str_to_mindtorch_dtype.get(dtype_str)
118
+
119
+ if dtype_str in int_dtype_str_list:
120
+ middle_dtype = mindspore.int64
121
+ else:
122
+ middle_dtype = mindspore.float64
123
+
124
+ np_ndarray = ms_tensor.astype(middle_dtype).numpy()
125
+
126
+ mindtorch_tensor = mindtorch.from_numpy(np_ndarray).to(ms_dtype)
127
+
128
+ return mindtorch_tensor
129
+
89
130
  @staticmethod
90
131
  def transfer_to_mindspore_tensor(torch_tensor):
91
132
  '''
@@ -141,8 +182,11 @@ class ComputeElement:
141
182
  elif isinstance(self.parameter, DtypeMetaData):
142
183
  if tensor_platform == Const.MS_FRAMEWORK:
143
184
  parameter_tmp = dtype_str_to_ms_dtype.get(self.parameter.dtype_str)
144
- else:
185
+ elif tensor_platform == Const.PT_FRAMEWORK:
145
186
  parameter_tmp = dtype_str_to_torch_dtype.get(self.parameter.dtype_str)
187
+ elif tensor_platform == Const.MT_FRAMEWORK:
188
+ parameter_tmp = dtype_str_to_mindtorch_dtype.get(self.parameter.dtype_str)
189
+
146
190
  elif isinstance(self.parameter, MstensorMetaData):
147
191
  mstensor_meta_data = self.parameter
148
192
  ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
@@ -161,6 +205,8 @@ class ComputeElement:
161
205
  # if necessary, do transfer
162
206
  if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
163
207
  parameter = self.transfer_to_torch_tensor(parameter_tmp)
208
+ elif not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.MT_FRAMEWORK:
209
+ parameter = self.transfer_to_mindtorch_tensor(parameter_tmp)
164
210
  elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform == Const.MS_FRAMEWORK:
165
211
  parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
166
212
  else: