mindstudio-probe 1.0.4__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/METADATA +1 -1
  2. mindstudio_probe-1.1.0.dist-info/RECORD +287 -0
  3. msprobe/README.md +46 -16
  4. msprobe/__init__.py +16 -1
  5. msprobe/config.json +0 -2
  6. msprobe/core/advisor/advisor.py +8 -8
  7. msprobe/core/advisor/advisor_const.py +6 -7
  8. msprobe/core/advisor/advisor_result.py +12 -12
  9. msprobe/core/common/const.py +64 -3
  10. msprobe/core/common/exceptions.py +2 -2
  11. msprobe/core/common/file_utils.py +54 -9
  12. msprobe/core/common/inplace_op_checker.py +38 -0
  13. msprobe/core/common/inplace_ops.yaml +251 -0
  14. msprobe/core/common/log.py +21 -11
  15. msprobe/core/common/utils.py +153 -167
  16. msprobe/core/common_config.py +18 -25
  17. msprobe/core/compare/acc_compare.py +209 -36
  18. msprobe/core/compare/check.py +102 -17
  19. msprobe/core/compare/compare_cli.py +21 -1
  20. msprobe/core/compare/highlight.py +41 -5
  21. msprobe/core/compare/multiprocessing_compute.py +33 -8
  22. msprobe/core/compare/npy_compare.py +21 -6
  23. msprobe/core/compare/utils.py +82 -48
  24. msprobe/core/data_dump/data_collector.py +31 -32
  25. msprobe/core/data_dump/data_processor/base.py +45 -22
  26. msprobe/core/data_dump/data_processor/factory.py +20 -3
  27. msprobe/core/data_dump/data_processor/mindspore_processor.py +11 -5
  28. msprobe/core/data_dump/data_processor/pytorch_processor.py +24 -7
  29. msprobe/core/data_dump/json_writer.py +63 -42
  30. msprobe/core/data_dump/scope.py +32 -16
  31. msprobe/core/grad_probe/constant.py +4 -0
  32. msprobe/core/grad_probe/grad_compare.py +2 -3
  33. msprobe/core/grad_probe/utils.py +16 -3
  34. msprobe/docs/01.installation.md +19 -9
  35. msprobe/docs/02.config_introduction.md +52 -80
  36. msprobe/docs/03.config_examples.md +3 -13
  37. msprobe/docs/04.acl_config_examples.md +11 -9
  38. msprobe/docs/05.data_dump_PyTorch.md +140 -12
  39. msprobe/docs/06.data_dump_MindSpore.md +47 -5
  40. msprobe/docs/07.accuracy_checker_PyTorch.md +57 -34
  41. msprobe/docs/08.accuracy_checker_online_PyTorch.md +51 -11
  42. msprobe/docs/09.accuracy_checker_MindSpore.md +8 -8
  43. msprobe/docs/10.accuracy_compare_PyTorch.md +181 -99
  44. msprobe/docs/11.accuracy_compare_MindSpore.md +162 -31
  45. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  46. msprobe/docs/15.free_benchmarking_PyTorch.md +59 -53
  47. msprobe/docs/16.free_benchmarking_MindSpore.md +140 -0
  48. msprobe/docs/17.grad_probe.md +14 -16
  49. msprobe/docs/18.online_dispatch.md +89 -0
  50. msprobe/docs/{FAQ_PyTorch.md → FAQ.md} +22 -10
  51. msprobe/docs/img/ms_dump.png +0 -0
  52. msprobe/docs/img/ms_layer.png +0 -0
  53. msprobe/docs/img/pt_dump.png +0 -0
  54. msprobe/mindspore/__init__.py +1 -0
  55. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +35 -11
  56. msprobe/mindspore/api_accuracy_checker/api_info.py +7 -0
  57. msprobe/mindspore/cell_processor.py +27 -3
  58. msprobe/mindspore/common/const.py +2 -0
  59. msprobe/mindspore/common/utils.py +18 -2
  60. msprobe/mindspore/compare/distributed_compare.py +9 -22
  61. msprobe/mindspore/compare/layer_mapping.py +146 -0
  62. msprobe/mindspore/compare/modify_mapping.py +107 -0
  63. msprobe/mindspore/compare/ms_compare.py +173 -35
  64. msprobe/mindspore/compare/ms_graph_compare.py +27 -11
  65. msprobe/mindspore/debugger/debugger_config.py +16 -13
  66. msprobe/mindspore/debugger/precision_debugger.py +37 -13
  67. msprobe/mindspore/dump/dump_tool_factory.py +16 -1
  68. msprobe/mindspore/dump/hook_cell/api_registry.py +11 -1
  69. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +206 -0
  70. msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +82 -10
  71. msprobe/mindspore/dump/hook_cell/wrap_api.py +21 -13
  72. msprobe/mindspore/dump/jit_dump.py +41 -17
  73. msprobe/mindspore/dump/kernel_graph_dump.py +19 -3
  74. msprobe/mindspore/dump/kernel_kbyk_dump.py +19 -4
  75. msprobe/mindspore/free_benchmark/api_pynative_self_check.py +19 -4
  76. msprobe/mindspore/free_benchmark/common/config.py +15 -0
  77. msprobe/mindspore/free_benchmark/common/handler_params.py +15 -0
  78. msprobe/mindspore/free_benchmark/common/utils.py +19 -5
  79. msprobe/mindspore/free_benchmark/decorator/dec_forward.py +16 -2
  80. msprobe/mindspore/free_benchmark/decorator/decorator_factory.py +18 -3
  81. msprobe/mindspore/free_benchmark/handler/base_handler.py +18 -3
  82. msprobe/mindspore/free_benchmark/handler/check_handler.py +18 -3
  83. msprobe/mindspore/free_benchmark/handler/fix_handler.py +15 -0
  84. msprobe/mindspore/free_benchmark/handler/handler_factory.py +18 -3
  85. msprobe/mindspore/free_benchmark/perturbation/add_noise.py +22 -7
  86. msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py +15 -0
  87. msprobe/mindspore/free_benchmark/perturbation/bit_noise.py +22 -7
  88. msprobe/mindspore/free_benchmark/perturbation/exchange_value.py +44 -18
  89. msprobe/mindspore/free_benchmark/perturbation/improve_precision.py +18 -4
  90. msprobe/mindspore/free_benchmark/perturbation/no_change.py +16 -1
  91. msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py +20 -5
  92. msprobe/mindspore/free_benchmark/self_check_tool_factory.py +15 -0
  93. msprobe/mindspore/grad_probe/global_context.py +18 -8
  94. msprobe/mindspore/overflow_check/kernel_graph_overflow_check.py +20 -4
  95. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +15 -0
  96. msprobe/mindspore/service.py +42 -123
  97. msprobe/pytorch/__init__.py +20 -1
  98. msprobe/pytorch/api_accuracy_checker/common/config.py +19 -2
  99. msprobe/pytorch/api_accuracy_checker/common/utils.py +53 -21
  100. msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +19 -2
  101. msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +47 -21
  102. msprobe/pytorch/api_accuracy_checker/compare/compare.py +51 -21
  103. msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +23 -6
  104. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +28 -8
  105. msprobe/pytorch/api_accuracy_checker/config.yaml +1 -1
  106. msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +67 -32
  107. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +26 -5
  108. msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +19 -2
  109. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +51 -125
  110. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut_utils.py +146 -3
  111. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py +21 -0
  112. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py +78 -33
  113. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +27 -4
  114. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/dump_dispatch.py +110 -0
  115. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py +36 -11
  116. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/torch_ops_config.yaml +63 -0
  117. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +44 -0
  118. msprobe/pytorch/bench_functions/__init__.py +18 -3
  119. msprobe/pytorch/bench_functions/apply_adam_w.py +15 -0
  120. msprobe/pytorch/bench_functions/confusion_transpose.py +15 -0
  121. msprobe/pytorch/bench_functions/fast_gelu.py +15 -0
  122. msprobe/pytorch/bench_functions/layer_norm_eval.py +15 -0
  123. msprobe/pytorch/bench_functions/linear.py +15 -0
  124. msprobe/pytorch/bench_functions/matmul_backward.py +21 -6
  125. msprobe/pytorch/bench_functions/npu_fusion_attention.py +180 -151
  126. msprobe/pytorch/bench_functions/rms_norm.py +15 -0
  127. msprobe/pytorch/bench_functions/rotary_mul.py +28 -9
  128. msprobe/pytorch/bench_functions/scaled_mask_softmax.py +15 -0
  129. msprobe/pytorch/bench_functions/swiglu.py +20 -5
  130. msprobe/pytorch/common/__init__.py +15 -0
  131. msprobe/pytorch/common/log.py +18 -6
  132. msprobe/pytorch/common/parse_json.py +26 -11
  133. msprobe/pytorch/common/utils.py +40 -35
  134. msprobe/pytorch/compare/distributed_compare.py +11 -11
  135. msprobe/pytorch/compare/match.py +15 -0
  136. msprobe/pytorch/compare/pt_compare.py +38 -6
  137. msprobe/pytorch/debugger/debugger_config.py +52 -39
  138. msprobe/pytorch/debugger/precision_debugger.py +72 -24
  139. msprobe/pytorch/free_benchmark/__init__.py +20 -5
  140. msprobe/pytorch/free_benchmark/common/enums.py +28 -0
  141. msprobe/pytorch/free_benchmark/common/params.py +15 -0
  142. msprobe/pytorch/free_benchmark/common/utils.py +17 -1
  143. msprobe/pytorch/free_benchmark/compare/grad_saver.py +28 -7
  144. msprobe/pytorch/free_benchmark/compare/single_benchmark.py +15 -0
  145. msprobe/pytorch/free_benchmark/main.py +19 -4
  146. msprobe/pytorch/free_benchmark/perturbed_layers/base_layer.py +15 -0
  147. msprobe/pytorch/free_benchmark/perturbed_layers/layer_factory.py +19 -4
  148. msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +15 -0
  149. msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +15 -0
  150. msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +26 -2
  151. msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +15 -0
  152. msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +15 -0
  153. msprobe/pytorch/free_benchmark/perturbed_layers/npu/npu_base_layser.py +15 -0
  154. msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +15 -0
  155. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +55 -16
  156. msprobe/pytorch/free_benchmark/result_handlers/check_handler.py +15 -0
  157. msprobe/pytorch/free_benchmark/result_handlers/fix_handler.py +15 -0
  158. msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +15 -0
  159. msprobe/pytorch/free_benchmark/result_handlers/preheat_handler.py +19 -4
  160. msprobe/pytorch/function_factory.py +17 -2
  161. msprobe/pytorch/functional/module_dump.py +84 -0
  162. msprobe/pytorch/grad_probe/grad_stat_csv.py +2 -2
  163. msprobe/pytorch/hook_module/__init__.py +16 -1
  164. msprobe/pytorch/hook_module/api_registry.py +13 -8
  165. msprobe/pytorch/hook_module/hook_module.py +17 -19
  166. msprobe/pytorch/hook_module/utils.py +4 -6
  167. msprobe/pytorch/hook_module/wrap_aten.py +12 -11
  168. msprobe/pytorch/hook_module/wrap_distributed.py +6 -7
  169. msprobe/pytorch/hook_module/wrap_functional.py +10 -11
  170. msprobe/pytorch/hook_module/wrap_npu_custom.py +9 -17
  171. msprobe/pytorch/hook_module/wrap_tensor.py +4 -6
  172. msprobe/pytorch/hook_module/wrap_torch.py +4 -6
  173. msprobe/pytorch/hook_module/wrap_vf.py +4 -6
  174. msprobe/pytorch/module_processer.py +17 -2
  175. msprobe/pytorch/online_dispatch/compare.py +11 -12
  176. msprobe/pytorch/online_dispatch/single_compare.py +7 -7
  177. msprobe/pytorch/online_dispatch/torch_ops_config.yaml +8 -0
  178. msprobe/pytorch/online_dispatch/utils.py +1 -4
  179. msprobe/pytorch/parse.py +15 -0
  180. msprobe/pytorch/parse_tool/cli.py +5 -6
  181. msprobe/pytorch/parse_tool/lib/compare.py +9 -10
  182. msprobe/pytorch/parse_tool/lib/parse_tool.py +3 -0
  183. msprobe/pytorch/parse_tool/lib/utils.py +28 -24
  184. msprobe/pytorch/parse_tool/lib/visualization.py +1 -1
  185. msprobe/pytorch/pt_config.py +167 -38
  186. msprobe/pytorch/service.py +97 -32
  187. mindstudio_probe-1.0.4.dist-info/RECORD +0 -276
  188. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py +0 -10
  189. msprobe/pytorch/functional/data_processor.py +0 -0
  190. msprobe/pytorch/functional/dump_module.py +0 -39
  191. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/LICENSE +0 -0
  192. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/WHEEL +0 -0
  193. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/entry_points.txt +0 -0
  194. {mindstudio_probe-1.0.4.dist-info → mindstudio_probe-1.1.0.dist-info}/top_level.txt +0 -0
