mindstudio-probe 8.1.1__py3-none-any.whl → 8.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/METADATA +1 -1
  2. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/RECORD +95 -94
  3. msprobe/core/common/const.py +3 -0
  4. msprobe/core/common/file_utils.py +45 -5
  5. msprobe/core/common/utils.py +117 -13
  6. msprobe/core/common_config.py +15 -1
  7. msprobe/core/compare/acc_compare.py +21 -9
  8. msprobe/core/compare/compare_cli.py +10 -2
  9. msprobe/core/compare/merge_result/merge_result.py +1 -1
  10. msprobe/core/compare/utils.py +8 -2
  11. msprobe/core/config_check/checkers/base_checker.py +2 -0
  12. msprobe/core/config_check/checkers/hyperparameter_checker.py +5 -4
  13. msprobe/core/config_check/ckpt_compare/ckpt_comparator.py +4 -1
  14. msprobe/core/config_check/config_check_cli.py +1 -1
  15. msprobe/core/config_check/config_checker.py +1 -2
  16. msprobe/core/data_dump/data_collector.py +4 -1
  17. msprobe/core/data_dump/data_processor/mindspore_processor.py +23 -1
  18. msprobe/core/data_dump/data_processor/pytorch_processor.py +3 -25
  19. msprobe/core/debugger/precision_debugger.py +13 -8
  20. msprobe/core/hook_manager.py +112 -82
  21. msprobe/core/monitor/utils.py +338 -0
  22. msprobe/core/service.py +2 -1
  23. msprobe/core/single_save/single_comparator.py +5 -3
  24. msprobe/docs/01.installation.md +1 -0
  25. msprobe/docs/05.data_dump_PyTorch.md +4 -4
  26. msprobe/docs/07.accuracy_checker_PyTorch.md +14 -11
  27. msprobe/docs/09.accuracy_checker_MindSpore.md +13 -11
  28. msprobe/docs/10.accuracy_compare_PyTorch.md +3 -1
  29. msprobe/docs/11.accuracy_compare_MindSpore.md +4 -2
  30. msprobe/docs/12.overflow_check_PyTorch.md +3 -2
  31. msprobe/docs/13.overflow_check_MindSpore.md +1 -1
  32. msprobe/docs/14.data_parse_PyTorch.md +35 -32
  33. msprobe/docs/21.visualization_PyTorch.md +9 -8
  34. msprobe/docs/22.visualization_MindSpore.md +1 -0
  35. msprobe/docs/23.generate_operator_PyTorch.md +1 -1
  36. msprobe/docs/24.code_mapping_Mindspore.md +6 -5
  37. msprobe/docs/31.config_check.md +15 -5
  38. msprobe/docs/33.generate_operator_MindSpore.md +2 -2
  39. msprobe/docs/34.RL_collect.md +18 -9
  40. msprobe/docs/35.nan_analyze.md +4 -3
  41. msprobe/docs/FAQ.md +3 -0
  42. msprobe/docs/img/ms_layer.png +0 -0
  43. msprobe/mindspore/api_accuracy_checker/api_runner.py +29 -1
  44. msprobe/mindspore/cell_processor.py +35 -14
  45. msprobe/mindspore/code_mapping/bind.py +23 -4
  46. msprobe/mindspore/code_mapping/graph_parser.py +6 -4
  47. msprobe/mindspore/common/utils.py +3 -0
  48. msprobe/mindspore/compare/common_dir_compare.py +32 -12
  49. msprobe/mindspore/compare/ms_graph_compare.py +7 -2
  50. msprobe/mindspore/compare/utils.py +9 -1
  51. msprobe/mindspore/debugger/debugger_config.py +13 -11
  52. msprobe/mindspore/debugger/precision_debugger.py +67 -45
  53. msprobe/mindspore/dump/dump_tool_factory.py +2 -0
  54. msprobe/mindspore/dump/hook_cell/hook_cell.py +14 -9
  55. msprobe/mindspore/dump/hook_cell/ms_hook_manager.py +12 -7
  56. msprobe/mindspore/dump/hook_cell/primitive_hooks.py +27 -13
  57. msprobe/mindspore/dump/jit_dump.py +6 -3
  58. msprobe/mindspore/dump/kernel_kbyk_dump.py +13 -6
  59. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +6 -5
  60. msprobe/mindspore/dym_loader/hook_dynamic_loader.h +2 -2
  61. msprobe/mindspore/grad_probe/grad_analyzer.py +2 -0
  62. msprobe/mindspore/mindspore_service.py +2 -2
  63. msprobe/mindspore/monitor/common_func.py +1 -1
  64. msprobe/mindspore/monitor/module_hook.py +3 -3
  65. msprobe/mindspore/monitor/utils.py +0 -252
  66. msprobe/mindspore/ms_config.py +0 -1
  67. msprobe/mindspore/overflow_check/overflow_check_tool_factory.py +1 -1
  68. msprobe/nan_analyze/graph.py +4 -0
  69. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +15 -6
  70. msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +1 -1
  71. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py +1 -1
  72. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -4
  73. msprobe/pytorch/common/utils.py +0 -16
  74. msprobe/pytorch/compare/pt_compare.py +5 -0
  75. msprobe/pytorch/debugger/debugger_config.py +12 -5
  76. msprobe/pytorch/debugger/precision_debugger.py +8 -1
  77. msprobe/pytorch/dump/module_dump/hook_wrapper.py +1 -3
  78. msprobe/pytorch/dump/module_dump/module_processer.py +44 -13
  79. msprobe/pytorch/free_benchmark/result_handlers/base_handler.py +2 -0
  80. msprobe/pytorch/hook_module/hook_module.py +9 -9
  81. msprobe/pytorch/hook_module/pt_hook_manager.py +7 -7
  82. msprobe/pytorch/monitor/csv2tb.py +3 -10
  83. msprobe/pytorch/monitor/features.py +5 -0
  84. msprobe/pytorch/monitor/module_hook.py +6 -7
  85. msprobe/pytorch/monitor/module_metric.py +0 -3
  86. msprobe/pytorch/monitor/optimizer_collect.py +1 -1
  87. msprobe/pytorch/monitor/utils.py +1 -317
  88. msprobe/pytorch/online_dispatch/dispatch.py +1 -1
  89. msprobe/pytorch/online_dispatch/dump_compare.py +7 -1
  90. msprobe/pytorch/parse_tool/lib/utils.py +2 -4
  91. msprobe/visualization/graph_service.py +1 -1
  92. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/LICENSE +0 -0
  93. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/WHEEL +0 -0
  94. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/entry_points.txt +0 -0
  95. {mindstudio_probe-8.1.1.dist-info → mindstudio_probe-8.1.2.dist-info}/top_level.txt +0 -0
