mindstudio-probe 8.1.0__py3-none-any.whl → 8.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +3 -2
  2. {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +46 -47
  3. msprobe/core/common/const.py +1 -0
  4. msprobe/core/common/file_utils.py +36 -18
  5. msprobe/core/common/utils.py +19 -8
  6. msprobe/core/compare/acc_compare.py +14 -5
  7. msprobe/core/compare/utils.py +7 -1
  8. msprobe/core/data_dump/data_collector.py +144 -90
  9. msprobe/core/data_dump/json_writer.py +31 -1
  10. msprobe/core/debugger/precision_debugger.py +19 -18
  11. msprobe/core/service.py +1 -0
  12. msprobe/core/single_save/single_comparator.py +25 -25
  13. msprobe/core/single_save/single_saver.py +5 -16
  14. msprobe/docs/01.installation.md +1 -0
  15. msprobe/docs/05.data_dump_PyTorch.md +3 -0
  16. msprobe/docs/06.data_dump_MindSpore.md +3 -0
  17. msprobe/docs/08.accuracy_checker_online_PyTorch.md +2 -2
  18. msprobe/docs/25.tool_function_introduction.md +19 -19
  19. msprobe/docs/33.generate_operator_MindSpore.md +10 -19
  20. msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -0
  21. msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
  22. msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +10 -1
  23. msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
  24. msprobe/mindspore/common/utils.py +1 -0
  25. msprobe/mindspore/debugger/precision_debugger.py +4 -4
  26. msprobe/mindspore/dump/cell_dump_process.py +13 -38
  27. msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +1 -26
  28. msprobe/mindspore/dump/hook_cell/api_register.py +3 -3
  29. msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +4 -4
  30. msprobe/mindspore/mindspore_service.py +3 -0
  31. msprobe/mindspore/monitor/features.py +10 -9
  32. msprobe/mindspore/monitor/optimizer_collect.py +4 -1
  33. msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
  34. msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +7 -7
  35. msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -0
  36. msprobe/pytorch/common/utils.py +1 -1
  37. msprobe/pytorch/debugger/precision_debugger.py +28 -25
  38. msprobe/pytorch/hook_module/api_register.py +3 -3
  39. msprobe/pytorch/monitor/optimizer_collect.py +4 -1
  40. msprobe/pytorch/pytorch_service.py +3 -0
  41. msprobe/visualization/compare/mode_adapter.py +9 -0
  42. msprobe/visualization/utils.py +3 -0
  43. msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +0 -9
  44. {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
  45. {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
  46. {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
  47. {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
@@ -70,7 +70,7 @@ def split_json_file(input_file, num_splits, filter_api):
70
70
  split_forward_data = dict(items[start:end])
71
71
  temp_data = {
72
72
  **input_data,
73
- "data":{
73
+ "data": {
74
74
  **split_forward_data,
75
75
  **backward_data
76
76
  }
@@ -141,7 +141,7 @@ def run_parallel_ut(config):
141
141
 
142
142
  for api_info in config.api_files:
143
143
  cmd = create_cmd(api_info, next(device_id_cycle))
144
- process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
144
+ process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
145
145
  text=True, bufsize=1, shell=False)
146
146
  processes.append(process)
147
147
  threading.Thread(target=read_process_output, args=(process,), daemon=True).start()
@@ -187,8 +187,8 @@ def run_parallel_ut(config):
187
187
 
188
188
 
189
189
  def prepare_config(args):
190
- api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
191
- ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
190
+ api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
191
+ ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
192
192
  api_info = api_info_file_checker.common_check()
193
193
  out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
194
194
  create_directory(out_path)
@@ -197,11 +197,11 @@ def prepare_config(args):
197
197
  split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
198
198
  config_path = args.config_path if args.config_path else None
199
199
  if config_path:
200
- config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
200
+ config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
201
201
  FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
202
202
  config_path = config_path_checker.common_check()
203
203
  result_csv_path = args.result_csv_path or os.path.join(
204
- out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
204
+ out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
205
205
  if not args.result_csv_path:
206
206
  details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv")
207
207
  comparator = Comparator(result_csv_path, details_csv_path, False)
@@ -220,7 +220,7 @@ def main():
220
220
  signal.signal(signal.SIGTERM, signal_handler)
221
221
  parser = argparse.ArgumentParser(description='Run UT in parallel')
222
222
  _run_ut_parser(parser)
223
- parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
223
+ parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
224
224
  help='Number of splits for parallel processing. Range: 1-64')
225
225
  args = parser.parse_args()
226
226
  config = prepare_config(args)
@@ -12,6 +12,7 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
+ import gc
15
16
  import os
16
17
  from datetime import datetime, timezone
17
18
 
@@ -117,6 +118,7 @@ def load_ssl_pem(key_file, cert_file, ca_file, crl_file):
117
118
  with FileOpen(key_file, "rb") as f:
118
119
  key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read(), passphrase.encode())
119
120
  del passphrase
121
+ gc.collect()
120
122
  with FileOpen(cert_file, "rb") as f:
121
123
  crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
122
124
  check_crt_valid(crt)
@@ -406,7 +406,7 @@ def load_api_data(api_data_bytes):
406
406
  """Load data from bytes stream"""
407
407
  try:
408
408
  buffer = io.BytesIO(api_data_bytes)
409
- buffer = torch.load(buffer, map_location="cpu")
409
+ buffer = torch.load(buffer, map_location="cpu", weights_only=False)
410
410
  except Exception as e:
411
411
  raise RuntimeError("load api_data from bytes failed") from e
412
412
  return buffer
@@ -53,19 +53,36 @@ class PrecisionDebugger(BasePrecisionDebugger):
53
53
  self.module_dumper = ModuleDumper(self.service)
54
54
  self.ori_customer_func = {}
55
55
  self.enable_dataloader = self.config.enable_dataloader
56
- self.param_warning()
57
-
58
- @property
59
- def instance(self):
60
- return self._instance
56
+ self._param_warning()
61
57
 
62
58
  @staticmethod
63
- def get_task_config(task, json_config):
59
+ def _get_task_config(task, json_config):
64
60
  return parse_task_config(task, json_config)
65
61
 
62
+ @staticmethod
63
+ def _iter_tracer(func):
64
+ def func_wrapper(*args, **kwargs):
65
+ debugger_instance = PrecisionDebugger._instance
66
+ if not debugger_instance:
67
+ raise MsprobeException(
68
+ MsprobeException.INTERFACE_USAGE_ERROR,
69
+ f"PrecisionDebugger must be instantiated before executing the dataloader iteration"
70
+ )
71
+
72
+ debugger_instance.enable_dataloader = False
73
+ if not debugger_instance.service.first_start:
74
+ debugger_instance.stop()
75
+ debugger_instance.step()
76
+ result = func(*args, **kwargs)
77
+ debugger_instance.start()
78
+ debugger_instance.enable_dataloader = True
79
+ return result
80
+
81
+ return func_wrapper
82
+
66
83
  @classmethod
67
84
  def start(cls, model=None, token_range=None):
68
- instance = cls.get_instance()
85
+ instance = cls._get_instance()
69
86
  if instance is None:
70
87
  return
71
88
 
@@ -79,7 +96,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
79
96
 
80
97
  @classmethod
81
98
  def stop(cls):
82
- instance = cls.get_instance()
99
+ instance = cls._get_instance()
83
100
  if instance is None:
84
101
  return
85
102
  if instance.enable_dataloader:
@@ -89,7 +106,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
89
106
 
90
107
  @classmethod
91
108
  def step(cls):
92
- instance = cls.get_instance()
109
+ instance = cls._get_instance()
93
110
  if instance is None:
94
111
  return
95
112
  cls._instance.service.step()
@@ -115,7 +132,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
115
132
  return
116
133
  instance.service.save(variable, name, save_backward)
117
134
 
118
- def param_warning(self):
135
+ def _param_warning(self):
119
136
  if self.model is not None:
120
137
  logger.warning_on_rank_0(
121
138
  "The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
@@ -123,7 +140,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
123
140
  )
124
141
  if self.enable_dataloader:
125
142
  logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
126
- dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__)
143
+ dataloader._BaseDataLoaderIter.__next__ = self._iter_tracer(dataloader._BaseDataLoaderIter.__next__)
127
144
 
128
145
 
129
146
  def module_dump(module, dump_name):
@@ -155,17 +172,3 @@ def module_dump_end():
155
172
  f"PrecisionDebugger must be instantiated before using module_dump_end interface"
156
173
  )
157
174
  instance.module_dumper.stop_module_dump()
158
-
159
-
160
- def iter_tracer(func):
161
- def func_wrapper(*args, **kwargs):
162
- debugger_instance = PrecisionDebugger.instance
163
- debugger_instance.enable_dataloader = False
164
- if not debugger_instance.service.first_start:
165
- debugger_instance.stop()
166
- debugger_instance.step()
167
- result = func(*args, **kwargs)
168
- debugger_instance.start()
169
- debugger_instance.enable_dataloader = True
170
- return result
171
- return func_wrapper
@@ -89,12 +89,12 @@ def dist_module_forward(module, *args, **kwargs):
89
89
  try:
90
90
  bound = inspect.signature(module.api_func).bind(*args, **kwargs)
91
91
  bound.apply_defaults()
92
- use_asyn_op_flag = bound.arguments.get("asyn_op", False)
92
+ use_async_op_flag = bound.arguments.get("async_op", False)
93
93
  except Exception as e:
94
- use_asyn_op_flag = False
94
+ use_async_op_flag = False
95
95
  logger.warning(f"fail to get dist api's func signature because {e}, no wait")
96
96
 
97
- if use_asyn_op_flag or module.api_name in ["isend", "irecv"]:
97
+ if use_async_op_flag or module.api_name in ["isend", "irecv"]:
98
98
  if handle and hasattr(handle, 'wait'):
99
99
  handle.wait()
100
100
  if module.api_name == "batch_isend_irecv":
@@ -109,6 +109,9 @@ class OptimizerMon(object):
109
109
  else:
110
110
  logger.warning(f"step of {name} is None, maybe something wrong happened.")
111
111
  continue
112
+ if exp_avg is None or exp_avg_sq is None:
113
+ logger.warning(f"exp_avg or exp_avg_sq of {name} is None, skip calculation.")
114
+ continue
112
115
  exp_avg_hat = exp_avg / (1 - self.torch_opt.defaults['betas'][0] ** step)
113
116
  exp_avg_sq_hat = exp_avg_sq / (1 - self.torch_opt.defaults['betas'][1] ** step)
114
117
  update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + self.torch_opt.defaults['eps'])
@@ -296,7 +299,7 @@ class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon):
296
299
  self.fp32_flat_groups = torch_opt.fp32_partitioned_groups_flat
297
300
  self.param2group = self.get_group_index()
298
301
 
299
- def param_not_in_partition(self, param, group_index):
302
+ def param_not_in_partition(self, lp_param, group_idx):
300
303
  """Each param partioned across all zero ranks"""
301
304
  return False
302
305
 
@@ -37,6 +37,9 @@ class PytorchService(BaseService):
37
37
  @staticmethod
38
38
  def _get_current_rank():
39
39
  return get_rank_if_initialized()
40
+
41
+ def reset_status(self):
42
+ self._reset_status()
40
43
 
41
44
  def _init_specific_components(self):
42
45
  self.logger = logger
@@ -161,6 +161,7 @@ class ModeAdapter:
161
161
  else change_percentage
162
162
  precision_index = GraphConst.MAX_INDEX_KEY \
163
163
  if change_percentage > GraphConst.MAX_INDEX_KEY else change_percentage
164
+ precision_index = self._ignore_precision_index(node.id, precision_index)
164
165
  return precision_index, other_dict
165
166
 
166
167
  def prepare_real_data(self, node):
@@ -197,3 +198,11 @@ class ModeAdapter:
197
198
  CompareConst.MAX_ABS_ERR: ToolTip.MAX_ABS_ERR,
198
199
  CompareConst.MAX_RELATIVE_ERR: ToolTip.MAX_RELATIVE_ERR}
