mindstudio-probe 1.2.1__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/METADATA +1 -1
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/RECORD +85 -66
- msprobe/README.md +2 -2
- msprobe/core/common/const.py +34 -9
- msprobe/core/common/inplace_ops.yaml +1 -0
- msprobe/core/common/utils.py +14 -0
- msprobe/core/compare/layer_mapping/data_scope_parser.py +1 -1
- msprobe/core/compare/merge_result/merge_result.py +8 -7
- msprobe/core/compare/merge_result/utils.py +81 -0
- msprobe/core/compare/utils.py +10 -0
- msprobe/core/data_dump/data_collector.py +58 -13
- msprobe/core/data_dump/data_processor/base.py +92 -8
- msprobe/core/data_dump/data_processor/factory.py +3 -0
- msprobe/core/data_dump/data_processor/mindspore_processor.py +17 -4
- msprobe/core/data_dump/data_processor/pytorch_processor.py +58 -7
- msprobe/core/data_dump/json_writer.py +26 -8
- msprobe/docs/01.installation.md +25 -0
- msprobe/docs/02.config_introduction.md +14 -12
- msprobe/docs/03.config_examples.md +24 -0
- msprobe/docs/05.data_dump_PyTorch.md +34 -15
- msprobe/docs/06.data_dump_MindSpore.md +45 -22
- msprobe/docs/09.accuracy_checker_MindSpore.md +4 -2
- msprobe/docs/19.monitor.md +257 -260
- msprobe/docs/21.visualization_PyTorch.md +10 -0
- msprobe/docs/22.visualization_MindSpore.md +11 -0
- msprobe/docs/27.dump_json_instruction.md +24 -20
- msprobe/docs/28.debugger_save_instruction.md +94 -0
- msprobe/docs/28.kernel_dump_MindSpore.md +69 -0
- msprobe/docs/img/monitor/step_count_per_record.png +0 -0
- msprobe/mindspore/__init__.py +1 -0
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +26 -6
- msprobe/mindspore/api_accuracy_checker/api_runner.py +54 -16
- msprobe/mindspore/api_accuracy_checker/compute_element.py +47 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +129 -0
- msprobe/mindspore/api_accuracy_checker/type_mapping.py +24 -1
- msprobe/mindspore/api_accuracy_checker/utils.py +6 -1
- msprobe/mindspore/common/utils.py +20 -2
- msprobe/mindspore/debugger/debugger_config.py +25 -2
- msprobe/mindspore/debugger/precision_debugger.py +25 -6
- msprobe/mindspore/dump/hook_cell/api_registry.py +2 -0
- msprobe/mindspore/dump/jit_dump.py +7 -6
- msprobe/mindspore/monitor/anomaly_detect.py +404 -0
- msprobe/mindspore/monitor/distributed/__init__.py +0 -0
- msprobe/mindspore/monitor/distributed/distributed_ops.yaml +15 -0
- msprobe/mindspore/monitor/distributed/stack_blacklist.yaml +5 -0
- msprobe/mindspore/monitor/distributed/wrap_distributed.py +300 -0
- msprobe/mindspore/monitor/features.py +63 -0
- msprobe/mindspore/monitor/module_hook.py +821 -0
- msprobe/mindspore/monitor/module_spec_verifier.py +94 -0
- msprobe/mindspore/monitor/utils.py +267 -0
- msprobe/mindspore/ms_config.py +8 -2
- msprobe/mindspore/service.py +95 -21
- msprobe/pytorch/__init__.py +0 -1
- msprobe/pytorch/api_accuracy_checker/generate_op_script/op_generator.py +1 -1
- msprobe/pytorch/bench_functions/apply_adam.py +215 -0
- msprobe/pytorch/bench_functions/group_norm_silu.py +27 -0
- msprobe/pytorch/bench_functions/mish.py +21 -0
- msprobe/pytorch/bench_functions/moe_gating_top_k_softmax.py +44 -0
- msprobe/pytorch/bench_functions/sort_v2.py +21 -0
- msprobe/pytorch/common/utils.py +71 -0
- msprobe/pytorch/debugger/debugger_config.py +19 -9
- msprobe/pytorch/debugger/precision_debugger.py +14 -0
- msprobe/pytorch/dump/module_dump/module_processer.py +10 -30
- msprobe/pytorch/function_factory.py +7 -1
- msprobe/pytorch/hook_module/support_wrap_ops.yaml +2 -1
- msprobe/pytorch/hook_module/wrap_distributed.py +4 -0
- msprobe/pytorch/monitor/anomaly_detect.py +14 -29
- msprobe/pytorch/monitor/csv2tb.py +10 -12
- msprobe/pytorch/monitor/module_hook.py +123 -104
- msprobe/pytorch/monitor/module_metric.py +6 -6
- msprobe/pytorch/monitor/optimizer_collect.py +45 -63
- msprobe/pytorch/monitor/utils.py +8 -43
- msprobe/pytorch/pt_config.py +19 -22
- msprobe/pytorch/service.py +103 -24
- msprobe/visualization/builder/graph_builder.py +31 -5
- msprobe/visualization/builder/msprobe_adapter.py +7 -5
- msprobe/visualization/graph/base_node.py +3 -2
- msprobe/visualization/graph/distributed_analyzer.py +80 -3
- msprobe/visualization/graph/node_op.py +4 -2
- msprobe/visualization/graph_service.py +3 -4
- msprobe/visualization/utils.py +10 -2
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-1.2.1.dist-info → mindstudio_probe-1.2.2.dist-info}/top_level.txt +0 -0
|
@@ -302,6 +302,16 @@ msprobe -f pytorch graph -i ./compare.json -o ./output
|
|
|
302
302
|
├── compare_stepn_rankn_{timestamp}.vis
|
|
303
303
|
```
|
|
304
304
|
|
|
305
|
+
#### 3.2.4 仅模型结构比对
|
|
306
|
+
|
|
307
|
+
适用场景:**主要关注模型结构而非训练过程数据**。例如,在模型迁移过程中,确保迁移前后模型结构的一致性,或在排查精度差异时,判断是否由模型结构差异所引起。
|
|
308
|
+
|
|
309
|
+
使用msprobe工具对模型数据进行采集时,**可选择仅采集模型结构(task配置为structure)**,此配置将避免采集模型训练过程的数据,从而显著减少采集所需的时间。
|
|
310
|
+
|
|
311
|
+
dump配置请参考[dump配置示例](./03.config_examples.md#16-task-配置为-structure)
|
|
312
|
+
|
|
313
|
+
得到dump数据后,若需比较特定两个rank之间的数据,请参考[3.2.2 双图比对](#322-双图比对);若需进行多个rank或多个step的数据批量比对,请参考[3.2.3 批量构建或比对](#323-批量构建或比对)。
|
|
314
|
+
|
|
305
315
|
## 4.启动tensorboard
|
|
306
316
|
|
|
307
317
|
### 4.1 可直连的服务器
|
|
@@ -303,6 +303,17 @@ msprobe -f mindspore graph -i ./compare.json -o ./output
|
|
|
303
303
|
├── compare_stepn_rankn_{timestamp}.vis
|
|
304
304
|
```
|
|
305
305
|
|
|
306
|
+
#### 3.2.4 仅模型结构比对
|
|
307
|
+
|
|
308
|
+
适用场景:**主要关注模型结构而非训练过程数据**。例如,在模型迁移过程中,确保迁移前后模型结构的一致性,或在排查精度差异时,判断是否由模型结构差异所引起。
|
|
309
|
+
|
|
310
|
+
使用msprobe工具对模型数据进行采集时,**可选择仅采集模型结构(task配置为structure)**,此配置将避免采集模型训练过程的数据,从而显著减少采集所需的时间。
|
|
311
|
+
|
|
312
|
+
dump配置请参考[dump配置示例](./03.config_examples.md#35-task-配置为-structure)
|
|
313
|
+
|
|
314
|
+
得到dump数据后,若需比较特定两个rank之间的数据,请参考[3.2.2 双图比对](#322-双图比对);若需进行多个rank或多个step的数据批量比对,请参考[3.2.3 批量构建或比对](#323-批量构建或比对)。
|
|
315
|
+
|
|
316
|
+
|
|
306
317
|
## 4.启动tensorboard
|
|
307
318
|
|
|
308
319
|
### 4.1 可直连的服务器
|
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
# dump.json文件说明及示例
|
|
2
2
|
|
|
3
|
-
## 1. dump.json
|
|
3
|
+
## 1. dump.json文件示例(PyTorch)
|
|
4
4
|
|
|
5
5
|
### 1.1 L0级别
|
|
6
|
-
L0级别的dump.json文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。以
|
|
7
|
-
`output = torch.nn.Conv2d(64, 128, 5, padding=2, bias=True)
|
|
6
|
+
L0级别的dump.json文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。以PyTorch的Conv2d模块为例,网络中模块调用代码为:
|
|
7
|
+
`output = self.conv2(input) # self.conv2 = torch.nn.Conv2d(64, 128, 5, padding=2, bias=True)`
|
|
8
8
|
|
|
9
|
-
dump.json
|
|
9
|
+
dump.json文件中包含以下数据名称:
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
11
|
+
- `Module.conv2.Conv2d.forward.0`:模块的前向数据,其中input_args为模块的输入数据(位置参数),input_kwargs为模块的输入数据(关键字参数),output为模块的输出数据,parameters为模块的参数数据,包括权重(weight)和偏置(bias)。
|
|
12
|
+
- `Module.conv2.Conv2d.parameters_grad`:模块的参数梯度数据,包括权重(weight)和偏置(bias)的梯度。
|
|
13
|
+
- `Module.conv2.Conv2d.backward.0`:模块的反向数据,其中input为模块反向的输入梯度(对应前向输出的梯度),output为模块的反向输出梯度(对应前向输入的梯度)。
|
|
14
|
+
|
|
15
|
+
**说明**:当dump时传入的model参数为List[torch.nn.Module]或Tuple[torch.nn.Module]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为`{Module}.{index}.*`,*表示以上三种模块级数据的命名格式,例如:`Module.0.conv1.Conv2d.forward.0`。
|
|
14
16
|
|
|
15
17
|
```json
|
|
16
18
|
{
|
|
@@ -167,12 +169,12 @@ dump.json文件中包含以下字段:
|
|
|
167
169
|
```
|
|
168
170
|
|
|
169
171
|
### 1.2 L1级别
|
|
170
|
-
L1级别的dump.json文件包括API的前反向的输入输出。以
|
|
171
|
-
|
|
172
|
+
L1级别的dump.json文件包括API的前反向的输入输出。以PyTorch的relu函数为例,网络中API调用代码为:
|
|
173
|
+
`output = torch.nn.functional.relu(input)`
|
|
172
174
|
|
|
173
|
-
dump.json
|
|
174
|
-
|
|
175
|
-
|
|
175
|
+
dump.json文件中包含以下数据名称:
|
|
176
|
+
- `Functional.relu.0.forward`:API的前向数据,其中input_args为API的输入数据(位置参数),input_kwargs为API的输入数据(关键字参数),output为API的输出数据。
|
|
177
|
+
- `Functional.relu.0.backward`:API的反向数据,其中input为API的反向输入梯度(对应前向输出的梯度),output为API的反向输出梯度(对应前向输入的梯度)。
|
|
176
178
|
|
|
177
179
|
```json
|
|
178
180
|
{
|
|
@@ -272,12 +274,14 @@ mix级别的dump.json文件同时包括L0和L1级别的dump数据,文件格式
|
|
|
272
274
|
|
|
273
275
|
L0级别的dump.json文件包括模块的前反向的输入输出,以及模块的参数和参数梯度。
|
|
274
276
|
以MindSpore的Conv2d模块为例,dump.json文件中使用的模块调用代码为:
|
|
275
|
-
`output = mindspore.nn.Conv2d(64, 128, 5, pad_mode='same', has_bias=True)
|
|
277
|
+
`output = self.conv2(input) # self.conv2 = mindspore.nn.Conv2d(64, 128, 5, pad_mode='same', has_bias=True)`
|
|
278
|
+
|
|
279
|
+
dump.json文件中包含以下数据名称:
|
|
280
|
+
- `Cell.conv2.Conv2d.forward.0`:模块的前向数据,其中input_args为模块的输入数据(位置参数),input_kwargs为模块的输入数据(关键字参数),output为模块的输出数据,parameters为模块的参数数据,包括权重(weight)和偏置(bias)。
|
|
281
|
+
- `Cell.conv2.Conv2d.parameters_grad`:模块的参数梯度数据,包括权重(weight)和偏置(bias)的梯度。
|
|
282
|
+
- `Cell.conv2.Conv2d.backward.0`:模块的反向数据,其中input为模块反向的输入梯度(对应前向输出的梯度),output为模块的反向输出梯度(对应前向输入的梯度)。
|
|
276
283
|
|
|
277
|
-
dump.
|
|
278
|
-
1. `Cell.conv2.Conv2d.forward.0`为模块的前向数据,其中input_args为模块的输入数据(位置参数),input_kwargs为模块的输入数据(关键字参数),output为模块的输出数据,parameters为模块的参数数据,包括权重(weight)和偏置(bias)。
|
|
279
|
-
2. `Cell.conv2.Conv2d.parameters_grad`为模块的参数梯度数据,包括权重(weight)和偏置(bias)的梯度。
|
|
280
|
-
3. `Cell.conv2.Conv2d.backward.0`为模块的反向数据,其中input为模块反向的输入梯度(对应前向输出的梯度),output为模块的反向输出梯度(对应前向输入的梯度)。
|
|
284
|
+
**说明**:当dump时传入的model参数为List[mindspore.nn.Cell]或Tuple[mindspore.nn.Cell]时,模块级数据的命名中包含该模块在列表中的索引index,命名格式为`{Cell}.{index}.*`,*表示以上三种模块级数据的命名格式,例如:`Cell.0.conv2.Conv2d.forward.0`。
|
|
281
285
|
|
|
282
286
|
```json
|
|
283
287
|
{
|
|
@@ -429,9 +433,9 @@ dump.json文件中包含以下字段:
|
|
|
429
433
|
L1级别的dump.json文件包括API的前反向的输入输出,以MindSpore的relu函数为例,网络中API调用代码为:
|
|
430
434
|
`output = mindspore.ops.relu(input)`
|
|
431
435
|
|
|
432
|
-
dump.json
|
|
433
|
-
|
|
434
|
-
|
|
436
|
+
dump.json文件中包含以下数据名称:
|
|
437
|
+
- `Functional.relu.0.forward`:API的前向数据,其中input_args为API的输入数据(位置参数),input_kwargs为API的输入数据(关键字参数),output为API的输出数据。
|
|
438
|
+
- `Functional.relu.0.backward`:API的反向数据,其中input为API的反向输入梯度(对应前向输出的梯度),output为API的反向输出梯度(对应前向输入的梯度)。
|
|
435
439
|
|
|
436
440
|
```json
|
|
437
441
|
{
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
# 单点保存工具 README
|
|
2
|
+
|
|
3
|
+
## 简介
|
|
4
|
+
L0, L1, mix dump存在盲区,网络中的非api/module的输入输出不会被批量dump下来。单点保存提供类似np.save和print的功能和使用体验,可以保存指定的变量。同时针对大模型场景进行了增强,具备以下特性:
|
|
5
|
+
- 可保存变量的反向梯度结果。
|
|
6
|
+
- 能直接保存嵌套结构数据(如 list、dict),无需手动遍历。
|
|
7
|
+
- 自动分 rank 保存。
|
|
8
|
+
- 多次调用时会自动计数。
|
|
9
|
+
- 可配置保存统计值或者张量。
|
|
10
|
+
|
|
11
|
+
## 支持场景
|
|
12
|
+
仅支持 PyTorch 与 MindSpore 的动态图场景。
|
|
13
|
+
|
|
14
|
+
## 使能方式
|
|
15
|
+
|
|
16
|
+
### 配置文件说明
|
|
17
|
+
|
|
18
|
+
通用配置:
|
|
19
|
+
|
|
20
|
+
| 参数 | 解释 | 是否必选 |
|
|
21
|
+
| -------- |-------------------------------------------| -------- |
|
|
22
|
+
| task | dump 的任务类型,str 类型。 单点保存场景仅支持传入"statistics", "tensor"。 | 是 |
|
|
23
|
+
| level | dump 级别,str 类型,根据不同级别采集不同数据。单点保存场景传入"debug"。 | 是 |
|
|
24
|
+
| dump_path | 设置 dump 数据目录路径,str 类型。细节详见[通用配置说明](./02.config_introduction.md#11-通用配置) | 是 |
|
|
25
|
+
| rank | 指定对某张卡上的数据进行采集,list[Union[int, str]] 类型。细节详见[通用配置说明](./02.config_introduction.md#11-通用配置) | 否 |
|
|
26
|
+
|
|
27
|
+
"statistics" 任务子配置项:
|
|
28
|
+
| 参数 | 解释 | 是否必选 |
|
|
29
|
+
| -------- |-------------------------------------------| -------- |
|
|
30
|
+
| summary_mode | 控制 dump 文件输出的模式,str 类型。支持传入"statistics", "md5"。 细节详见[statistics任务子配置项说明](./02.config_introduction.md#12-task-配置为-statistics) | 否 |
|
|
31
|
+
|
|
32
|
+
"tensor" 任务无子配置项。
|
|
33
|
+
|
|
34
|
+
### 接口调用说明
|
|
35
|
+
|
|
36
|
+
调用PrecisionDebugger.save,传入需要保存的变量,指定变量名称以及是否需要保存反向数据。接口入参说明详见[pytorch单点保存接口](./05.data_dump_PyTorch.md#19-save),[mindspore单点保存接口](./06.data_dump_MindSpore.md#615-save)
|
|
37
|
+
|
|
38
|
+
### 实例(以pytorch场景为例)
|
|
39
|
+
|
|
40
|
+
配置文件
|
|
41
|
+
```json
|
|
42
|
+
{
|
|
43
|
+
"task": "statistics",
|
|
44
|
+
"dump_path": "./dump_path",
|
|
45
|
+
"rank": [],
|
|
46
|
+
"level": "debug",
|
|
47
|
+
"statistics": {
|
|
48
|
+
"summary_mode": "statistics"
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
```
|
|
52
|
+
|
|
53
|
+
初始化
|
|
54
|
+
```python
|
|
55
|
+
# 训练启动py脚本
|
|
56
|
+
from mindspore.pytorch import PrecisionDebugger
|
|
57
|
+
debugger = PrecisionDebugger("./config.json")
|
|
58
|
+
for data, label in data_loader:
|
|
59
|
+
# 执行模型训练
|
|
60
|
+
train(data, label)
|
|
61
|
+
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
初始化(无配置文件)
|
|
65
|
+
```python
|
|
66
|
+
# 训练启动py脚本
|
|
67
|
+
from mindspore.pytorch import PrecisionDebugger
|
|
68
|
+
debugger = PrecisionDebugger(dump_path="dump_path", level="debug")
|
|
69
|
+
for data, label in data_loader:
|
|
70
|
+
# 执行模型训练
|
|
71
|
+
train(data, label)
|
|
72
|
+
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
调用保存接口
|
|
76
|
+
```python
|
|
77
|
+
# 训练过程中被调用py文件
|
|
78
|
+
from mindspore.pytorch import PrecisionDebugger
|
|
79
|
+
dict_variable = {"key1": "value1", "key2": [1, 2]}
|
|
80
|
+
PrecisionDebugger.save(dict_variable, "dict_variable", save_backward=False)
|
|
81
|
+
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
## 输出结果
|
|
85
|
+
* **"task" 配置为 "statistics" 场景** :在 dump 目录下会生成包含变量统计值信息的 `debug.json` 文件。
|
|
86
|
+
* **"task" 配置为 "tensor" 场景** :除了在 dump 目录下生成包含变量统计值信息的 `debug.json` 文件外,还会在 dump 子目录 `dump_tensor_data` 中保存张量二进制文件,文件名称格式为 `{variable_name}{grad_flag}.{count}.tensor.{indexes}.{file_suffix}`。
|
|
87
|
+
|
|
88
|
+
- variable_name: 传入save接口的变量名称。
|
|
89
|
+
- grad_flag: 反向数据标识,反向数据为"_grad",正向数据为""。
|
|
90
|
+
- count: 调用计数,多次以相同变量名称调用时的计数。
|
|
91
|
+
- indexes: 索引,在保存嵌套结构数据时的索引。例如:嵌套结构为`{"key1": "value1", "key2": ["value2", "value3"]}`,"value2"的索引为"key2.0"
|
|
92
|
+
- file_suffix:文件后缀,pytorch场景为"pt",mindspore场景为"npy"
|
|
93
|
+
|
|
94
|
+
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# MindSpore 场景的 kernel dump 说明
|
|
2
|
+
|
|
3
|
+
当使用 msprobe 数据采集功能时,level 配置为 "L2" 表示采集 kernel 层级的算子数据,仅支持昇腾 NPU 平台。
|
|
4
|
+
|
|
5
|
+
本文主要介绍 kernel dump 的配置示例和采集结果介绍, msprobe 数据采集功能的详细使用参考 《[MindSpore 场景的精度数据采集](./06.data_dump_MindSpore.md)》。
|
|
6
|
+
|
|
7
|
+
## 1 kernel dump 配置示例
|
|
8
|
+
|
|
9
|
+
使用 kernel dump 时,list 必须要填一个 API 名称,kernel dump 目前每个 step 只支持采集一个 API 的数据。
|
|
10
|
+
API 名称填写参考 L1 dump 结果文件 dump.json 中的API名称,命名格式为:`{api_type}.{api_name}.{API调用次数}.{forward/backward}`。
|
|
11
|
+
|
|
12
|
+
```json
|
|
13
|
+
{
|
|
14
|
+
"task": "tensor",
|
|
15
|
+
"dump_path": "/home/data_dump",
|
|
16
|
+
"level": "L2",
|
|
17
|
+
"rank": [],
|
|
18
|
+
"step": [],
|
|
19
|
+
"tensor": {
|
|
20
|
+
"scope": [],
|
|
21
|
+
"list": ["Functional.linear.0.backward"]
|
|
22
|
+
}
|
|
23
|
+
}
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
## 2 结果文件介绍
|
|
27
|
+
|
|
28
|
+
### 2.1 采集结果说明
|
|
29
|
+
|
|
30
|
+
如果 API kernel 级数据采集成功,会打印以下信息:
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
The kernel data of {api_name} is dumped successfully.
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
注意:如果打印该信息后,没有数据生成,参考**常见问题3.1**进行排查。
|
|
37
|
+
|
|
38
|
+
如果 kernel dump 遇到不支持的 API, 会打印以下信息:
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
The kernel dump does not support the {api_name} API.
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
其中 {api_name} 是对应溢出的 API 名称。
|
|
45
|
+
|
|
46
|
+
### 2.2 输出文件说明
|
|
47
|
+
kernel dump 采集成功后,会在指定的 dump_path 目录下生成如下文件:
|
|
48
|
+
|
|
49
|
+
```
|
|
50
|
+
├── /home/data_dump/
|
|
51
|
+
│ ├── step0
|
|
52
|
+
│ │ ├── 20241201103000 # 日期时间格式,表示2024-12-01 10:30:00
|
|
53
|
+
│ │ │ ├── 0 # 表示 device id
|
|
54
|
+
│ │ │ │ ├──{op_type}.{op_name}.{task_id}.{stream_id}.{timestamp} # kernel 层算子数据
|
|
55
|
+
│ │ │ ...
|
|
56
|
+
│ │ ├── kernel_config_{device_id}.json # kernel dump 在接口调用过程中生成的中间文件,一般情况下无需关注
|
|
57
|
+
│ │ ...
|
|
58
|
+
│ ├── step1
|
|
59
|
+
│ ...
|
|
60
|
+
```
|
|
61
|
+
成功采集到数据后,可以使用 msprobe 工具提供的《[PyTorch 场景的数据解析](./14.data_parse_PyTorch.md)》功能分析数据。
|
|
62
|
+
|
|
63
|
+
## 3 常见问题
|
|
64
|
+
|
|
65
|
+
#### 3.1 采集结果文件为空,有可能是什么原因?
|
|
66
|
+
|
|
67
|
+
1. 首先需要确认工具使用方式、配置文件内容、list 填写的 API 名称格式是否都正确无误。
|
|
68
|
+
|
|
69
|
+
2. 其次需要确认 API 是否运行在昇腾 NPU 上,如果是运行在其他设备上则不会存在 kernel 级数据。
|
|
Binary file
|
msprobe/mindspore/__init__.py
CHANGED
|
@@ -26,6 +26,7 @@ from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager
|
|
|
26
26
|
from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context,
|
|
27
27
|
trim_output_compute_element_list)
|
|
28
28
|
from msprobe.mindspore.common.log import logger
|
|
29
|
+
from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer
|
|
29
30
|
|
|
30
31
|
cur_path = os.path.dirname(os.path.realpath(__file__))
|
|
31
32
|
yaml_path = os.path.join(cur_path, MsCompareConst.SUPPORTED_API_LIST_FILE)
|
|
@@ -82,9 +83,11 @@ class ApiAccuracyChecker:
|
|
|
82
83
|
# get output
|
|
83
84
|
if global_context.get_is_constructed():
|
|
84
85
|
# constructed situation, need use constructed input to run mindspore api getting tested_output
|
|
85
|
-
tested_outputs = api_runner(api_input_aggregation, api_name_str,
|
|
86
|
+
tested_outputs = api_runner(api_input_aggregation, api_name_str,
|
|
87
|
+
forward_or_backward, global_context.get_framework())
|
|
86
88
|
else:
|
|
87
89
|
tested_outputs = api_info.get_compute_element_list(forward_or_backward, Const.OUTPUT)
|
|
90
|
+
|
|
88
91
|
bench_outputs = api_runner(api_input_aggregation, api_name_str, forward_or_backward, Const.PT_FRAMEWORK)
|
|
89
92
|
tested_outputs = trim_output_compute_element_list(tested_outputs, forward_or_backward)
|
|
90
93
|
bench_outputs = trim_output_compute_element_list(bench_outputs, forward_or_backward)
|
|
@@ -153,13 +156,19 @@ class ApiAccuracyChecker:
|
|
|
153
156
|
real_api_str = Const.SEP.join(api_name_str_list[1:-2])
|
|
154
157
|
api_list = load_yaml(yaml_path)
|
|
155
158
|
supported_tensor_api_list = api_list.get(MsCompareConst.SUPPORTED_TENSOR_LIST_KEY)
|
|
156
|
-
if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL)
|
|
159
|
+
if api_type_str in (MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL) \
|
|
160
|
+
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
161
|
+
return True
|
|
162
|
+
if api_type_str in MsCompareConst.MT_VALID_API_TYPES \
|
|
163
|
+
and global_context.get_framework() == Const.MT_FRAMEWORK:
|
|
157
164
|
return True
|
|
158
|
-
if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list
|
|
165
|
+
if api_type_str == MsCompareConst.TENSOR_API and real_api_str in supported_tensor_api_list \
|
|
166
|
+
and global_context.get_framework() == Const.MS_FRAMEWORK:
|
|
159
167
|
return True
|
|
160
168
|
return False
|
|
161
169
|
|
|
162
170
|
def parse(self, api_info_path):
|
|
171
|
+
|
|
163
172
|
api_info_dict = load_json(api_info_path)
|
|
164
173
|
|
|
165
174
|
# init global context
|
|
@@ -167,14 +176,25 @@ class ApiAccuracyChecker:
|
|
|
167
176
|
"task field in api_info.json", accepted_type=str,
|
|
168
177
|
accepted_value=(MsCompareConst.STATISTICS_TASK,
|
|
169
178
|
MsCompareConst.TENSOR_TASK))
|
|
179
|
+
try:
|
|
180
|
+
framework = check_and_get_from_json_dict(api_info_dict, MsCompareConst.FRAMEWORK,
|
|
181
|
+
"framework field in api_info.json", accepted_type=str,
|
|
182
|
+
accepted_value=(Const.MS_FRAMEWORK,
|
|
183
|
+
Const.MT_FRAMEWORK))
|
|
184
|
+
except Exception as e:
|
|
185
|
+
framework = Const.MS_FRAMEWORK
|
|
186
|
+
logger.warning(f"JSON parsing error in framework field: {e}")
|
|
187
|
+
|
|
188
|
+
if framework == Const.MT_FRAMEWORK and not torch_mindtorch_importer.is_valid_pt_mt_env:
|
|
189
|
+
raise Exception(f"Please check if you have a valid PyTorch and MindTorch environment")
|
|
190
|
+
|
|
170
191
|
is_constructed = task == MsCompareConst.STATISTICS_TASK
|
|
171
192
|
if not is_constructed:
|
|
172
193
|
dump_data_dir = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DUMP_DATA_DIR_FIELD,
|
|
173
|
-
"dump_data_dir field in api_info.json",
|
|
174
|
-
accepted_type=str)
|
|
194
|
+
"dump_data_dir field in api_info.json", accepted_type=str)
|
|
175
195
|
else:
|
|
176
196
|
dump_data_dir = ""
|
|
177
|
-
global_context.init(is_constructed, dump_data_dir)
|
|
197
|
+
global_context.init(is_constructed, dump_data_dir, framework)
|
|
178
198
|
|
|
179
199
|
api_info_data = check_and_get_from_json_dict(api_info_dict, MsCompareConst.DATA_FIELD,
|
|
180
200
|
"data field in api_info.json", accepted_type=dict)
|
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
16
|
import mindspore
|
|
17
|
-
import torch
|
|
18
17
|
from mindspore import ops
|
|
19
18
|
from msprobe.core.common.const import Const, MsCompareConst
|
|
20
19
|
from msprobe.core.common.exceptions import ApiAccuracyCheckerException
|
|
@@ -24,14 +23,28 @@ from msprobe.mindspore.api_accuracy_checker.utils import convert_to_tuple
|
|
|
24
23
|
from msprobe.mindspore.common.log import logger
|
|
25
24
|
|
|
26
25
|
|
|
26
|
+
from msprobe.mindspore.api_accuracy_checker import torch_mindtorch_importer
|
|
27
|
+
|
|
28
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch
|
|
29
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch_tensor
|
|
30
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch_func
|
|
31
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch_dist
|
|
32
|
+
|
|
33
|
+
if torch_mindtorch_importer.is_valid_pt_mt_env:
|
|
34
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import torch
|
|
35
|
+
else:
|
|
36
|
+
import torch
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
|
|
27
40
|
class ApiInputAggregation:
|
|
28
41
|
def __init__(self, inputs, kwargs, gradient_inputs) -> None:
|
|
29
|
-
|
|
42
|
+
"""
|
|
30
43
|
Args:
|
|
31
44
|
inputs: List[ComputeElement]
|
|
32
45
|
kwargs: dict{str: ComputeElement}
|
|
33
46
|
gradient_inputs: Union[List[ComputeElement], None]
|
|
34
|
-
|
|
47
|
+
"""
|
|
35
48
|
self.inputs = inputs
|
|
36
49
|
self.kwargs = kwargs
|
|
37
50
|
self.gradient_inputs = gradient_inputs
|
|
@@ -43,16 +56,34 @@ api_parent_module_mapping = {
|
|
|
43
56
|
(MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): mindspore.mint.nn.functional,
|
|
44
57
|
(MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): torch.nn.functional,
|
|
45
58
|
(MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): mindspore.Tensor,
|
|
46
|
-
(MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor
|
|
59
|
+
(MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): torch.Tensor,
|
|
60
|
+
(MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): mindtorch_tensor,
|
|
61
|
+
(MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): torch.Tensor,
|
|
62
|
+
(MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): mindtorch,
|
|
63
|
+
(MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): torch,
|
|
64
|
+
(MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): mindtorch_func,
|
|
65
|
+
(MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): torch.nn.functional,
|
|
66
|
+
(MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): mindtorch_dist,
|
|
67
|
+
(MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): torch.distributed
|
|
68
|
+
|
|
47
69
|
}
|
|
48
70
|
|
|
71
|
+
|
|
49
72
|
api_parent_module_str_mapping = {
|
|
50
73
|
(MsCompareConst.MINT, Const.MS_FRAMEWORK): "mindspore.mint",
|
|
51
74
|
(MsCompareConst.MINT, Const.PT_FRAMEWORK): "torch",
|
|
52
75
|
(MsCompareConst.MINT_FUNCTIONAL, Const.MS_FRAMEWORK): "mindspore.mint.nn.functional",
|
|
53
76
|
(MsCompareConst.MINT_FUNCTIONAL, Const.PT_FRAMEWORK): "torch.nn.functional",
|
|
54
77
|
(MsCompareConst.TENSOR_API, Const.MS_FRAMEWORK): "mindspore.Tensor",
|
|
55
|
-
(MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor"
|
|
78
|
+
(MsCompareConst.TENSOR_API, Const.PT_FRAMEWORK): "torch.Tensor",
|
|
79
|
+
(MsCompareConst.MINDTORCH_TENSOR, Const.MT_FRAMEWORK): "mindtorch_tensor",
|
|
80
|
+
(MsCompareConst.MINDTORCH_TENSOR, Const.PT_FRAMEWORK): "torch.Tensor",
|
|
81
|
+
(MsCompareConst.MINDTORCH, Const.MT_FRAMEWORK): "mindtorch",
|
|
82
|
+
(MsCompareConst.MINDTORCH, Const.PT_FRAMEWORK): "torch",
|
|
83
|
+
(MsCompareConst.MINDTORCH_FUNC, Const.MT_FRAMEWORK): "mindtorch_func",
|
|
84
|
+
(MsCompareConst.MINDTORCH_FUNC, Const.PT_FRAMEWORK): "torch.nn.functional",
|
|
85
|
+
(MsCompareConst.MINDTORCH_DIST, Const.MT_FRAMEWORK): "mindtorch_dist",
|
|
86
|
+
(MsCompareConst.MINDTORCH_DIST, Const.PT_FRAMEWORK): "torch.distributed"
|
|
56
87
|
}
|
|
57
88
|
|
|
58
89
|
|
|
@@ -64,7 +95,7 @@ class ApiRunner:
|
|
|
64
95
|
api_input_aggregation: ApiInputAggregation
|
|
65
96
|
api_name_str: str, e.g. "MintFunctional.relu.0"
|
|
66
97
|
forward_or_backward: str, Union["forward", "backward"]
|
|
67
|
-
api_platform: str, Union["mindspore", "torch"]
|
|
98
|
+
api_platform: str, Union["mindspore", "torch", "mindtorch"]
|
|
68
99
|
|
|
69
100
|
Return:
|
|
70
101
|
outputs: list[ComputeElement]
|
|
@@ -72,35 +103,41 @@ class ApiRunner:
|
|
|
72
103
|
Description:
|
|
73
104
|
run mindspore.mint/torch api
|
|
74
105
|
'''
|
|
75
|
-
|
|
106
|
+
|
|
107
|
+
api_type_str, api_sub_name = self.get_info_from_name(api_name_str, api_platform)
|
|
76
108
|
api_instance = self.get_api_instance(api_type_str, api_sub_name, api_platform)
|
|
77
109
|
|
|
78
110
|
return self.run_api(api_instance, api_input_aggregation, forward_or_backward, api_platform)
|
|
79
111
|
|
|
80
112
|
@staticmethod
|
|
81
|
-
def get_info_from_name(api_name_str):
|
|
82
|
-
|
|
113
|
+
def get_info_from_name(api_name_str, api_platform=Const.MS_FRAMEWORK):
|
|
114
|
+
"""
|
|
83
115
|
Args:
|
|
84
116
|
api_name_str: str, the trimmed key of data dict in api_info.json. e.g. "MintFunctional.relu.0"
|
|
85
|
-
|
|
117
|
+
api_platform: str, the platform for the API, which can be either "mindspore" or "mindtorch".
|
|
118
|
+
It specifies which framework is being used. Default is "mindspore".
|
|
86
119
|
Return:
|
|
87
|
-
api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
|
|
120
|
+
api_type_str: str, Union["MintFunctional", "Mint", "Tensor", "Torch", "Functional"]
|
|
88
121
|
api_sub_name: str, e.g. "relu"
|
|
89
|
-
|
|
122
|
+
"""
|
|
90
123
|
api_name_list = api_name_str.split(Const.SEP)
|
|
91
124
|
if len(api_name_list) != 3:
|
|
92
125
|
err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format"
|
|
93
126
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
94
127
|
api_type_str, api_sub_name = api_name_list[0], api_name_list[1]
|
|
95
|
-
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API]
|
|
128
|
+
if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API] \
|
|
129
|
+
and api_platform == Const.MS_FRAMEWORK:
|
|
96
130
|
err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api"
|
|
97
131
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
98
132
|
|
|
133
|
+
if api_type_str not in MsCompareConst.MT_VALID_API_TYPES and api_platform == Const.MT_FRAMEWORK:
|
|
134
|
+
err_msg = f"ApiRunner.get_info_from_name failed: not torch, functional or Tensor api"
|
|
135
|
+
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
99
136
|
return api_type_str, api_sub_name
|
|
100
137
|
|
|
101
138
|
@staticmethod
|
|
102
139
|
def get_api_instance(api_type_str, api_sub_name, api_platform):
|
|
103
|
-
|
|
140
|
+
"""
|
|
104
141
|
Args:
|
|
105
142
|
api_type_str: str, Union["MintFunctional", "Mint", "Tensor"]
|
|
106
143
|
api_sub_name: str, e.g. "relu"
|
|
@@ -113,11 +150,12 @@ class ApiRunner:
|
|
|
113
150
|
get mindspore.mint/torch api fucntion
|
|
114
151
|
mindspore.mint.{api_sub_name} <--> torch.{api_sub_name}
|
|
115
152
|
mindspore.mint.nn.functional.{api_sub_name} <--> torch.nn.functional.{api_sub_name}
|
|
116
|
-
|
|
153
|
+
"""
|
|
117
154
|
|
|
118
155
|
api_parent_module = api_parent_module_mapping.get((api_type_str, api_platform))
|
|
119
156
|
api_parent_module_str = api_parent_module_str_mapping.get((api_type_str, api_platform))
|
|
120
157
|
full_api_name = api_parent_module_str + Const.SEP + api_sub_name
|
|
158
|
+
|
|
121
159
|
if not hasattr(api_parent_module, api_sub_name):
|
|
122
160
|
err_msg = f"ApiRunner.get_api_instance failed: {full_api_name} is not found"
|
|
123
161
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.ApiWrong))
|
|
@@ -147,7 +185,7 @@ class ApiRunner:
|
|
|
147
185
|
logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue))
|
|
148
186
|
gradient_inputs = tuple(compute_element.get_parameter(get_origin=False, tensor_platform=api_platform)
|
|
149
187
|
for compute_element in gradient_inputs)
|
|
150
|
-
if api_platform == Const.MS_FRAMEWORK:
|
|
188
|
+
if api_platform == Const.MS_FRAMEWORK or api_platform == Const.MT_FRAMEWORK:
|
|
151
189
|
if len(gradient_inputs) == 1:
|
|
152
190
|
gradient_inputs = gradient_inputs[0]
|
|
153
191
|
|
|
@@ -25,6 +25,7 @@ from msprobe.core.common.file_utils import load_npy
|
|
|
25
25
|
from msprobe.mindspore.api_accuracy_checker.type_mapping import (api_info_type_str_to_type,
|
|
26
26
|
ms_dtype_to_dtype_str, torch_dtype_to_dtype_str,
|
|
27
27
|
dtype_str_to_ms_dtype, dtype_str_to_np_dtype,
|
|
28
|
+
dtype_str_to_mindtorch_dtype,
|
|
28
29
|
dtype_str_to_torch_dtype, type_to_api_info_type_str,
|
|
29
30
|
DEFAULT_CONSTRUCT_NP_FLOAT_DTYPE, TUPLE_TYPE_STR,
|
|
30
31
|
MINDSPORE_TENSOR_TYPE_STR, MINDSPORE_DTYPE_TYPE_STR,
|
|
@@ -33,6 +34,15 @@ from msprobe.mindspore.api_accuracy_checker.type_mapping import (api_info_type_s
|
|
|
33
34
|
from msprobe.mindspore.api_accuracy_checker.utils import check_and_get_from_json_dict, global_context
|
|
34
35
|
from msprobe.mindspore.common.log import logger
|
|
35
36
|
|
|
37
|
+
import msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer as env_module
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
if env_module.is_valid_pt_mt_env:
|
|
41
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import mindtorch
|
|
42
|
+
from msprobe.mindspore.api_accuracy_checker.torch_mindtorch_importer import torch
|
|
43
|
+
else:
|
|
44
|
+
import torch
|
|
45
|
+
|
|
36
46
|
|
|
37
47
|
class MstensorMetaData:
|
|
38
48
|
def __init__(self, dtype_str, npy_path, maximum, minimum, shape) -> None:
|
|
@@ -86,6 +96,37 @@ class ComputeElement:
|
|
|
86
96
|
torch_tensor = torch.from_numpy(np_ndarray).to(torch_dtype)
|
|
87
97
|
return torch_tensor
|
|
88
98
|
|
|
99
|
+
@staticmethod
|
|
100
|
+
def transfer_to_mindtorch_tensor(ms_tensor):
|
|
101
|
+
"""
|
|
102
|
+
Args:
|
|
103
|
+
ms_tensor: mindspore.Tensor
|
|
104
|
+
Return:
|
|
105
|
+
mindtorch_tensor: mindtorch.Tensor
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
ms_dtype = ms_tensor.dtype
|
|
109
|
+
|
|
110
|
+
dtype_str = ms_dtype_to_dtype_str.get(ms_dtype)
|
|
111
|
+
|
|
112
|
+
if dtype_str not in dtype_str_to_mindtorch_dtype:
|
|
113
|
+
err_msg = f"ComputeElement.transfer_to_mindtorch_tensor failed: no matching mindtorch dtype for {dtype_str}"
|
|
114
|
+
logger.error_log_with_exp(err_msg,
|
|
115
|
+
ApiAccuracyCheckerException(ApiAccuracyCheckerException.UnsupportType))
|
|
116
|
+
else:
|
|
117
|
+
mindtorch_dtype = dtype_str_to_mindtorch_dtype.get(dtype_str)
|
|
118
|
+
|
|
119
|
+
if dtype_str in int_dtype_str_list:
|
|
120
|
+
middle_dtype = mindspore.int64
|
|
121
|
+
else:
|
|
122
|
+
middle_dtype = mindspore.float64
|
|
123
|
+
|
|
124
|
+
np_ndarray = ms_tensor.astype(middle_dtype).numpy()
|
|
125
|
+
|
|
126
|
+
mindtorch_tensor = mindtorch.from_numpy(np_ndarray).to(ms_dtype)
|
|
127
|
+
|
|
128
|
+
return mindtorch_tensor
|
|
129
|
+
|
|
89
130
|
@staticmethod
|
|
90
131
|
def transfer_to_mindspore_tensor(torch_tensor):
|
|
91
132
|
'''
|
|
@@ -141,8 +182,11 @@ class ComputeElement:
|
|
|
141
182
|
elif isinstance(self.parameter, DtypeMetaData):
|
|
142
183
|
if tensor_platform == Const.MS_FRAMEWORK:
|
|
143
184
|
parameter_tmp = dtype_str_to_ms_dtype.get(self.parameter.dtype_str)
|
|
144
|
-
|
|
185
|
+
elif tensor_platform == Const.PT_FRAMEWORK:
|
|
145
186
|
parameter_tmp = dtype_str_to_torch_dtype.get(self.parameter.dtype_str)
|
|
187
|
+
elif tensor_platform == Const.MT_FRAMEWORK:
|
|
188
|
+
parameter_tmp = dtype_str_to_mindtorch_dtype.get(self.parameter.dtype_str)
|
|
189
|
+
|
|
146
190
|
elif isinstance(self.parameter, MstensorMetaData):
|
|
147
191
|
mstensor_meta_data = self.parameter
|
|
148
192
|
ms_dtype = dtype_str_to_ms_dtype.get(mstensor_meta_data.dtype_str)
|
|
@@ -161,6 +205,8 @@ class ComputeElement:
|
|
|
161
205
|
# if necessary, do transfer
|
|
162
206
|
if not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.PT_FRAMEWORK:
|
|
163
207
|
parameter = self.transfer_to_torch_tensor(parameter_tmp)
|
|
208
|
+
elif not get_origin and isinstance(parameter_tmp, mindspore.Tensor) and tensor_platform == Const.MT_FRAMEWORK:
|
|
209
|
+
parameter = self.transfer_to_mindtorch_tensor(parameter_tmp)
|
|
164
210
|
elif not get_origin and isinstance(parameter_tmp, torch.Tensor) and tensor_platform == Const.MS_FRAMEWORK:
|
|
165
211
|
parameter = self.transfer_to_mindspore_tensor(parameter_tmp)
|
|
166
212
|
else:
|