@@ -1,2 +1,17 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  from .parse_json import parse_json_info_forward_backward
2
17
  from .utils import seed_all
@@ -1,9 +1,21 @@
1
- import os
2
- import time
3
- import sys
4
- from msprobe.pytorch.common.utils import get_rank_if_initialized
5
- from msprobe.core.common.log import BaseLogger
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
6
16
  from msprobe.core.common.exceptions import DistributedNotInitializedError
17
+ from msprobe.core.common.log import BaseLogger
18
+ from msprobe.pytorch.common.utils import get_rank_if_initialized
7
19
 
8
20
 
9
21
  class PyTorchLogger(BaseLogger):
@@ -18,4 +30,4 @@ class PyTorchLogger(BaseLogger):
18
30
  return current_rank
19
31
 
20
32
 
21
- logger = PyTorchLogger()
33
+ logger = PyTorchLogger()
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import json
2
17
 
3
18
  from msprobe.core.common.exceptions import ParseJsonException
@@ -5,14 +20,6 @@ from msprobe.core.common.file_utils import FileOpen
5
20
 
6
21
 
7
22
  def parse_json_info_forward_backward(json_path):
8
- def parse_data_name_with_pattern(data_name, pattern):
9
- name_struct = data_name.split('.')
10
- if not name_struct[-1] == pattern:
11
- raise ParseJsonException(ParseJsonException.UnexpectedNameStruct,
12
- f"{data_name} in file {json_path}")
13
- api_name = '.'.join(name_struct[:-1])
14
- return api_name
15
-
16
23
  with FileOpen(json_path, 'r') as f:
17
24
  dump_json = json.load(f)
18
25
 
@@ -27,13 +34,21 @@ def parse_json_info_forward_backward(json_path):
27
34
  if "Module" in data_name:
28
35
  continue
29
36
  if "forward" in data_name:
30
- api_name = parse_data_name_with_pattern(data_name, "forward")
37
+ api_name = parse_data_name_with_pattern(data_name, "forward", json_path)
31
38
  forward_data.update({api_name: data_item})
32
39
  elif "backward" in data_name:
33
- api_name = parse_data_name_with_pattern(data_name, "backward")
40
+ api_name = parse_data_name_with_pattern(data_name, "backward", json_path)
34
41
  backward_data.update({api_name: data_item})
35
42
  else:
36
43
  raise ParseJsonException(ParseJsonException.UnexpectedNameStruct,
37
- f"{data_name} in file {json_path}.")
44
+ f"{data_name} in file {json_path}.")
38
45
 
39
46
  return forward_data, backward_data, real_data_path
47
+
48
+
49
+ def parse_data_name_with_pattern(data_name, pattern, json_path):
50
+ name_struct = data_name.split('.')
51
+ if not name_struct[-1] == pattern:
52
+ raise ParseJsonException(ParseJsonException.UnexpectedNameStruct, f"{data_name} in file {json_path}")
53
+ api_name = '.'.join(name_struct[:-1])
54
+ return api_name
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,20 +12,22 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
17
16
  import io
