mindstudio-probe 8.1.1__py3-none-any.whl → 8.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/METADATA +1 -1
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/RECORD +95 -94
- msprobe/core/common/const.py +3 -0
- msprobe/core/common/file_utils.py +45 -5
- msprobe/core/common/utils.py +117 -13
- msprobe/core/common_config.py +15 -1
- msprobe/core/compare/acc_compare.py +21 -9
- msprobe/core/compare/compare_cli.py +10 -2
- msprobe/core/compare/merge_result/merge_result.py +1 -1
- msprobe/core/compare/utils.py +8 -2
- msprobe/core/config_check/checkers/base_checker.py +2 -0
- msprobe/core/config_check/checkers/hyperparameter_checker.py +5 -4
- msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +4 -1
- msprobe/core/config_check/config_check_cli.py +1 -1
- msprobe/core/config_check/config_checker.py +1 -2
- msprobe/core/data_dump/data_collector.py +4 -1
- msprobe/core/data_dump/data_processor/mindspore_processor.py +23 -1
- msprobe/core/data_dump/data_processor/pytorch_processor.py +3 -25
- msprobe/core/debugger/precision_debugger.py +13 -8
- msprobe/core/hook_manager.py +112 -82
- msprobe/core/monitor/utils.py +338 -0
- msprobe/core/service.py +2 -1
- msprobe/core/single_save/single_comparator.py +5 -3
- msprobe/docs/01.installation.md +1 -0
- msprobe/docs/05.data_dump_PyTorch.md +4 -4
- msprobe/docs/07.accuracy_checker_PyTorch.md +14 -11
- msprobe/docs/09.accuracy_checker_MindSpore.md +13 -11
- msprobe/docs/10.accuracy_compare_PyTorch.md +3 -1
- msprobe/docs/11.accuracy_compare_MindSpore.md +4 -2
- msprobe/docs/12.overflow_check_PyTorch.md +3 -2
- msprobe/docs/13.overflow_check_MindSpore.md +1 -1
- msprobe/docs/14.data_parse_PyTorch.md +35 -32
- msprobe/docs/21.visualization_PyTorch.md +9 -8
- msprobe/docs/22.visualization_MindSpore.md +1 -0
- msprobe/docs/23.generate_operator_PyTorch.md +1 -1
- msprobe/docs/24.code_mapping_Mindspore.md +6 -5
- msprobe/docs/31.config_check.md +15 -5
- msprobe/docs/33.generate_operator_MindSpore.md +2 -2
- msprobe/docs/34.RL_collect.md +18 -9
- msprobe/docs/35.nan_analyze.md +4 -3
- msprobe/docs/FAQ.md +3 -0
- msprobe/docs/img/ms_layer.png +0 -0
- msprobe/mindspore/api_accuracy_checker/api_runner.py +29 -1
- msprobe/mindspore/cell_processor.py +35 -14
- msprobe/mindspore/code_mapping/bind.py +23 -4
- msprobe/mindspore/code_mapping/graph_parser.py +6 -4
- msprobe/mindspore/common/utils.py +3 -0
- msprobe/mindspore/compare/common_dir_compare.py +32 -12
- msprobe/mindspore/compare/ms_graph_compare.py +7 -2
- msprobe/mindspore/compare/utils.py +9 -1
- msprobe/mindspore/debugger/debugger_config.py +13 -11
- msprobe/mindspore/debugger/precision_debugger.py +67 -45
- msprobe/mindspore/dump/dump_tool_factory.py +2 -0
- msprobe/mindspore/dump/hook_cell/hook_cell.py +14 -9
- msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +12 -7
- msprobe/mindspore/dump/hook_cell/primitive_hooks.py +27 -13
- msprobe/mindspore/dump/jit_dump.py +6 -3
- msprobe/mindspore/dump/kernel_kbyk_dump.py +13 -6
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +6 -5
- msprobe/mindspore/dym_loader/hook_dynamic_loader.h +2 -2
- msprobe/mindspore/grad_probe/grad_analyzer.py +2 -0
- msprobe/mindspore/mindspore_service.py +2 -2
- msprobe/mindspore/monitor/common_func.py +1 -1
- msprobe/mindspore/monitor/module_hook.py +3 -3
- msprobe/mindspore/monitor/utils.py +0 -252
- msprobe/mindspore/ms_config.py +0 -1
- msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
- msprobe/nan_analyze/graph.py +4 -0
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +15 -6
- msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +1 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +1 -1
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -4
- msprobe/pytorch/common/utils.py +0 -16
- msprobe/pytorch/compare/pt_compare.py +5 -0
- msprobe/pytorch/debugger/debugger_config.py +12 -5
- msprobe/pytorch/debugger/precision_debugger.py +8 -1
- msprobe/pytorch/dump/module_dump/hook_wrapper.py +1 -3
- msprobe/pytorch/dump/module_dump/module_processer.py +44 -13
- msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +2 -0
- msprobe/pytorch/hook_module/hook_module.py +9 -9
- msprobe/pytorch/hook_module/pt_hook_manager.py +7 -7
- msprobe/pytorch/monitor/csv2tb.py +3 -10
- msprobe/pytorch/monitor/features.py +5 -0
- msprobe/pytorch/monitor/module_hook.py +6 -7
- msprobe/pytorch/monitor/module_metric.py +0 -3
- msprobe/pytorch/monitor/optimizer_collect.py +1 -1
- msprobe/pytorch/monitor/utils.py +1 -317
- msprobe/pytorch/online_dispatch/dispatch.py +1 -1
- msprobe/pytorch/online_dispatch/dump_compare.py +7 -1
- msprobe/pytorch/parse_tool/lib/utils.py +2 -4
- msprobe/visualization/graph_service.py +1 -1
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/top_level.txt +0 -0
|
@@ -13,12 +13,15 @@
|
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
15
|
|
|
16
|
+
import threading
|
|
17
|
+
import sys
|
|
16
18
|
from collections import OrderedDict
|
|
17
19
|
|
|
18
20
|
import torch
|
|
19
21
|
from torch.utils.hooks import BackwardHook, RemovableHandle
|
|
20
22
|
|
|
21
23
|
from msprobe.core.common.const import Const
|
|
24
|
+
from msprobe.core.common.utils import ModuleQueue, ThreadSafe
|
|
22
25
|
from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
|
|
23
26
|
from msprobe.pytorch.common.log import logger
|
|
24
27
|
from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
|
|
@@ -46,13 +49,15 @@ def wrap_megatron_deallocate(func):
|
|
|
46
49
|
out.data = torch.empty((1,), device=out.device, dtype=out.dtype, )
|
|
47
50
|
return func(out_clone, deallocate_pipeline_outputs)
|
|
48
51
|
return func(out, deallocate_pipeline_outputs)
|
|
52
|
+
|
|
49
53
|
return wrapper_func
|
|
50
54
|
|
|
51
55
|
|
|
52
56
|
class ModuleProcesser:
|
|
57
|
+
module_queue = ModuleQueue()
|
|
53
58
|
module_count = {}
|
|
54
|
-
module_stack =
|
|
55
|
-
api_parent_node =
|
|
59
|
+
module_stack = {}
|
|
60
|
+
api_parent_node = {}
|
|
56
61
|
module_node = {}
|
|
57
62
|
module_bw_hook_kernels = {}
|
|
58
63
|
module_with_backward_hook = {}
|
|
@@ -64,7 +69,15 @@ class ModuleProcesser:
|
|
|
64
69
|
replace_checkpoint()
|
|
65
70
|
try:
|
|
66
71
|
from megatron.core.pipeline_parallel import schedules
|
|
72
|
+
origin_func_id = id(schedules.deallocate_output_tensor)
|
|
67
73
|
schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor)
|
|
74
|
+
for module in list(sys.modules.values()):
|
|
75
|
+
if module.__name__ == 'schedules':
|
|
76
|
+
continue
|
|
77
|
+
for func in module.__dict__:
|
|
78
|
+
if id(module.__dict__[func]) == origin_func_id:
|
|
79
|
+
module.__setattr__(func, schedules.deallocate_output_tensor)
|
|
80
|
+
logger.debug(f'patch {module.__name__}.{func}.')
|
|
68
81
|
logger.info_on_rank_0("Patch megatron method success.")
|
|
69
82
|
except ImportError:
|
|
70
83
|
logger.info_on_rank_0("No megatron find.")
|
|
@@ -103,9 +116,10 @@ class ModuleProcesser:
|
|
|
103
116
|
|
|
104
117
|
@classmethod
|
|
105
118
|
def reset_module_stats(cls):
|
|
119
|
+
cls.module_queue = ModuleQueue()
|
|
106
120
|
cls.module_count = {}
|
|
107
|
-
cls.module_stack =
|
|
108
|
-
cls.api_parent_node =
|
|
121
|
+
cls.module_stack = {}
|
|
122
|
+
cls.api_parent_node = {}
|
|
109
123
|
cls.module_node = {}
|
|
110
124
|
cls.module_bw_hook_kernels = {}
|
|
111
125
|
cls.enable_module_dump = False
|
|
@@ -144,6 +158,7 @@ class ModuleProcesser:
|
|
|
144
158
|
register_forward_pre_hook(module, forward_pre_hook)
|
|
145
159
|
|
|
146
160
|
def build_module_hook(self, module_name, build_data_hook):
|
|
161
|
+
@ThreadSafe.synchronized
|
|
147
162
|
def forward_pre_hook(module, args, kwargs=None):
|
|
148
163
|
if kwargs is None:
|
|
149
164
|
kwargs = {}
|
|
@@ -171,15 +186,19 @@ class ModuleProcesser:
|
|
|
171
186
|
hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
|
|
172
187
|
|
|
173
188
|
def get_backward_pre_hook(full_backward_name):
|
|
189
|
+
@ThreadSafe.synchronized
|
|
174
190
|
def backward_pre_hook_fn(module, grad_output):
|
|
175
191
|
self.set_construct_info_in_pre_hook(full_backward_name)
|
|
192
|
+
|
|
176
193
|
return backward_pre_hook_fn
|
|
177
194
|
|
|
178
195
|
def get_backward_hook(backward_data_hook, full_backward_name):
|
|
196
|
+
@ThreadSafe.synchronized
|
|
179
197
|
def backward_hook_fn(module, grad_input, grad_output):
|
|
180
198
|
new_output = backward_data_hook(module, grad_input, grad_output)
|
|
181
199
|
self.set_construct_info_in_hook(full_backward_name, is_forward=False)
|
|
182
200
|
return new_output
|
|
201
|
+
|
|
183
202
|
return backward_hook_fn
|
|
184
203
|
|
|
185
204
|
if not ModuleProcesser.module_with_backward_hook.get(module_name):
|
|
@@ -193,6 +212,7 @@ class ModuleProcesser:
|
|
|
193
212
|
args = bw_hook.setup_input_hook(args)
|
|
194
213
|
return (args, kwargs) if torch_version_above_or_equal_2 else args
|
|
195
214
|
|
|
215
|
+
@ThreadSafe.synchronized
|
|
196
216
|
def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None):
|
|
197
217
|
if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump:
|
|
198
218
|
return output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
|
|
@@ -218,23 +238,34 @@ class ModuleProcesser:
|
|
|
218
238
|
return forward_pre_hook
|
|
219
239
|
|
|
220
240
|
def set_construct_info_in_pre_hook(self, full_name):
|
|
221
|
-
|
|
222
|
-
|
|
241
|
+
tid = threading.get_ident()
|
|
242
|
+
if tid not in self.module_stack:
|
|
243
|
+
ModuleProcesser.module_stack[tid] = []
|
|
244
|
+
|
|
245
|
+
if self.module_stack[tid]:
|
|
246
|
+
ModuleProcesser.module_node[full_name] = self.module_stack[tid][-1]
|
|
223
247
|
else:
|
|
224
|
-
ModuleProcesser.
|
|
225
|
-
|
|
226
|
-
|
|
248
|
+
parent_name = ModuleProcesser.module_queue.find_last(full_name)
|
|
249
|
+
ModuleProcesser.module_node[full_name] = parent_name
|
|
250
|
+
|
|
251
|
+
ModuleProcesser.module_queue.add_name(full_name)
|
|
252
|
+
ModuleProcesser.module_stack[tid].append(full_name)
|
|
253
|
+
ModuleProcesser.api_parent_node[tid] = full_name
|
|
227
254
|
if self.scope:
|
|
228
255
|
self.scope.begin_module(full_name)
|
|
229
256
|
|
|
230
257
|
def set_construct_info_in_hook(self, full_name, is_forward=True):
|
|
258
|
+
tid = threading.get_ident()
|
|
231
259
|
if torch_version_above_or_equal_2 or is_forward:
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
260
|
+
ModuleProcesser.module_queue.remove_name(full_name)
|
|
261
|
+
ModuleProcesser.api_parent_node[tid] = None
|
|
262
|
+
if self.module_stack.get(tid):
|
|
263
|
+
ModuleProcesser.module_stack[tid].pop()
|
|
264
|
+
if self.module_stack.get(tid):
|
|
265
|
+
ModuleProcesser.api_parent_node[tid] = ModuleProcesser.module_stack[tid][-1]
|
|
235
266
|
if self.scope:
|
|
236
267
|
self.scope.end_module(full_name)
|
|
237
268
|
else:
|
|
238
269
|
if self.scope:
|
|
239
270
|
self.scope.begin_module(full_name)
|
|
240
|
-
ModuleProcesser.api_parent_node = full_name
|
|
271
|
+
ModuleProcesser.api_parent_node[tid] = full_name
|
|
@@ -186,6 +186,8 @@ class FuzzHandler(ABC):
|
|
|
186
186
|
ratio = self.ratio_calculate(
|
|
187
187
|
origin_output, perturbed_output, norm_type=NormType.ENDLESS_NORM
|
|
188
188
|
)
|
|
189
|
+
if threshold == 0:
|
|
190
|
+
raise ValueError("Threshold cannot be zero. Check `get_threshold` implementation.")
|
|
189
191
|
if ratio == ThresholdConfig.SYMBOL_FLIPPING:
|
|
190
192
|
is_consistent = False
|
|
191
193
|
else:
|
|
@@ -22,20 +22,19 @@ import torch.nn as nn
|
|
|
22
22
|
import torch.utils.hooks as full_hooks
|
|
23
23
|
|
|
24
24
|
from msprobe.core.common.runtime import Runtime
|
|
25
|
-
from msprobe.
|
|
25
|
+
from msprobe.core.common.utils import ThreadSafe
|
|
26
|
+
from msprobe.pytorch.common.utils import register_forward_pre_hook, register_forward_hook
|
|
26
27
|
|
|
27
28
|
|
|
28
29
|
class HOOKModule(nn.Module):
|
|
29
30
|
module_count = defaultdict(int)
|
|
30
|
-
inner_stop_hook =
|
|
31
|
+
inner_stop_hook = defaultdict(bool)
|
|
31
32
|
|
|
32
33
|
def __init__(self, hook_build_func) -> None:
|
|
33
34
|
super(HOOKModule, self).__init__()
|
|
34
35
|
self.has_overflow = False
|
|
35
|
-
self.
|
|
36
|
-
|
|
37
|
-
HOOKModule.inner_stop_hook[self.current_thread] = False
|
|
38
|
-
self.stop_hook = HOOKModule.inner_stop_hook.get(self.current_thread, False)
|
|
36
|
+
self.tid = threading.get_ident()
|
|
37
|
+
self.stop_hook = HOOKModule.inner_stop_hook.get(self.tid, False)
|
|
39
38
|
|
|
40
39
|
if not self.stop_hook:
|
|
41
40
|
self.forward_data_collected = False
|
|
@@ -43,6 +42,7 @@ class HOOKModule(nn.Module):
|
|
|
43
42
|
if not Runtime.is_running:
|
|
44
43
|
return
|
|
45
44
|
prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
|
|
45
|
+
ThreadSafe.acquire()
|
|
46
46
|
if callable(hook_build_func):
|
|
47
47
|
hook_set = hook_build_func(prefix)
|
|
48
48
|
register_forward_pre_hook(self, hook_set.forward_pre_hook)
|
|
@@ -52,11 +52,11 @@ class HOOKModule(nn.Module):
|
|
|
52
52
|
def __call__(self, *args, **kwargs):
|
|
53
53
|
changed = False
|
|
54
54
|
if not self.stop_hook:
|
|
55
|
-
HOOKModule.inner_stop_hook[self.
|
|
55
|
+
HOOKModule.inner_stop_hook[self.tid] = True
|
|
56
56
|
changed = True
|
|
57
57
|
result = self._call_func(*args, **kwargs)
|
|
58
58
|
if changed:
|
|
59
|
-
HOOKModule.inner_stop_hook[self.
|
|
59
|
+
HOOKModule.inner_stop_hook[self.tid] = False
|
|
60
60
|
return result
|
|
61
61
|
|
|
62
62
|
@staticmethod
|
|
@@ -104,7 +104,7 @@ class HOOKModule(nn.Module):
|
|
|
104
104
|
else:
|
|
105
105
|
return result
|
|
106
106
|
|
|
107
|
-
if
|
|
107
|
+
if not (var.requires_grad and torch.is_grad_enabled()):
|
|
108
108
|
return result
|
|
109
109
|
|
|
110
110
|
grad_fn = var.grad_fn
|
|
@@ -23,7 +23,7 @@ from msprobe.pytorch.common.utils import is_recomputation, torch_version_above_o
|
|
|
23
23
|
from msprobe.pytorch.hook_module.hook_module import HOOKModule
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
class PytorchHookManager(BaseHookManager):
|
|
26
|
+
class PytorchHookManager(BaseHookManager):
|
|
27
27
|
@property
|
|
28
28
|
def _is_recompute(self):
|
|
29
29
|
return is_recomputation()
|
|
@@ -41,7 +41,7 @@ class PytorchHookManager(BaseHookManager):
|
|
|
41
41
|
kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {}
|
|
42
42
|
output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
|
|
43
43
|
return kwargs, output
|
|
44
|
-
|
|
44
|
+
|
|
45
45
|
def build_hook(self, hook_type, name):
|
|
46
46
|
if hook_type == Const.API:
|
|
47
47
|
full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD
|
|
@@ -51,10 +51,10 @@ class PytorchHookManager(BaseHookManager):
|
|
|
51
51
|
hookset = HookSet(
|
|
52
52
|
forward_hook=self._build_forward_hook(hook_type, full_forward_name),
|
|
53
53
|
forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name),
|
|
54
|
-
backward_hook=self._build_backward_hook(hook_type, full_backward_name)
|
|
54
|
+
backward_hook=self._build_backward_hook(hook_type, full_backward_name)
|
|
55
55
|
)
|
|
56
56
|
return hookset
|
|
57
|
-
|
|
57
|
+
|
|
58
58
|
def _need_exchange(self, module):
|
|
59
59
|
return True
|
|
60
60
|
|
|
@@ -62,7 +62,7 @@ class PytorchHookManager(BaseHookManager):
|
|
|
62
62
|
params_dict = {}
|
|
63
63
|
if self.config.task != Const.STRUCTURE:
|
|
64
64
|
params_dict = {
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
65
|
+
key.split(Const.SEP)[-1]: value
|
|
66
|
+
for key, value in module.named_parameters(recurse=False)
|
|
67
|
+
}
|
|
68
68
|
return params_dict
|
|
@@ -23,17 +23,17 @@ from tqdm import tqdm
|
|
|
23
23
|
|
|
24
24
|
from msprobe.core.common.const import MonitorConst
|
|
25
25
|
from msprobe.core.common.file_utils import read_csv, create_directory, remove_path, recursive_chmod
|
|
26
|
-
from msprobe.core.common.utils import
|
|
26
|
+
from msprobe.core.common.utils import check_process_num
|
|
27
27
|
from msprobe.core.common.decorator import recursion_depth_decorator
|
|
28
|
+
from msprobe.core.monitor.utils import get_target_output_dir
|
|
28
29
|
from msprobe.pytorch.common.log import logger
|
|
29
|
-
|
|
30
|
+
|
|
30
31
|
|
|
31
32
|
all_data_type_list = [
|
|
32
33
|
"actv", "actv_grad", "exp_avg", "exp_avg_sq",
|
|
33
34
|
"grad_unreduced", "grad_reduced", "param_origin", "param_updated"
|
|
34
35
|
]
|
|
35
36
|
CSV_FILE_SUFFIX = r"_\d+-\d+\.csv"
|
|
36
|
-
MAX_PROCESS_NUM = 128
|
|
37
37
|
|
|
38
38
|
|
|
39
39
|
def parse_step_line(line, ops):
|
|
@@ -119,13 +119,6 @@ def csv2tb_by_step_work(target_output_dirs, output_dirpath, data_type_list):
|
|
|
119
119
|
write_step(output_dirpath, all_step_result, rank, data_type)
|
|
120
120
|
|
|
121
121
|
|
|
122
|
-
def check_process_num(process_num):
|
|
123
|
-
if not is_int(process_num) or process_num <= 0:
|
|
124
|
-
raise ValueError(f"process_num({process_num}) is not a positive integer")
|
|
125
|
-
if process_num > MAX_PROCESS_NUM:
|
|
126
|
-
raise ValueError(f"The maximum supported process_num is {MAX_PROCESS_NUM}, current value: {process_num}.")
|
|
127
|
-
|
|
128
|
-
|
|
129
122
|
def check_data_type_list(data_type_list):
|
|
130
123
|
if data_type_list is None:
|
|
131
124
|
logger.info(f"data_type_list is None, use default all_data_type_list: {all_data_type_list}")
|
|
@@ -45,13 +45,18 @@ def get_max(x: torch.tensor):
|
|
|
45
45
|
|
|
46
46
|
@torch.no_grad()
|
|
47
47
|
def get_zeros(x: torch.tensor, eps: float):
|
|
48
|
+
if x.numel() == 0:
|
|
49
|
+
return torch.tensor(float('nan'))
|
|
48
50
|
return torch.sum(torch.abs(x) < eps) / x.numel()
|
|
49
51
|
|
|
50
52
|
|
|
51
53
|
@torch.no_grad()
|
|
52
54
|
def get_sign_matches(x: torch.tensor, y: torch.tensor):
|
|
55
|
+
if y.numel() == 0:
|
|
56
|
+
return torch.tensor(1.)
|
|
53
57
|
xs = x.sign()
|
|
54
58
|
ys = y.sign()
|
|
59
|
+
|
|
55
60
|
try:
|
|
56
61
|
same_direction_ratio = ((xs * ys).sum() / ys.numel() + 1) / 2
|
|
57
62
|
except RuntimeError as e:
|
|
@@ -31,8 +31,11 @@ from msprobe.core.common.decorator import recursion_depth_decorator
|
|
|
31
31
|
from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter
|
|
32
32
|
from msprobe.core.common.file_utils import write_df_to_csv
|
|
33
33
|
from msprobe.core.common.utils import analyze_api_call_stack
|
|
34
|
+
from msprobe.core.monitor.utils import validate_config, validate_ops, \
|
|
35
|
+
get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor
|
|
34
36
|
from msprobe.pytorch.common.log import logger
|
|
35
|
-
from msprobe.pytorch.common.utils import is_recomputation
|
|
37
|
+
from msprobe.pytorch.common.utils import is_recomputation
|
|
38
|
+
from msprobe.pytorch.monitor.utils import get_param_struct
|
|
36
39
|
from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput
|
|
37
40
|
from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
|
|
38
41
|
get_process_group
|
|
@@ -40,8 +43,6 @@ from msprobe.pytorch.monitor.features import get_sign_matches
|
|
|
40
43
|
from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
|
|
41
44
|
TensorMetrics, squash_param_name
|
|
42
45
|
from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
|
|
43
|
-
from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \
|
|
44
|
-
get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor
|
|
45
46
|
from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
|
|
46
47
|
|
|
47
48
|
|
|
@@ -592,7 +593,7 @@ class TrainerMon:
|
|
|
592
593
|
context.param_adam_update = mv_result.update
|
|
593
594
|
context.param_adam_ratio = mv_result.ratio
|
|
594
595
|
|
|
595
|
-
self.generate_wgrad_metrics(grad_dict)
|
|
596
|
+
_, _ = self.generate_wgrad_metrics(grad_dict)
|
|
596
597
|
self.generate_mv_metrics(context)
|
|
597
598
|
self.generate_param_metrics(context, MonitorConst.PRE_PARAM)
|
|
598
599
|
|
|
@@ -763,7 +764,7 @@ class TrainerMon:
|
|
|
763
764
|
def clone_if_tensor(args):
|
|
764
765
|
if isinstance(args, tuple):
|
|
765
766
|
return tuple([clone_if_tensor(arg) for arg in args])
|
|
766
|
-
elif isinstance(args, torch.Tensor)
|
|
767
|
+
elif isinstance(args, torch.Tensor):
|
|
767
768
|
return args.clone()
|
|
768
769
|
else:
|
|
769
770
|
return args
|
|
@@ -1170,8 +1171,6 @@ class TrainerMon:
|
|
|
1170
1171
|
grad = param.main_grad
|
|
1171
1172
|
else:
|
|
1172
1173
|
grad = param.grad
|
|
1173
|
-
if is_float8_tensor(grad):
|
|
1174
|
-
grad = grad.float()
|
|
1175
1174
|
context_dict[key] = grad.clone()
|
|
1176
1175
|
|
|
1177
1176
|
if param.micro_step == self.micro_batch_number:
|
|
@@ -16,7 +16,6 @@ import re
|
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
|
-
from msprobe.pytorch.common.utils import is_float8_tensor
|
|
20
19
|
from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
|
|
21
20
|
from msprobe.pytorch.monitor.utils import get_nan_tensor
|
|
22
21
|
|
|
@@ -181,8 +180,6 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None):
|
|
|
181
180
|
# Non-tensor in/output filled with nan.
|
|
182
181
|
out_dict[tag].update({metric_name: get_nan_tensor() for metric_name in ops})
|
|
183
182
|
continue
|
|
184
|
-
if is_float8_tensor(tensor):
|
|
185
|
-
tensor = tensor.float()
|
|
186
183
|
for metric_name in ops:
|
|
187
184
|
fun_metric = config_metric_registry.get(metric_name)
|
|
188
185
|
out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
|
|
@@ -17,7 +17,7 @@ from abc import abstractmethod
|
|
|
17
17
|
import torch
|
|
18
18
|
|
|
19
19
|
from msprobe.pytorch.common.log import logger
|
|
20
|
-
from msprobe.
|
|
20
|
+
from msprobe.core.monitor.utils import MVResult
|
|
21
21
|
from msprobe.core.common.const import MonitorConst
|
|
22
22
|
|
|
23
23
|
|