@@ -13,12 +13,15 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ import threading
17
+ import sys
16
18
  from collections import OrderedDict
17
19
 
18
20
  import torch
19
21
  from torch.utils.hooks import BackwardHook, RemovableHandle
20
22
 
21
23
  from msprobe.core.common.const import Const
24
+ from msprobe.core.common.utils import ModuleQueue, ThreadSafe
22
25
  from msprobe.core.data_dump.scope import BaseScope, ModuleRangeScope, MixRangeScope
23
26
  from msprobe.pytorch.common.log import logger
24
27
  from msprobe.pytorch.common.utils import is_torch_nn_module, register_forward_pre_hook
@@ -46,13 +49,15 @@ def wrap_megatron_deallocate(func):
46
49
  out.data = torch.empty((1,), device=out.device, dtype=out.dtype, )
47
50
  return func(out_clone, deallocate_pipeline_outputs)
48
51
  return func(out, deallocate_pipeline_outputs)
52
+
49
53
  return wrapper_func
50
54
 
51
55
 
52
56
  class ModuleProcesser:
57
+ module_queue = ModuleQueue()
53
58
  module_count = {}
54
- module_stack = []
55
- api_parent_node = ""
59
+ module_stack = {}
60
+ api_parent_node = {}
56
61
  module_node = {}
57
62
  module_bw_hook_kernels = {}
58
63
  module_with_backward_hook = {}
@@ -64,7 +69,15 @@ class ModuleProcesser:
64
69
  replace_checkpoint()
65
70
  try:
66
71
  from megatron.core.pipeline_parallel import schedules
72
+ origin_func_id = id(schedules.deallocate_output_tensor)
67
73
  schedules.deallocate_output_tensor = wrap_megatron_deallocate(schedules.deallocate_output_tensor)