18
17
  import os
19
18
  import random
20
19
  import stat
20
+ from functools import wraps
21
+
22
+ import numpy as np
21
23
  import torch
22
24
  import torch.distributed as dist
23
- import numpy as np
24
- from functools import wraps
25
25
  from msprobe.core.common.exceptions import DistributedNotInitializedError
26
- from msprobe.core.common.log import logger
27
26
  from msprobe.core.common.file_utils import (FileCheckConst, change_mode,
28
27
  check_file_or_directory_path, check_path_before_create)
29
-
28
+ from msprobe.core.common.log import logger
29
+ from msprobe.core.common.utils import check_seed_all
30
+ from packaging import version
30
31
 
31
32
  try:
32
33
  import torch_npu
@@ -35,10 +36,8 @@ except ImportError:
35
36
  else:
36
37
  is_gpu = False
37
38
 
38
-
39
39
  torch_without_guard_version = torch.__version__ >= '2.1'
40
40
 
41
-
42
41
  if not is_gpu and not torch_without_guard_version:
43
42
  from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard
44
43
 
@@ -46,7 +45,6 @@ npu_distributed_api = ['isend', 'irecv']
46
45
 
47
46
 
48
47
  def parameter_adapter(func):
49
-
50
48
  def handle_masked_select(input_tensor, indices):