199
200
  return json.dumps(tips)
201
+
202
+ def _ignore_precision_index(self, node_id, precision_index):
203
+ node_id_split = node_id.split(Const.SEP)
204
+ if len(node_id_split) < 2:
205
+ return precision_index
206
+ if node_id.split(Const.SEP)[1] in GraphConst.IGNORE_PRECISION_INDEX:
207
+ return GraphConst.MAX_INDEX_KEY if self.compare_mode == GraphConst.MD5_COMPARE else GraphConst.MIN_INDEX_KEY
208
+ return precision_index
@@ -184,6 +184,9 @@ class GraphConst:
184
184
  OP = 'op'
185
185
  PEER = 'peer'
186
186
  GROUP_ID = 'group_id'
187
+
188
+ IGNORE_PRECISION_INDEX = {'empty', 'empty_like', 'empty_with_format', 'new_empty_strided', 'new_empty',
189
+ 'empty_strided'}
187
190
 
188
191
 
189
192
  def is_serializable(obj):
@@ -1,9 +0,0 @@
1
- {
2
- "dump_json_path": "./dump.json",
3
- "api_name": "Mint.split.1",
4
- "extract_api_path": "Mint.split.1.json",
5
- "propagation": "backward",
6
- "data_mode": "random_data",
7
- "random_seed": 1234,
8
- "iter_times": 1
9
- }