74
+ for module in list(sys.modules.values()):
75
+ if module.__name__ == 'schedules':
76
+ continue
77
+ for func in module.__dict__:
78
+ if id(module.__dict__[func]) == origin_func_id:
79
+ module.__setattr__(func, schedules.deallocate_output_tensor)
80
+ logger.debug(f'patch {module.__name__}.{func}.')
68
81
  logger.info_on_rank_0("Patch megatron method success.")
69
82
  except ImportError:
70
83
  logger.info_on_rank_0("No megatron find.")
@@ -103,9 +116,10 @@ class ModuleProcesser:
103
116
 
104
117
  @classmethod
105
118
  def reset_module_stats(cls):
119
+ cls.module_queue = ModuleQueue()
106
120
  cls.module_count = {}
107
- cls.module_stack = []
108
- cls.api_parent_node = ""
121
+ cls.module_stack = {}
122
+ cls.api_parent_node = {}
109
123
  cls.module_node = {}
110
124
  cls.module_bw_hook_kernels = {}
111
125
  cls.enable_module_dump = False
@@ -144,6 +158,7 @@ class ModuleProcesser:
144
158
  register_forward_pre_hook(module, forward_pre_hook)
145
159
 
146
160
  def build_module_hook(self, module_name, build_data_hook):
161
+ @ThreadSafe.synchronized
147
162
  def forward_pre_hook(module, args, kwargs=None):
148
163
  if kwargs is None:
149
164
  kwargs = {}
@@ -171,15 +186,19 @@ class ModuleProcesser:
171
186
  hook_set = build_data_hook(BaseScope.Module_Type_Module, full_forward_name)
172
187
 
173
188
  def get_backward_pre_hook(full_backward_name):
189
+ @ThreadSafe.synchronized
174
190
  def backward_pre_hook_fn(module, grad_output):
175
191
  self.set_construct_info_in_pre_hook(full_backward_name)
192
+
176
193
  return backward_pre_hook_fn
177
194
 
178
195
  def get_backward_hook(backward_data_hook, full_backward_name):
196
+ @ThreadSafe.synchronized
179
197
  def backward_hook_fn(module, grad_input, grad_output):
180
198
  new_output = backward_data_hook(module, grad_input, grad_output)
181
199
  self.set_construct_info_in_hook(full_backward_name, is_forward=False)
182
200
  return new_output
201
+
183
202
  return backward_hook_fn
184
203
 
185
204
  if not ModuleProcesser.module_with_backward_hook.get(module_name):
@@ -193,6 +212,7 @@ class ModuleProcesser:
193
212
  args = bw_hook.setup_input_hook(args)
194
213
  return (args, kwargs) if torch_version_above_or_equal_2 else args
195
214
 
215
+ @ThreadSafe.synchronized
196
216
  def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None):
197
217
  if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump:
198
218
  return output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
@@ -218,23 +238,34 @@ class ModuleProcesser:
218
238
  return forward_pre_hook
219
239
 
220
240
  def set_construct_info_in_pre_hook(self, full_name):
221
- if self.module_stack:
222
- ModuleProcesser.module_node[full_name] = self.module_stack[-1]
241
+ tid = threading.get_ident()
242
+ if tid not in self.module_stack:
243
+ ModuleProcesser.module_stack[tid] = []
244
+
245
+ if self.module_stack[tid]:
246
+ ModuleProcesser.module_node[full_name] = self.module_stack[tid][-1]
223
247
  else:
224
- ModuleProcesser.module_node[full_name] = None
225
- ModuleProcesser.module_stack.append(full_name)
226
- ModuleProcesser.api_parent_node = full_name
248
+ parent_name = ModuleProcesser.module_queue.find_last(full_name)
249
+ ModuleProcesser.module_node[full_name] = parent_name
250
+
251
+ ModuleProcesser.module_queue.add_name(full_name)
252
+ ModuleProcesser.module_stack[tid].append(full_name)
253
+ ModuleProcesser.api_parent_node[tid] = full_name
227
254
  if self.scope:
228
255
  self.scope.begin_module(full_name)
229
256
 
230
257
  def set_construct_info_in_hook(self, full_name, is_forward=True):