51
49
  masked_select_func = getattr(torch._C._VariableFunctionsClass, "masked_select")
52
50
  if input_tensor.dtype == torch.bfloat16:
@@ -80,17 +78,19 @@ def parameter_adapter(func):
80
78
  if self.op_name_ == "__eq__" and args[1] is None:
81
79
  return False
82
80
  return func(self, *args, **kwargs)
81
+
83
82
  return inner
84
83
 
85
84
 
86
85
  def torch_device_guard(func):
87
86
  if is_gpu or torch_without_guard_version:
88
87
  return func
89
- # Parse args/kwargs matched torch.device objects
90
88
 
89
+ # Parse args/kwargs matched torch.device objects
91
90
  @torch_npu_device_guard
92
91
  def wrapper(*args, **kwargs):
93
92
  return func(*args, **kwargs)
93
+
94
94
  return wrapper
95
95
 
96
96
 
@@ -105,20 +105,28 @@ def get_rank_if_initialized():
105
105
 
106
106
 
107
107
  def seed_all(seed=1234, mode=False):
108
- random.seed(seed)
109
- os.environ['PYTHONHASHSEED'] = str(seed)
110
- np.random.seed(seed)
111
- torch.manual_seed(seed)
112
- torch.use_deterministic_algorithms(mode)
113
- if is_gpu:
114
- torch.cuda.manual_seed_all(seed)
115
- torch.cuda.manual_seed(seed)
116
- torch.backends.cudnn.deterministic = True
117
- torch.backends.cudnn.enable = False
118
- torch.backends.cudnn.benchmark = False
119
- else:
120
- torch_npu.npu.manual_seed_all(seed)
121
- torch_npu.npu.manual_seed(seed)
108
+ check_seed_all(seed, mode)
109
+ try:
110
+ random.seed(seed)
111
+ os.environ['PYTHONHASHSEED'] = str(seed)
112
+ np.random.seed(seed)
113
+ torch.manual_seed(seed)
114
+ cuda_version = torch.version.cuda
115
+ if cuda_version is not None and version.parse(cuda_version) >= version.parse("10.2"):
116
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
117
+ os.environ['HCCL_DETERMINISTIC'] = str(mode)
118
+ torch.use_deterministic_algorithms(mode)
119
+ if is_gpu:
120
+ torch.cuda.manual_seed_all(seed)
121
+ torch.cuda.manual_seed(seed)
122
+ torch.backends.cudnn.deterministic = True
123
+ torch.backends.cudnn.enable = False
124
+ torch.backends.cudnn.benchmark = False
125
+ else:
126
+ torch_npu.npu.manual_seed_all(seed)
127
+ torch_npu.npu.manual_seed(seed)
128
+ except Exception as e:
129
+ logger.error(f"There is an unexpected error while determinating randomness. {e}")
122
130
 
123
131
 
124
132
  class Const:
@@ -191,10 +199,7 @@ class Const:
191
199
  ENV_ENABLE = "1"
