mindstudio-probe 8.1.0__py3-none-any.whl → 8.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/METADATA +3 -2
- {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/RECORD +46 -47
- msprobe/core/common/const.py +1 -0
- msprobe/core/common/file_utils.py +36 -18
- msprobe/core/common/utils.py +19 -8
- msprobe/core/compare/acc_compare.py +14 -5
- msprobe/core/compare/utils.py +7 -1
- msprobe/core/data_dump/data_collector.py +144 -90
- msprobe/core/data_dump/json_writer.py +31 -1
- msprobe/core/debugger/precision_debugger.py +19 -18
- msprobe/core/service.py +1 -0
- msprobe/core/single_save/single_comparator.py +25 -25
- msprobe/core/single_save/single_saver.py +5 -16
- msprobe/docs/01.installation.md +1 -0
- msprobe/docs/05.data_dump_PyTorch.md +3 -0
- msprobe/docs/06.data_dump_MindSpore.md +3 -0
- msprobe/docs/08.accuracy_checker_online_PyTorch.md +2 -2
- msprobe/docs/25.tool_function_introduction.md +19 -19
- msprobe/docs/33.generate_operator_MindSpore.md +10 -19
- msprobe/mindspore/api_accuracy_checker/api_accuracy_checker.py +1 -0
- msprobe/mindspore/api_accuracy_checker/compute_element.py +0 -1
- msprobe/mindspore/api_accuracy_checker/generate_op_script/op_generator.py +10 -1
- msprobe/mindspore/api_accuracy_checker/torch_mindtorch_importer.py +2 -1
- msprobe/mindspore/common/utils.py +1 -0
- msprobe/mindspore/debugger/precision_debugger.py +4 -4
- msprobe/mindspore/dump/cell_dump_process.py +13 -38
- msprobe/mindspore/dump/cell_dump_with_insert_gradient.py +1 -26
- msprobe/mindspore/dump/hook_cell/api_register.py +3 -3
- msprobe/mindspore/dym_loader/hook_dynamic_loader.cpp +4 -4
- msprobe/mindspore/mindspore_service.py +3 -0
- msprobe/mindspore/monitor/features.py +10 -9
- msprobe/mindspore/monitor/optimizer_collect.py +4 -1
- msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +20 -20
- msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +7 -7
- msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/utils.py +2 -0
- msprobe/pytorch/common/utils.py +1 -1
- msprobe/pytorch/debugger/precision_debugger.py +28 -25
- msprobe/pytorch/hook_module/api_register.py +3 -3
- msprobe/pytorch/monitor/optimizer_collect.py +4 -1
- msprobe/pytorch/pytorch_service.py +3 -0
- msprobe/visualization/compare/mode_adapter.py +9 -0
- msprobe/visualization/utils.py +3 -0
- msprobe/mindspore/api_accuracy_checker/generate_op_script/config_op.json +0 -9
- {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/LICENSE +0 -0
- {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/WHEEL +0 -0
- {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/entry_points.txt +0 -0
- {mindstudio_probe-8.1.0.dist-info → mindstudio_probe-8.1.1.dist-info}/top_level.txt +0 -0
|
@@ -70,7 +70,7 @@ def split_json_file(input_file, num_splits, filter_api):
|
|
|
70
70
|
split_forward_data = dict(items[start:end])
|
|
71
71
|
temp_data = {
|
|
72
72
|
**input_data,
|
|
73
|
-
"data":{
|
|
73
|
+
"data": {
|
|
74
74
|
**split_forward_data,
|
|
75
75
|
**backward_data
|
|
76
76
|
}
|
|
@@ -141,7 +141,7 @@ def run_parallel_ut(config):
|
|
|
141
141
|
|
|
142
142
|
for api_info in config.api_files:
|
|
143
143
|
cmd = create_cmd(api_info, next(device_id_cycle))
|
|
144
|
-
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
|
|
144
|
+
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
|
|
145
145
|
text=True, bufsize=1, shell=False)
|
|
146
146
|
processes.append(process)
|
|
147
147
|
threading.Thread(target=read_process_output, args=(process,), daemon=True).start()
|
|
@@ -187,8 +187,8 @@ def run_parallel_ut(config):
|
|
|
187
187
|
|
|
188
188
|
|
|
189
189
|
def prepare_config(args):
|
|
190
|
-
api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
|
|
191
|
-
|
|
190
|
+
api_info_file_checker = FileChecker(file_path=args.api_info_file, path_type=FileCheckConst.FILE,
|
|
191
|
+
ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX)
|
|
192
192
|
api_info = api_info_file_checker.common_check()
|
|
193
193
|
out_path = args.out_path if args.out_path else Const.DEFAULT_PATH
|
|
194
194
|
create_directory(out_path)
|
|
@@ -197,11 +197,11 @@ def prepare_config(args):
|
|
|
197
197
|
split_files, total_items = split_json_file(api_info, args.num_splits, args.filter_api)
|
|
198
198
|
config_path = args.config_path if args.config_path else None
|
|
199
199
|
if config_path:
|
|
200
|
-
config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
|
|
200
|
+
config_path_checker = FileChecker(config_path, FileCheckConst.FILE,
|
|
201
201
|
FileCheckConst.READ_ABLE, FileCheckConst.JSON_SUFFIX)
|
|
202
202
|
config_path = config_path_checker.common_check()
|
|
203
203
|
result_csv_path = args.result_csv_path or os.path.join(
|
|
204
|
-
|
|
204
|
+
out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv")
|
|
205
205
|
if not args.result_csv_path:
|
|
206
206
|
details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv")
|
|
207
207
|
comparator = Comparator(result_csv_path, details_csv_path, False)
|
|
@@ -220,7 +220,7 @@ def main():
|
|
|
220
220
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
221
221
|
parser = argparse.ArgumentParser(description='Run UT in parallel')
|
|
222
222
|
_run_ut_parser(parser)
|
|
223
|
-
parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
|
|
223
|
+
parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8,
|
|
224
224
|
help='Number of splits for parallel processing. Range: 1-64')
|
|
225
225
|
args = parser.parse_args()
|
|
226
226
|
config = prepare_config(args)
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
13
|
# See the License for the specific language governing permissions and
|
|
14
14
|
# limitations under the License.
|
|
15
|
+
import gc
|
|
15
16
|
import os
|
|
16
17
|
from datetime import datetime, timezone
|
|
17
18
|
|
|
@@ -117,6 +118,7 @@ def load_ssl_pem(key_file, cert_file, ca_file, crl_file):
|
|
|
117
118
|
with FileOpen(key_file, "rb") as f:
|
|
118
119
|
key = crypto.load_privatekey(crypto.FILETYPE_PEM, f.read(), passphrase.encode())
|
|
119
120
|
del passphrase
|
|
121
|
+
gc.collect()
|
|
120
122
|
with FileOpen(cert_file, "rb") as f:
|
|
121
123
|
crt = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
|
|
122
124
|
check_crt_valid(crt)
|
msprobe/pytorch/common/utils.py
CHANGED
|
@@ -406,7 +406,7 @@ def load_api_data(api_data_bytes):
|
|
|
406
406
|
"""Load data from bytes stream"""
|
|
407
407
|
try:
|
|
408
408
|
buffer = io.BytesIO(api_data_bytes)
|
|
409
|
-
buffer = torch.load(buffer, map_location="cpu")
|
|
409
|
+
buffer = torch.load(buffer, map_location="cpu", weights_only=False)
|
|
410
410
|
except Exception as e:
|
|
411
411
|
raise RuntimeError("load api_data from bytes failed") from e
|
|
412
412
|
return buffer
|
|
@@ -53,19 +53,36 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
53
53
|
self.module_dumper = ModuleDumper(self.service)
|
|
54
54
|
self.ori_customer_func = {}
|
|
55
55
|
self.enable_dataloader = self.config.enable_dataloader
|
|
56
|
-
self.
|
|
57
|
-
|
|
58
|
-
@property
|
|
59
|
-
def instance(self):
|
|
60
|
-
return self._instance
|
|
56
|
+
self._param_warning()
|
|
61
57
|
|
|
62
58
|
@staticmethod
|
|
63
|
-
def
|
|
59
|
+
def _get_task_config(task, json_config):
|
|
64
60
|
return parse_task_config(task, json_config)
|
|
65
61
|
|
|
62
|
+
@staticmethod
|
|
63
|
+
def _iter_tracer(func):
|
|
64
|
+
def func_wrapper(*args, **kwargs):
|
|
65
|
+
debugger_instance = PrecisionDebugger._instance
|
|
66
|
+
if not debugger_instance:
|
|
67
|
+
raise MsprobeException(
|
|
68
|
+
MsprobeException.INTERFACE_USAGE_ERROR,
|
|
69
|
+
f"PrecisionDebugger must be instantiated before executing the dataloader iteration"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
debugger_instance.enable_dataloader = False
|
|
73
|
+
if not debugger_instance.service.first_start:
|
|
74
|
+
debugger_instance.stop()
|
|
75
|
+
debugger_instance.step()
|
|
76
|
+
result = func(*args, **kwargs)
|
|
77
|
+
debugger_instance.start()
|
|
78
|
+
debugger_instance.enable_dataloader = True
|
|
79
|
+
return result
|
|
80
|
+
|
|
81
|
+
return func_wrapper
|
|
82
|
+
|
|
66
83
|
@classmethod
|
|
67
84
|
def start(cls, model=None, token_range=None):
|
|
68
|
-
instance = cls.
|
|
85
|
+
instance = cls._get_instance()
|
|
69
86
|
if instance is None:
|
|
70
87
|
return
|
|
71
88
|
|
|
@@ -79,7 +96,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
79
96
|
|
|
80
97
|
@classmethod
|
|
81
98
|
def stop(cls):
|
|
82
|
-
instance = cls.
|
|
99
|
+
instance = cls._get_instance()
|
|
83
100
|
if instance is None:
|
|
84
101
|
return
|
|
85
102
|
if instance.enable_dataloader:
|
|
@@ -89,7 +106,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
89
106
|
|
|
90
107
|
@classmethod
|
|
91
108
|
def step(cls):
|
|
92
|
-
instance = cls.
|
|
109
|
+
instance = cls._get_instance()
|
|
93
110
|
if instance is None:
|
|
94
111
|
return
|
|
95
112
|
cls._instance.service.step()
|
|
@@ -115,7 +132,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
115
132
|
return
|
|
116
133
|
instance.service.save(variable, name, save_backward)
|
|
117
134
|
|
|
118
|
-
def
|
|
135
|
+
def _param_warning(self):
|
|
119
136
|
if self.model is not None:
|
|
120
137
|
logger.warning_on_rank_0(
|
|
121
138
|
"The 'model' parameter in the PrecisionDebugger will be deprecated in the future."
|
|
@@ -123,7 +140,7 @@ class PrecisionDebugger(BasePrecisionDebugger):
|
|
|
123
140
|
)
|
|
124
141
|
if self.enable_dataloader:
|
|
125
142
|
logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.")
|
|
126
|
-
dataloader._BaseDataLoaderIter.__next__ =
|
|
143
|
+
dataloader._BaseDataLoaderIter.__next__ = self._iter_tracer(dataloader._BaseDataLoaderIter.__next__)
|
|
127
144
|
|
|
128
145
|
|
|
129
146
|
def module_dump(module, dump_name):
|
|
@@ -155,17 +172,3 @@ def module_dump_end():
|
|
|
155
172
|
f"PrecisionDebugger must be instantiated before using module_dump_end interface"
|
|
156
173
|
)
|
|
157
174
|
instance.module_dumper.stop_module_dump()
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
def iter_tracer(func):
|
|
161
|
-
def func_wrapper(*args, **kwargs):
|
|
162
|
-
debugger_instance = PrecisionDebugger.instance
|
|
163
|
-
debugger_instance.enable_dataloader = False
|
|
164
|
-
if not debugger_instance.service.first_start:
|
|
165
|
-
debugger_instance.stop()
|
|
166
|
-
debugger_instance.step()
|
|
167
|
-
result = func(*args, **kwargs)
|
|
168
|
-
debugger_instance.start()
|
|
169
|
-
debugger_instance.enable_dataloader = True
|
|
170
|
-
return result
|
|
171
|
-
return func_wrapper
|
|
@@ -89,12 +89,12 @@ def dist_module_forward(module, *args, **kwargs):
|
|
|
89
89
|
try:
|
|
90
90
|
bound = inspect.signature(module.api_func).bind(*args, **kwargs)
|
|
91
91
|
bound.apply_defaults()
|
|
92
|
-
|
|
92
|
+
use_async_op_flag = bound.arguments.get("async_op", False)
|
|
93
93
|
except Exception as e:
|
|
94
|
-
|
|
94
|
+
use_async_op_flag = False
|
|
95
95
|
logger.warning(f"fail to get dist api's func signature because {e}, no wait")
|
|
96
96
|
|
|
97
|
-
if
|
|
97
|
+
if use_async_op_flag or module.api_name in ["isend", "irecv"]:
|
|
98
98
|
if handle and hasattr(handle, 'wait'):
|
|
99
99
|
handle.wait()
|
|
100
100
|
if module.api_name == "batch_isend_irecv":
|
|
@@ -109,6 +109,9 @@ class OptimizerMon(object):
|
|
|
109
109
|
else:
|
|
110
110
|
logger.warning(f"step of {name} is None, maybe something wrong happened.")
|
|
111
111
|
continue
|
|
112
|
+
if exp_avg is None or exp_avg_sq is None:
|
|
113
|
+
logger.warning(f"exp_avg or exp_avg_sq of {name} is None, skip calculation.")
|
|
114
|
+
continue
|
|
112
115
|
exp_avg_hat = exp_avg / (1 - self.torch_opt.defaults['betas'][0] ** step)
|
|
113
116
|
exp_avg_sq_hat = exp_avg_sq / (1 - self.torch_opt.defaults['betas'][1] ** step)
|
|
114
117
|
update_dict[name] = exp_avg_hat / (torch.sqrt(exp_avg_sq_hat) + self.torch_opt.defaults['eps'])
|
|
@@ -296,7 +299,7 @@ class DeepSpeedZeroOptimizerStage3Mon(DeepSpeedZeroOptimizerMon):
|
|
|
296
299
|
self.fp32_flat_groups = torch_opt.fp32_partitioned_groups_flat
|
|
297
300
|
self.param2group = self.get_group_index()
|
|
298
301
|
|
|
299
|
-
def param_not_in_partition(self,
|
|
302
|
+
def param_not_in_partition(self, lp_param, group_idx):
|
|
300
303
|
"""Each param partioned across all zero ranks"""
|
|
301
304
|
return False
|
|
302
305
|
|
|
@@ -161,6 +161,7 @@ class ModeAdapter:
|
|
|
161
161
|
else change_percentage
|
|
162
162
|
precision_index = GraphConst.MAX_INDEX_KEY \
|
|
163
163
|
if change_percentage > GraphConst.MAX_INDEX_KEY else change_percentage
|
|
164
|
+
precision_index = self._ignore_precision_index(node.id, precision_index)
|
|
164
165
|
return precision_index, other_dict
|
|
165
166
|
|
|
166
167
|
def prepare_real_data(self, node):
|
|
@@ -197,3 +198,11 @@ class ModeAdapter:
|
|
|
197
198
|
CompareConst.MAX_ABS_ERR: ToolTip.MAX_ABS_ERR,
|
|
198
199
|
CompareConst.MAX_RELATIVE_ERR: ToolTip.MAX_RELATIVE_ERR}
|
|
199
200
|
return json.dumps(tips)
|
|
201
|
+
|
|
202
|
+
def _ignore_precision_index(self, node_id, precision_index):
|
|
203
|
+
node_id_split = node_id.split(Const.SEP)
|
|
204
|
+
if len(node_id_split) < 2:
|
|
205
|
+
return precision_index
|
|
206
|
+
if node_id.split(Const.SEP)[1] in GraphConst.IGNORE_PRECISION_INDEX:
|
|
207
|
+
return GraphConst.MAX_INDEX_KEY if self.compare_mode == GraphConst.MD5_COMPARE else GraphConst.MIN_INDEX_KEY
|
|
208
|
+
return precision_index
|
msprobe/visualization/utils.py
CHANGED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|