258
+ tid = threading.get_ident()
231
259
  if torch_version_above_or_equal_2 or is_forward:
232
- if self.module_stack:
233
- ModuleProcesser.module_stack.pop()
234
- ModuleProcesser.api_parent_node = ModuleProcesser.module_stack[-1] if self.module_stack else None
260
+ ModuleProcesser.module_queue.remove_name(full_name)
261
+ ModuleProcesser.api_parent_node[tid] = None
262
+ if self.module_stack.get(tid):
263
+ ModuleProcesser.module_stack[tid].pop()
264
+ if self.module_stack.get(tid):
265
+ ModuleProcesser.api_parent_node[tid] = ModuleProcesser.module_stack[tid][-1]
235
266
  if self.scope:
236
267
  self.scope.end_module(full_name)
237
268
  else:
238
269
  if self.scope:
239
270
  self.scope.begin_module(full_name)
240
- ModuleProcesser.api_parent_node = full_name
271
+ ModuleProcesser.api_parent_node[tid] = full_name
@@ -186,6 +186,8 @@ class FuzzHandler(ABC):
186
186
  ratio = self.ratio_calculate(
187
187
  origin_output, perturbed_output, norm_type=NormType.ENDLESS_NORM
188
188
  )
189
+ if threshold == 0:
190
+ raise ValueError("Threshold cannot be zero. Check `get_threshold` implementation.")
189
191
  if ratio == ThresholdConfig.SYMBOL_FLIPPING:
190
192
  is_consistent = False
191
193
  else:
@@ -22,20 +22,19 @@ import torch.nn as nn
22
22
  import torch.utils.hooks as full_hooks
23
23
 
24
24
  from msprobe.core.common.runtime import Runtime
25
- from msprobe.pytorch.common.utils import is_float8_tensor, register_forward_pre_hook, register_forward_hook
25
+ from msprobe.core.common.utils import ThreadSafe
26
+ from msprobe.pytorch.common.utils import register_forward_pre_hook, register_forward_hook
26
27
 
27
28
 
28
29
  class HOOKModule(nn.Module):
29
30
  module_count = defaultdict(int)
30
- inner_stop_hook = {}
31
+ inner_stop_hook = defaultdict(bool)
31
32
 
32
33
  def __init__(self, hook_build_func) -> None:
33
34
  super(HOOKModule, self).__init__()
34
35
  self.has_overflow = False
35
- self.current_thread = threading.current_thread().ident
36
- if self.current_thread not in HOOKModule.inner_stop_hook:
37
- HOOKModule.inner_stop_hook[self.current_thread] = False
38
- self.stop_hook = HOOKModule.inner_stop_hook.get(self.current_thread, False)
36
+ self.tid = threading.get_ident()
37
+ self.stop_hook = HOOKModule.inner_stop_hook.get(self.tid, False)
39
38
 
40
39
  if not self.stop_hook:
41
40
  self.forward_data_collected = False
@@ -43,6 +42,7 @@ class HOOKModule(nn.Module):
43
42
  if not Runtime.is_running:
44
43
  return
45
44
  prefix = self.prefix_api_name if hasattr(self, "prefix_api_name") else ""
45
+ ThreadSafe.acquire()
46
46
  if callable(hook_build_func):
47
47
  hook_set = hook_build_func(prefix)
48
48
  register_forward_pre_hook(self, hook_set.forward_pre_hook)
@@ -52,11 +52,11 @@ class HOOKModule(nn.Module):
52
52
  def __call__(self, *args, **kwargs):
53
53
  changed = False
54
54
  if not self.stop_hook:
55
- HOOKModule.inner_stop_hook[self.current_thread] = True
55
+ HOOKModule.inner_stop_hook[self.tid] = True
56
56
  changed = True
57
57
  result = self._call_func(*args, **kwargs)
58
58
  if changed:
59
- HOOKModule.inner_stop_hook[self.current_thread] = False
59
+ HOOKModule.inner_stop_hook[self.tid] = False
60
60
  return result
61
61
 
62
62
  @staticmethod
@@ -104,7 +104,7 @@ class HOOKModule(nn.Module):
104
104
  else:
105
105
  return result
106
106
 
107
- if is_float8_tensor(var) or not (var.requires_grad and torch.is_grad_enabled()):
107
+ if not (var.requires_grad and torch.is_grad_enabled()):
108
108
  return result