192
200
  ENV_DISABLE = "0"
193
201
 
194
- MAX_SEED_VALUE = 2**32 - 1
195
-
196
- INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
197
- "_reduce_scatter_base", "_all_gather_base", "all_to_all_single"]
202
+ MAX_SEED_VALUE = 2 ** 32 - 1
198
203
 
199
204
  TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"]
200
205
  LEVEL_LIST = ["L0", "L1", "L2", "mix"]
@@ -257,7 +262,7 @@ def print_rank_0(message):
257
262
  logger.info(message)
258
263
  else:
259
264
  logger.info(message)
260
-
265
+
261
266
 
262
267
  def load_pt(pt_path, to_cpu=False):
263
268
  pt_path = os.path.realpath(pt_path)
@@ -279,8 +284,8 @@ def save_pt(tensor, filepath):
279
284
  torch.save(tensor, filepath)
280
285
  except Exception as e:
281
286
  logger.error("Save pt file failed, please check according possible error causes: "
282
- "1. out of disk space or disk error, "
283
- "2. no permission to write files, etc.")
287
+ "1. out of disk space or disk error, "
288
+ "2. no permission to write files, etc.")
284
289
  raise RuntimeError(f"save pt file {filepath} failed") from e
285
290
  change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY)
286
291
 
@@ -1,8 +1,7 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- # Copyright (C) 2019-2024. Huawei Technologies Co., Ltd. All rights reserved.
5
- # Licensed under the Apache License, Version 2.0 (the "License");
1
+ # Copyright (c) 2019-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
5
  # you may not use this file except in compliance with the License.
7
6
  # You may obtain a copy of the License at
8
7
  #
@@ -13,14 +12,13 @@
13
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
13
  # See the License for the specific language governing permissions and
15
14
  # limitations under the License.