109
109
 
110
110
  grad_fn = var.grad_fn
@@ -23,7 +23,7 @@ from msprobe.pytorch.common.utils import is_recomputation, torch_version_above_o
23
23
  from msprobe.pytorch.hook_module.hook_module import HOOKModule
24
24
 
25
25
 
26
- class PytorchHookManager(BaseHookManager):
26
+ class PytorchHookManager(BaseHookManager):
27
27
  @property
28
28
  def _is_recompute(self):
29
29
  return is_recomputation()
@@ -41,7 +41,7 @@ class PytorchHookManager(BaseHookManager):
41
41
  kwargs = kwargs_or_output if torch_version_above_or_equal_2 else {}
42
42
  output = output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output
43
43
  return kwargs, output
44
-
44
+
45
45
  def build_hook(self, hook_type, name):
46
46
  if hook_type == Const.API:
47
47
  full_forward_name = name + str(HOOKModule.get_module_count(name)) + Const.SEP + Const.FORWARD
@@ -51,10 +51,10 @@ class PytorchHookManager(BaseHookManager):
51
51
  hookset = HookSet(
52
52
  forward_hook=self._build_forward_hook(hook_type, full_forward_name),
53
53
  forward_pre_hook=self._build_forward_pre_hook(hook_type, full_forward_name, name),
54
- backward_hook=self._build_backward_hook(hook_type, full_backward_name)
54
+ backward_hook=self._build_backward_hook(hook_type, full_backward_name)
55
55
  )
56
56
  return hookset
57
-
57
+
58
58
  def _need_exchange(self, module):
59
59
  return True
60
60
 
@@ -62,7 +62,7 @@ class PytorchHookManager(BaseHookManager):
62
62
  params_dict = {}
63
63
  if self.config.task != Const.STRUCTURE:
64
64
  params_dict = {
65
- key.split(Const.SEP)[-1]: value
66
- for key, value in module.named_parameters(recurse=False)
67
- }
65
+ key.split(Const.SEP)[-1]: value
66
+ for key, value in module.named_parameters(recurse=False)
67
+ }
68
68
  return params_dict
@@ -23,17 +23,17 @@ from tqdm import tqdm
23
23
 
24
24
  from msprobe.core.common.const import MonitorConst
25
25
  from msprobe.core.common.file_utils import read_csv, create_directory, remove_path, recursive_chmod
26
- from msprobe.core.common.utils import is_int
26
+ from msprobe.core.common.utils import check_process_num
27
27
  from msprobe.core.common.decorator import recursion_depth_decorator
28
+ from msprobe.core.monitor.utils import get_target_output_dir
28
29
  from msprobe.pytorch.common.log import logger
29
- from msprobe.pytorch.monitor.utils import get_target_output_dir
30
+
30
31
 
31
32
  all_data_type_list = [
32
33
  "actv", "actv_grad", "exp_avg", "exp_avg_sq",
33
34
  "grad_unreduced", "grad_reduced", "param_origin", "param_updated"
34
35
  ]
35
36
  CSV_FILE_SUFFIX = r"_\d+-\d+\.csv"
36
- MAX_PROCESS_NUM = 128
37
37
 
38
38
 
39
39
  def parse_step_line(line, ops):
@@ -119,13 +119,6 @@ def csv2tb_by_step_work(target_output_dirs, output_dirpath, data_type_list):
119
119
  write_step(output_dirpath, all_step_result, rank, data_type)
120
120
 
121
121
 
122
- def check_process_num(process_num):
123
- if not is_int(process_num) or process_num <= 0:
124
- raise ValueError(f"process_num({process_num}) is not a positive integer")
125
- if process_num > MAX_PROCESS_NUM:
126
- raise ValueError(f"The maximum supported process_num is {MAX_PROCESS_NUM}, current value: {process_num}.")
127
-
128
-
129
122
  def check_data_type_list(data_type_list):
130
123
  if data_type_list is None:
131
124
  logger.info(f"data_type_list is None, use default all_data_type_list: {all_data_type_list}")
@@ -45,13 +45,18 @@ def get_max(x: torch.tensor):
45
45
 
46
46
  @torch.no_grad()
47
47
  def get_zeros(x: torch.tensor, eps: float):