16
- """
15
+
17
16
  import os
18
17
  from msprobe.core.common.utils import CompareException, check_compare_param, \
19
18
  check_configuration_param, task_dumppath_get
20
19
  from msprobe.core.common.file_utils import create_directory
21
20
  from msprobe.core.common.exceptions import FileCheckException
22
21
  from msprobe.pytorch.common.log import logger
23
- from msprobe.core.common.const import Const
24
22
  from msprobe.pytorch.compare.pt_compare import PTComparator
25
23
  from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json
26
24
 
@@ -55,12 +53,14 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs):
55
53
  }
56
54
  try:
57
55
  summary_compare, md5_compare = task_dumppath_get(dump_result_param)
58
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
56
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match,
57
+ dump_result_param.get('is_print_compare_log', True))
59
58
  create_directory(output_path)
60
- check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare)
59
+ check_compare_param(dump_result_param, output_path,
60
+ summary_compare=summary_compare, md5_compare=md5_compare)
61
61
  except (CompareException, FileCheckException) as error:
62
62
  logger.error('Compare failed. Please check the arguments and do it again!')
63
63
  raise CompareException(error.code) from error
64
64
  pt_comparator = PTComparator()
65
- pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare,
66
- md5_compare=md5_compare, **kwargs)
65
+ pt_comparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}',
66
+ summary_compare=summary_compare, md5_compare=md5_compare, **kwargs)
@@ -1,3 +1,18 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os
2
17
  from msprobe.core.common.utils import CompareException
3
18
  from msprobe.core.common.file_utils import load_yaml
@@ -1,17 +1,48 @@
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import os.path
2
17
  import torch
3
18
  from msprobe.core.common.const import FileCheckConst
4
19
  from msprobe.pytorch.common.log import logger
5
20
  from msprobe.core.common.exceptions import FileCheckException
6
21
  from msprobe.core.compare.acc_compare import Comparator
7
- from msprobe.core.common.utils import check_configuration_param, task_dumppath_get, check_compare_param, CompareException
8
- from msprobe.core.common.file_utils import FileChecker, create_directory
22
+ from msprobe.core.common.utils import check_configuration_param, task_dumppath_get, check_compare_param, \
23
+ CompareException
24
+ from msprobe.core.common.file_utils import FileChecker, create_directory, load_yaml
9
25
  from msprobe.pytorch.common.utils import load_pt
10
26
 
11
27
 
12
28
  class PTComparator (Comparator):
13
- def __init__(self):
29
+ def __init__(self, data_mapping=None):
14
30
  self.frame_name = PTComparator.__name__
31
+ self.data_mapping = data_mapping
32
+ if isinstance(self.data_mapping, str) or self.data_mapping is None:
33
+ self.data_mapping_dict = self.load_mapping_file(self.data_mapping)
34
+ elif isinstance(self.data_mapping, dict):
35
+ self.data_mapping_dict = self.data_mapping
36
+ else:
37
+ raise TypeError(f"The type of parameter `data_mapping` must be dict, str or None, but got "
38
+ f"{type(self.data_mapping)}")
39
+
40
+ def load_mapping_file(self, mapping_file):
41
+ if isinstance(mapping_file, str):
42
+ mapping_dict = load_yaml(mapping_file)
43
+ else:
44
+ mapping_dict = {}
45
+ return mapping_dict
15
46
 
16
47
  def read_npy_data(self, dir_path, file_name):
17
48
  data_path = os.path.join(dir_path, file_name)
@@ -35,16 +66,17 @@ class PTComparator (Comparator):
35
66
  return data_value
36
67
 
37
68
 
38
- def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False):
69
+ def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False, **kwargs):
39
70
  try:
40
71
  summary_compare, md5_compare = task_dumppath_get(input_param)
41
- check_configuration_param(stack_mode, auto_analyze, fuzzy_match)
72
+ check_configuration_param(stack_mode, auto_analyze, fuzzy_match, input_param.get('is_print_compare_log', True))
42
73
  create_directory(output_path)
43
74
  check_compare_param(input_param, output_path, summary_compare, md5_compare)
75
+ data_mapping = kwargs.get('data_mapping', None)
44
76
  except (CompareException, FileCheckException) as error:
45
77
  logger.error('Compare failed. Please check the arguments and do it again!')
46
78
  raise CompareException(error.code) from error
47
- pt_comparator = PTComparator()
79
+ pt_comparator = PTComparator(data_mapping)
48
80
  pt_comparator.compare_core(input_param, output_path, stack_mode=stack_mode,
49
81
  auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare,
50
82
  md5_compare=md5_compare)
@@ -1,6 +1,23 @@
1
- from msprobe.pytorch.common import seed_all
2
- from msprobe.pytorch.common.log import logger
1
+ # Copyright (c) 2024-2024, Huawei Technologies Co., Ltd.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
3
18
  from msprobe.core.common.const import Const
19
+ from msprobe.core.common.exceptions import MsprobeException
20
+ from msprobe.pytorch.common.log import logger
4
21
 
5
22
 
6
23
  class DebuggerConfig:
@@ -10,8 +27,6 @@ class DebuggerConfig:
10
27
  self.rank = common_config.rank if common_config.rank else []
11
28
  self.step = common_config.step if common_config.step else []
12
29
  self.level = level or common_config.level or "L1"
13
- self.seed = common_config.seed if common_config.seed else 1234
14
- self.is_deterministic = common_config.is_deterministic
15
30
  self.enable_dataloader = common_config.enable_dataloader
16
31
  self.scope = task_config.scope if task_config.scope else []
17
32
  self.list = task_config.list if task_config.list else []
@@ -25,15 +40,15 @@ class DebuggerConfig:
25
40
  self.framework = Const.PT_FRAMEWORK
26
41
 
27
42
  if self.task == Const.FREE_BENCHMARK:
28
- self.fuzz_device = task_config.fuzz_device if task_config.fuzz_device else 'npu'
29
- self.handler_type = task_config.handler_type if task_config.handler_type else 'check'
30
- self.pert_mode = task_config.pert_mode if task_config.pert_mode else 'improve_precision'
31
- self.fuzz_level = task_config.fuzz_level if task_config.fuzz_level else 'L1'
32
- self.fuzz_stage = task_config.fuzz_stage if task_config.fuzz_stage else 'forward'
43
+ self.fuzz_device = task_config.fuzz_device
44
+ self.handler_type = task_config.handler_type
45
+ self.pert_mode = task_config.pert_mode
46
+ self.fuzz_level = task_config.fuzz_level
47
+ self.fuzz_stage = task_config.fuzz_stage
33
48
  self.preheat_config = {
34
- "if_preheat": task_config.if_preheat if task_config.if_preheat is not None else True,
35
- "preheat_step": task_config.preheat_step if task_config.preheat_step else 15,
36
- "max_sample": task_config.max_sample if task_config.max_sample else 20,
49
+ "if_preheat": task_config.if_preheat,
50
+ "preheat_step": task_config.preheat_step,
51
+ "max_sample": task_config.max_sample
37
52
  }
38
53
 
39
54
  self.online_run_ut = False
@@ -46,8 +61,7 @@ class DebuggerConfig:
46
61
  self.port = task_config.port if task_config.port else -1
47
62
 
48
63
  self.check()
49
- if self.step:
50
- self.step.sort()
64
+
51
65
  if self.level == "L2":
52
66
  if not self.scope or not isinstance(self.scope, list) or len(self.scope) != 1:
53
67
  raise ValueError("scope must be configured as a list with one api name")
@@ -58,38 +72,37 @@ class DebuggerConfig:
58
72
  for index, scope_spec in enumerate(self.scope):
59
73
  self.scope[index] = scope_spec.replace(Const.BACKWARD, Const.FORWARD)
60
74
  self.backward_input[self.scope[index]] = self.backward_input_list[index]
61
- seed_all(self.seed, self.is_deterministic)
62
75
 
63
76
  def check_kwargs(self):
64
77
  if self.task and self.task not in Const.TASK_LIST:
65
- raise Exception("task is invalid")
78
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
79
+ f"The task <{self.task}> is not in the {Const.TASK_LIST}.")
66
80
  if self.level and self.level not in Const.LEVEL_LIST:
67
- raise Exception("level is invalid")
81
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
82
+ f"The level <{self.level}> is not in the {Const.LEVEL_LIST}.")
68
83
  if not self.dump_path:
69
- raise Exception("Invalid dump path, please check your config")
84
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR,
85
+ f"The dump_path not found.")
70
86
 
71
87
  def check(self):
72
88
  self.check_kwargs()
73
- self._check_rank()
74
- self._check_step()
75
89
  return True
76
90
 
77
- def check_model(self, model):
78
- if self.level in ["L0", "mix"] and not model:
79
- raise Exception(
80
- f"For level {self.level}, PrecisionDebugger must receive a model argument."
81
- )
82
-
83
- def _check_rank(self):
84
- if self.rank:
85
- for rank_id in self.rank:
86
- if not isinstance(rank_id, int) or rank_id < 0:
87
- raise ValueError(f"rank {self.rank} must be an integer and greater than or equal to 0.")
88
- else:
89
- logger.warning_on_rank_0(f"Rank argument is provided. Only rank {self.rank} data will be dumpped.")
90
-
91
- def _check_step(self):
92
- if self.step:
93
- for s in self.step:
94
- if not isinstance(s, int) or s < 0:
95
- raise ValueError(f"step element {s} must be an integer and greater than or equal to 0.")
91
+ def check_model(self, instance, start_model):
92
+ if self.level not in ["L0", "mix"]:
93
+ if instance.model is not None or start_model is not None:
94
+ logger.warning_on_rank_0(
95
+ f"The current level is not L0 or mix level, so the model parameters will not be used.")
96
+ return
97
+ if start_model is None:
98
+ if instance.model is None:
99
+ logger.error_on_rank_0(
100
+ f"For level {self.level}, PrecisionDebugger or start interface must receive a 'model' argument.")
101
+ raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, f"missing the parameter 'model'")
102
+ return
103
+ if isinstance(start_model, torch.nn.Module):
104
+ instance.model = start_model
105
+ else:
106
+ logger.error_on_rank_0(f"The 'model' parameter of start must be a torch.nn.Module type.")
107
+ raise MsprobeException(
108
+ MsprobeException.INVALID_PARAM_ERROR, f"model must be a torch.nn.Module")