48
+ if x.numel() == 0:
49
+ return torch.tensor(float('nan'))
48
50
  return torch.sum(torch.abs(x) < eps) / x.numel()
49
51
 
50
52
 
51
53
  @torch.no_grad()
52
54
  def get_sign_matches(x: torch.tensor, y: torch.tensor):
55
+ if y.numel() == 0:
56
+ return torch.tensor(1.)
53
57
  xs = x.sign()
54
58
  ys = y.sign()
59
+
55
60
  try:
56
61
  same_direction_ratio = ((xs * ys).sum() / ys.numel() + 1) / 2
57
62
  except RuntimeError as e:
@@ -31,8 +31,11 @@ from msprobe.core.common.decorator import recursion_depth_decorator
31
31
  from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter
32
32
  from msprobe.core.common.file_utils import write_df_to_csv
33
33
  from msprobe.core.common.utils import analyze_api_call_stack
34
+ from msprobe.core.monitor.utils import validate_config, validate_ops, \
35
+ get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor
34
36
  from msprobe.pytorch.common.log import logger
35
- from msprobe.pytorch.common.utils import is_recomputation, is_float8_tensor
37
+ from msprobe.pytorch.common.utils import is_recomputation
38
+ from msprobe.pytorch.monitor.utils import get_param_struct
36
39
  from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput
37
40
  from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \
38
41
  get_process_group
@@ -40,8 +43,6 @@ from msprobe.pytorch.monitor.features import get_sign_matches
40
43
  from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \
41
44
  TensorMetrics, squash_param_name
42
45
  from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory
43
- from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \
44
- get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor
45
46
  from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer
46
47
 
47
48
 
@@ -592,7 +593,7 @@ class TrainerMon:
592
593
  context.param_adam_update = mv_result.update
593
594
  context.param_adam_ratio = mv_result.ratio
594
595
 
595
- self.generate_wgrad_metrics(grad_dict)
596
+ _, _ = self.generate_wgrad_metrics(grad_dict)
596
597
  self.generate_mv_metrics(context)
597
598
  self.generate_param_metrics(context, MonitorConst.PRE_PARAM)
598
599
 
@@ -763,7 +764,7 @@ class TrainerMon:
763
764
  def clone_if_tensor(args):
764
765
  if isinstance(args, tuple):
765
766
  return tuple([clone_if_tensor(arg) for arg in args])
766
- elif isinstance(args, torch.Tensor) and not is_float8_tensor(args):
767
+ elif isinstance(args, torch.Tensor):
767
768
  return args.clone()
768
769
  else:
769
770
  return args
@@ -1170,8 +1171,6 @@ class TrainerMon:
1170
1171
  grad = param.main_grad
1171
1172
  else:
1172
1173
  grad = param.grad
1173
- if is_float8_tensor(grad):
1174
- grad = grad.float()
1175
1174
  context_dict[key] = grad.clone()
1176
1175
 
1177
1176
  if param.micro_step == self.micro_batch_number:
@@ -16,7 +16,6 @@ import re
16
16
 
17
17
  import torch
18
18
 
19
- from msprobe.pytorch.common.utils import is_float8_tensor
20
19
  from msprobe.pytorch.monitor.features import get_max, get_min, get_zeros, get_nans, get_norm, get_mean
21
20
  from msprobe.pytorch.monitor.utils import get_nan_tensor
22
21
 
@@ -181,8 +180,6 @@ def get_metrics(ops, tag2tensor, eps, out_dict=None):
181
180
  # Non-tensor in/output filled with nan.
182
181
  out_dict[tag].update({metric_name: get_nan_tensor() for metric_name in ops})
183
182
  continue
184
- if is_float8_tensor(tensor):
185
- tensor = tensor.float()
186
183
  for metric_name in ops:
187
184
  fun_metric = config_metric_registry.get(metric_name)
188
185
  out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps)
@@ -17,7 +17,7 @@ from abc import abstractmethod
17
17
  import torch
18
18
 
19
19
  from msprobe.pytorch.common.log import logger
20
- from msprobe.pytorch.monitor.utils import MVResult
20
+ from msprobe.core.monitor.utils import MVResult
21
21
  from msprobe.core.common.const import MonitorConst
22
22
 
23
23