mindspore 2.3.0rc1__cp38-cp38-manylinux1_x86_64.whl → 2.3.0rc2__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +1 -1
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +13 -3
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +20 -0
- mindspore/_extends/parse/parser.py +1 -1
- mindspore/_extends/parse/standard_method.py +6 -5
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +5 -5
- mindspore/boost/boost_cell_wrapper.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -2
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_stub_tensor.py +1 -0
- mindspore/common/api.py +56 -4
- mindspore/common/dtype.py +5 -3
- mindspore/common/dump.py +2 -2
- mindspore/common/hook_handle.py +51 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +17 -6
- mindspore/common/parameter.py +7 -2
- mindspore/common/recompute.py +247 -0
- mindspore/common/sparse_tensor.py +2 -2
- mindspore/common/symbol.py +1 -1
- mindspore/common/tensor.py +74 -36
- mindspore/communication/__init__.py +3 -3
- mindspore/communication/management.py +30 -30
- mindspore/context.py +28 -15
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +51 -51
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +3 -3
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +3 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +3 -3
- mindspore/dataset/engine/datasets_vision.py +68 -68
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +26 -26
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/transforms.py +92 -92
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/experimental/optim/adadelta.py +2 -2
- mindspore/experimental/optim/adagrad.py +2 -2
- mindspore/experimental/optim/adam.py +2 -2
- mindspore/experimental/optim/adamax.py +2 -2
- mindspore/experimental/optim/adamw.py +2 -2
- mindspore/experimental/optim/asgd.py +2 -2
- mindspore/experimental/optim/lr_scheduler.py +24 -20
- mindspore/experimental/optim/nadam.py +2 -2
- mindspore/experimental/optim/optimizer.py +1 -1
- mindspore/experimental/optim/radam.py +2 -2
- mindspore/experimental/optim/rmsprop.py +2 -2
- mindspore/experimental/optim/rprop.py +2 -2
- mindspore/experimental/optim/sgd.py +2 -2
- mindspore/hal/stream.py +2 -0
- mindspore/include/mindapi/base/types.h +5 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6 -6
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/log.py +2 -2
- mindspore/mint/__init__.py +457 -0
- mindspore/mint/nn/__init__.py +430 -0
- mindspore/mint/nn/functional.py +424 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +186 -0
- mindspore/multiprocessing/__init__.py +4 -0
- mindspore/nn/__init__.py +3 -0
- mindspore/nn/cell.py +51 -47
- mindspore/nn/extend/__init__.py +29 -0
- mindspore/nn/extend/basic.py +140 -0
- mindspore/nn/extend/embedding.py +143 -0
- mindspore/nn/extend/layer/__init__.py +27 -0
- mindspore/nn/extend/layer/normalization.py +107 -0
- mindspore/nn/extend/pooling.py +117 -0
- mindspore/nn/generator.py +297 -0
- mindspore/nn/layer/basic.py +109 -1
- mindspore/nn/layer/container.py +2 -2
- mindspore/nn/layer/conv.py +6 -6
- mindspore/nn/layer/embedding.py +1 -1
- mindspore/nn/layer/normalization.py +21 -43
- mindspore/nn/layer/padding.py +4 -0
- mindspore/nn/optim/ada_grad.py +2 -2
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +7 -7
- mindspore/nn/optim/adamax.py +2 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +1 -1
- mindspore/nn/optim/lamb.py +3 -3
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +2 -2
- mindspore/nn/optim/momentum.py +2 -2
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +2 -2
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/sgd.py +2 -2
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +9 -9
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -2
- mindspore/ops/_vmap/vmap_math_ops.py +27 -8
- mindspore/ops/_vmap/vmap_nn_ops.py +66 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +73 -1
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +12 -3
- mindspore/ops/auto_generate/gen_arg_handler.py +24 -0
- mindspore/ops/auto_generate/gen_extend_func.py +274 -0
- mindspore/ops/auto_generate/gen_ops_def.py +889 -22
- mindspore/ops/auto_generate/gen_ops_prim.py +3541 -253
- mindspore/ops/auto_generate/pyboost_inner_prim.py +282 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +9 -0
- mindspore/ops/extend/__init__.py +9 -1
- mindspore/ops/extend/array_func.py +134 -27
- mindspore/ops/extend/math_func.py +3 -3
- mindspore/ops/extend/nn_func.py +363 -2
- mindspore/ops/function/__init__.py +19 -2
- mindspore/ops/function/array_func.py +463 -439
- mindspore/ops/function/clip_func.py +7 -18
- mindspore/ops/function/grad/grad_func.py +5 -5
- mindspore/ops/function/linalg_func.py +4 -4
- mindspore/ops/function/math_func.py +260 -243
- mindspore/ops/function/nn_func.py +825 -62
- mindspore/ops/function/random_func.py +73 -4
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +1 -1
- mindspore/ops/functional.py +2 -2
- mindspore/ops/op_info_register.py +1 -31
- mindspore/ops/operations/__init__.py +2 -3
- mindspore/ops/operations/_grad_ops.py +2 -107
- mindspore/ops/operations/_inner_ops.py +5 -5
- mindspore/ops/operations/_sequence_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +11 -233
- mindspore/ops/operations/comm_ops.py +32 -32
- mindspore/ops/operations/custom_ops.py +7 -89
- mindspore/ops/operations/manually_defined/ops_def.py +329 -4
- mindspore/ops/operations/math_ops.py +13 -163
- mindspore/ops/operations/nn_ops.py +9 -316
- mindspore/ops/operations/random_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +3 -3
- mindspore/ops/primitive.py +2 -2
- mindspore/ops_generate/arg_dtype_cast.py +12 -3
- mindspore/ops_generate/arg_handler.py +24 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +2 -0
- mindspore/ops_generate/gen_pyboost_func.py +13 -6
- mindspore/ops_generate/pyboost_utils.py +2 -17
- mindspore/parallel/__init__.py +3 -2
- mindspore/parallel/_auto_parallel_context.py +106 -1
- mindspore/parallel/_parallel_serialization.py +34 -2
- mindspore/parallel/_utils.py +16 -0
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/checkpoint_transform.py +249 -77
- mindspore/parallel/cluster/process_entity/_api.py +1 -1
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +1 -1
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +1 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +17 -5
- mindspore/profiler/parser/ascend_msprof_exporter.py +3 -3
- mindspore/profiler/parser/ascend_msprof_generator.py +10 -3
- mindspore/profiler/parser/ascend_op_generator.py +26 -9
- mindspore/profiler/parser/ascend_timeline_generator.py +7 -4
- mindspore/profiler/parser/profiler_info.py +11 -1
- mindspore/profiler/profiling.py +13 -5
- mindspore/rewrite/api/node.py +12 -12
- mindspore/rewrite/api/symbol_tree.py +11 -11
- mindspore/run_check/_check_version.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +2 -2
- mindspore/train/amp.py +4 -4
- mindspore/train/anf_ir_pb2.py +8 -2
- mindspore/train/callback/_backup_and_restore.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +2 -2
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +2 -2
- mindspore/train/callback/_time_monitor.py +2 -2
- mindspore/train/dataset_helper.py +8 -3
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/mind_ir_pb2.py +22 -17
- mindspore/train/model.py +15 -15
- mindspore/train/serialization.py +18 -18
- mindspore/train/summary/summary_record.py +7 -7
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/version.py +1 -1
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +1 -1
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +223 -209
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
|
@@ -67,7 +67,9 @@ class AscendOPGenerator:
|
|
|
67
67
|
"""
|
|
68
68
|
Analyse op summary op statistic generate op data.
|
|
69
69
|
"""
|
|
70
|
-
|
|
70
|
+
if isinstance(self.op_summary, np.ndarray) and self.op_summary.shape[0] == 0 or \
|
|
71
|
+
not isinstance(self.op_summary, np.ndarray) and not self.op_summary:
|
|
72
|
+
return
|
|
71
73
|
self._combine_op_and_kernel(self.op_summary, self.launch_ops)
|
|
72
74
|
# aicore intermediation detail
|
|
73
75
|
self.op_detail = self._parse_op_detail(self.op_summary)
|
|
@@ -97,7 +99,7 @@ class AscendOPGenerator:
|
|
|
97
99
|
output_timeline_data_path : output_timeline_data.txt path
|
|
98
100
|
"""
|
|
99
101
|
# aicore intermediation detail
|
|
100
|
-
if self.op_detail.shape[0] != 0:
|
|
102
|
+
if isinstance(self.op_detail, np.ndarray) and self.op_detail.shape[0] != 0:
|
|
101
103
|
try:
|
|
102
104
|
with os.fdopen(os.open(aicore_intermediate_detail_path,
|
|
103
105
|
os.O_WRONLY | os.O_CREAT | os.O_TRUNC, stat.S_IWUSR | stat.S_IRUSR),
|
|
@@ -112,7 +114,7 @@ class AscendOPGenerator:
|
|
|
112
114
|
os.chmod(aicore_intermediate_detail_path, stat.S_IREAD | stat.S_IWRITE)
|
|
113
115
|
|
|
114
116
|
# aicore intermediation type
|
|
115
|
-
if self.op_type.shape[0] != 0:
|
|
117
|
+
if isinstance(self.op_type, np.ndarray) and self.op_type.shape[0] != 0:
|
|
116
118
|
try:
|
|
117
119
|
with os.fdopen(os.open(aicore_intermediate_type_path,
|
|
118
120
|
os.O_WRONLY | os.O_CREAT | os.O_TRUNC, stat.S_IWUSR | stat.S_IRUSR),
|
|
@@ -127,7 +129,7 @@ class AscendOPGenerator:
|
|
|
127
129
|
os.chmod(aicore_intermediate_type_path, stat.S_IREAD | stat.S_IWRITE)
|
|
128
130
|
|
|
129
131
|
# aicpu_intermediation
|
|
130
|
-
if self.aicpu_detail.shape[0] != 0:
|
|
132
|
+
if isinstance(self.aicpu_detail, np.ndarray) and self.aicpu_detail.shape[0] != 0:
|
|
131
133
|
try:
|
|
132
134
|
with os.fdopen(os.open(aicpu_intermediate_detail_path,
|
|
133
135
|
os.O_WRONLY | os.O_CREAT | os.O_TRUNC, stat.S_IWUSR | stat.S_IRUSR),
|
|
@@ -142,7 +144,7 @@ class AscendOPGenerator:
|
|
|
142
144
|
os.chmod(aicpu_intermediate_detail_path, stat.S_IREAD | stat.S_IWRITE)
|
|
143
145
|
|
|
144
146
|
# framwork_raw
|
|
145
|
-
if self.framework_raw.shape[0] != 0:
|
|
147
|
+
if isinstance(self.framework_raw, np.ndarray) and self.framework_raw.shape[0] != 0:
|
|
146
148
|
try:
|
|
147
149
|
with os.fdopen(os.open(framework_raw_path,
|
|
148
150
|
os.O_WRONLY | os.O_CREAT | os.O_TRUNC, stat.S_IWUSR | stat.S_IRUSR),
|
|
@@ -157,7 +159,8 @@ class AscendOPGenerator:
|
|
|
157
159
|
os.chmod(framework_raw_path, stat.S_IREAD | stat.S_IWRITE)
|
|
158
160
|
|
|
159
161
|
# output_timeline_data
|
|
160
|
-
if self.output_timeline_data.
|
|
162
|
+
if isinstance(self.output_timeline_data, np.ndarray) and self.output_timeline_data.size and \
|
|
163
|
+
self.output_timeline_data.shape[0] != 0 and output_timeline_data_path:
|
|
161
164
|
try:
|
|
162
165
|
with os.fdopen(os.open(output_timeline_data_path,
|
|
163
166
|
os.O_WRONLY | os.O_CREAT | os.O_TRUNC, stat.S_IWUSR | stat.S_IRUSR),
|
|
@@ -173,6 +176,9 @@ class AscendOPGenerator:
|
|
|
173
176
|
|
|
174
177
|
def _combine_op_and_kernel(self, op_summary, launch_ops):
|
|
175
178
|
"""update op name, kernel name etc."""
|
|
179
|
+
if isinstance(op_summary, np.ndarray) and op_summary.shape[0] == 0 or not isinstance(op_summary, np.ndarray) \
|
|
180
|
+
and not op_summary:
|
|
181
|
+
return
|
|
176
182
|
self._full_kernel_name = op_summary['Op Name'].copy()
|
|
177
183
|
self._op_name = op_summary['Op Name'].copy()
|
|
178
184
|
self._kernel_name = np.array(
|
|
@@ -199,6 +205,9 @@ class AscendOPGenerator:
|
|
|
199
205
|
Args:
|
|
200
206
|
op_summary(DataFrame): op summary data.
|
|
201
207
|
"""
|
|
208
|
+
if isinstance(op_summary, np.ndarray) and op_summary.shape[0] == 0 or \
|
|
209
|
+
not isinstance(op_summary, np.ndarray) and not op_summary:
|
|
210
|
+
return None
|
|
202
211
|
if self.aclnn_status:
|
|
203
212
|
op_detail = np.empty((len(op_summary),), dtype=self.op_detail_dt)
|
|
204
213
|
op_detail['task_type'] = op_summary['Task Type']
|
|
@@ -226,7 +235,9 @@ class AscendOPGenerator:
|
|
|
226
235
|
Args:
|
|
227
236
|
op_statistic(DataFrame): op statistic data.
|
|
228
237
|
"""
|
|
229
|
-
|
|
238
|
+
if isinstance(op_statistic, np.ndarray) and op_statistic.shape[0] == 0 or \
|
|
239
|
+
not isinstance(op_statistic, np.ndarray) and not op_statistic:
|
|
240
|
+
return None
|
|
230
241
|
groups, _, inverse, _ = np.unique(op_statistic['Op Type'], return_index=True, return_inverse=True,
|
|
231
242
|
return_counts=True)
|
|
232
243
|
|
|
@@ -246,7 +257,9 @@ class AscendOPGenerator:
|
|
|
246
257
|
Args:
|
|
247
258
|
op_summary(DataFrame): op summary data.
|
|
248
259
|
"""
|
|
249
|
-
|
|
260
|
+
if isinstance(op_summary, np.ndarray) and op_summary.shape[0] == 0 or \
|
|
261
|
+
not isinstance(op_summary, np.ndarray) and not op_summary:
|
|
262
|
+
return None
|
|
250
263
|
op_summary = op_summary[op_summary['Task Type'] == 'AI_CPU']
|
|
251
264
|
|
|
252
265
|
aicpu_detail = np.empty((len(op_summary),), dtype=self.aicpu_detail_dt)
|
|
@@ -271,6 +284,8 @@ class AscendOPGenerator:
|
|
|
271
284
|
|
|
272
285
|
def op_info_analyse(row):
|
|
273
286
|
"""generate op info data"""
|
|
287
|
+
if not row['Input Shapes']:
|
|
288
|
+
return ""
|
|
274
289
|
input_shapes = row['Input Shapes'].replace('"', '').split(';')
|
|
275
290
|
input_data_types = row['Input Data Types'].replace('_', '').split(';')
|
|
276
291
|
input_formats = row['Input Formats'].replace('_', '').split(';')
|
|
@@ -295,7 +310,9 @@ class AscendOPGenerator:
|
|
|
295
310
|
'shape': output_shapes[i]
|
|
296
311
|
}
|
|
297
312
|
return json.dumps(op_info)
|
|
298
|
-
|
|
313
|
+
if isinstance(op_summary, np.ndarray) and op_summary.shape[0] == 0 or \
|
|
314
|
+
not isinstance(op_summary, np.ndarray) and not op_summary:
|
|
315
|
+
return None
|
|
299
316
|
if self.dynamic_status or self.aclnn_status:
|
|
300
317
|
index = list(range(op_summary.shape[0]))
|
|
301
318
|
else:
|
|
@@ -37,12 +37,13 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
|
|
|
37
37
|
scope_index = 1
|
|
38
38
|
cpu_index = 2
|
|
39
39
|
|
|
40
|
-
def __init__(self, profiling_dir, source_path, mindstudio_profiler_output, rank_id, mode):
|
|
40
|
+
def __init__(self, profiling_dir, source_path, mindstudio_profiler_output, rank_id, rank_size, mode):
|
|
41
41
|
super().__init__(DeviceTarget.ASCEND.value, mode)
|
|
42
42
|
self._profiling_dir = profiling_dir
|
|
43
43
|
self._source_path = source_path
|
|
44
44
|
self._mindstudio_profiler_output = mindstudio_profiler_output
|
|
45
45
|
self._rank_id = rank_id
|
|
46
|
+
self._rank_size = rank_size
|
|
46
47
|
self._timeline_display_filename = self._timeline_display_filename.format(rank_id)
|
|
47
48
|
self._timeline_summary_filename = self._timeline_summary_filename.format(rank_id)
|
|
48
49
|
self._timeline_data = []
|
|
@@ -63,7 +64,9 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
|
|
|
63
64
|
"""
|
|
64
65
|
|
|
65
66
|
logger.info('parse cluster data...')
|
|
66
|
-
|
|
67
|
+
if isinstance(op_summary, np.ndarray) and op_summary.shape[0] == 0 or \
|
|
68
|
+
not isinstance(op_summary, np.ndarray) and not op_summary:
|
|
69
|
+
return
|
|
67
70
|
timeline_list = op_summary[~np.isin(op_summary['Task Type'], ['AI_CPU', 'HCCL'])][
|
|
68
71
|
['Op Name', 'Stream ID', 'Task Start Time', 'Task Duration']]
|
|
69
72
|
|
|
@@ -151,6 +154,7 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
|
|
|
151
154
|
# get msprof data
|
|
152
155
|
msprof_file_name = fr'{self._mindstudio_profiler_output}/msprof_*.json'
|
|
153
156
|
file_list_msprof = glob.glob(msprof_file_name)
|
|
157
|
+
msprof_timeline = []
|
|
154
158
|
if not file_list_msprof:
|
|
155
159
|
logger.error('Could not find msprof_*.json file in %s', self._mindstudio_profiler_output)
|
|
156
160
|
else:
|
|
@@ -179,8 +183,7 @@ class AscendTimelineGenerator(BaseTimelineGenerator):
|
|
|
179
183
|
fwk_file_path = fr'{self._profiling_dir}/{self._framework_dir}/{oprange_name}'
|
|
180
184
|
if os.path.exists(fwk_file_path):
|
|
181
185
|
# It is faster not to submit to the pool
|
|
182
|
-
|
|
183
|
-
result = self._parse_fwk_device_data(msprof_side_data)
|
|
186
|
+
result = self._parse_fwk_device_data(msprof_timeline)
|
|
184
187
|
timeline_data.extend(result.get("trace_data", []))
|
|
185
188
|
self._kernel_events = result.get("kernels", [])
|
|
186
189
|
|
|
@@ -92,9 +92,19 @@ class ProfilerInfo:
|
|
|
92
92
|
|
|
93
93
|
@staticmethod
|
|
94
94
|
def set_export_flag(flag):
|
|
95
|
-
"""Set
|
|
95
|
+
"""Set whether all-export or not."""
|
|
96
96
|
ProfilerInfo._profiler_info_dict["all_export"] = flag
|
|
97
97
|
|
|
98
|
+
@staticmethod
|
|
99
|
+
def set_system_time(sys_time):
|
|
100
|
+
"""Set system time."""
|
|
101
|
+
ProfilerInfo._profiler_info_dict["system_time"] = sys_time
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def set_system_cnt(sys_cnt):
|
|
105
|
+
"""Set system cnt."""
|
|
106
|
+
ProfilerInfo._profiler_info_dict["system_cnt"] = sys_cnt
|
|
107
|
+
|
|
98
108
|
@staticmethod
|
|
99
109
|
def set_diff_time(diff_time):
|
|
100
110
|
"""synchronize timestamps between different devices"""
|
mindspore/profiler/profiling.py
CHANGED
|
@@ -662,12 +662,12 @@ class Profiler:
|
|
|
662
662
|
>>> # Profiler init.
|
|
663
663
|
>>> profiler = Profiler()
|
|
664
664
|
>>> # Train Model or eval Model, taking LeNet5 as an example.
|
|
665
|
-
>>> # Refer to https://gitee.com/mindspore/docs/blob/
|
|
665
|
+
>>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
666
666
|
>>> net = LeNet5()
|
|
667
667
|
>>> optimizer = nn.Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
668
668
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
669
669
|
>>> # Create the dataset taking MNIST as an example.
|
|
670
|
-
>>> # Refer to https://gitee.com/mindspore/docs/blob/
|
|
670
|
+
>>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
671
671
|
>>> dataloader = create_dataset()
|
|
672
672
|
>>> model = Model(net, loss, optimizer)
|
|
673
673
|
>>> model.train(5, dataloader, dataset_sink_mode=False)
|
|
@@ -742,7 +742,6 @@ class Profiler:
|
|
|
742
742
|
stage_num = get_auto_parallel_context("pipeline_stages")
|
|
743
743
|
|
|
744
744
|
ProfilerInfo.set_parallel_info(parallel_mode, stage_num)
|
|
745
|
-
ProfilerInfo.set_rank_size(self._rank_size)
|
|
746
745
|
ProfilerInfo.set_heterogeneous(self._is_heterogeneous)
|
|
747
746
|
if offline_path:
|
|
748
747
|
ProfilerInfo.set_analyse_start_time(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
|
@@ -842,6 +841,8 @@ class Profiler:
|
|
|
842
841
|
self._md_profiler.start()
|
|
843
842
|
self._ascend_graph_start()
|
|
844
843
|
ProfilerInfo.set_profiling_start_time(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
|
844
|
+
ProfilerInfo.set_system_time(int(c_expression.get_clock_time() * 1e3)) # cast us to ns
|
|
845
|
+
ProfilerInfo.set_system_cnt(c_expression.get_clock_syscnt())
|
|
845
846
|
|
|
846
847
|
def stop(self):
|
|
847
848
|
"""
|
|
@@ -1134,6 +1135,7 @@ class Profiler:
|
|
|
1134
1135
|
self._rank_size = get_group_size()
|
|
1135
1136
|
else:
|
|
1136
1137
|
self._rank_size = int(os.getenv('RANK_SIZE', '1'))
|
|
1138
|
+
ProfilerInfo.set_rank_size(self._rank_size)
|
|
1137
1139
|
|
|
1138
1140
|
if self._has_started:
|
|
1139
1141
|
self.stop()
|
|
@@ -1252,7 +1254,7 @@ class Profiler:
|
|
|
1252
1254
|
try:
|
|
1253
1255
|
logger.info("Profiling: analyzing the timeline data")
|
|
1254
1256
|
timeline_analyser = AscendTimelineGenerator(self._output_path, source_path, mindstudio_profiler_output,
|
|
1255
|
-
self._rank_id, context.get_context('mode'))
|
|
1257
|
+
self._rank_id, self._rank_size, context.get_context('mode'))
|
|
1256
1258
|
timeline_analyser.parse_cluster_data(op_summary, steptrace)
|
|
1257
1259
|
timeline_analyser.parse_timeline_data(pretty=self._pretty_json)
|
|
1258
1260
|
timeline_analyser.write_timeline_display()
|
|
@@ -1439,7 +1441,8 @@ class Profiler:
|
|
|
1439
1441
|
key = name if name.startswith("hcom_") else (name, ts)
|
|
1440
1442
|
launch_op = kernel_map.get(key)
|
|
1441
1443
|
if not launch_op:
|
|
1442
|
-
|
|
1444
|
+
if context.get_context("mode") == context.GRAPH_MODE or not name.startswith("aclnn"):
|
|
1445
|
+
logger.warning(f"Failed to get launch operator for {name}!")
|
|
1443
1446
|
continue
|
|
1444
1447
|
launch_ops[index] = launch_op.name
|
|
1445
1448
|
return launch_ops
|
|
@@ -1467,6 +1470,9 @@ class Profiler:
|
|
|
1467
1470
|
ProfilerInfo.set_export_flag(flag)
|
|
1468
1471
|
op_summary, op_statistic, steptrace, steptrace_model \
|
|
1469
1472
|
= _ascend_graph_msprof_analyse(mindstudio_profiler_output)
|
|
1473
|
+
if isinstance(op_statistic, np.ndarray) and op_statistic.shape[0] == 0 or \
|
|
1474
|
+
not isinstance(op_statistic, np.ndarray) and not op_statistic:
|
|
1475
|
+
return
|
|
1470
1476
|
kernels = self._ascend_timeline_analyse(op_summary, steptrace, source_path, mindstudio_profiler_output)
|
|
1471
1477
|
launch_ops = self._get_kernel_op_map(op_summary, kernels)
|
|
1472
1478
|
self._ascend_op_analyse(op_summary, op_statistic, self._dynamic_status, launch_ops)
|
|
@@ -1505,6 +1511,8 @@ class Profiler:
|
|
|
1505
1511
|
else:
|
|
1506
1512
|
self._rank_size = int(os.getenv('RANK_SIZE', '1'))
|
|
1507
1513
|
|
|
1514
|
+
ProfilerInfo.set_rank_size(self._rank_size)
|
|
1515
|
+
|
|
1508
1516
|
if self._has_started:
|
|
1509
1517
|
self.stop()
|
|
1510
1518
|
else:
|
mindspore/rewrite/api/node.py
CHANGED
|
@@ -89,7 +89,7 @@ class Node:
|
|
|
89
89
|
>>> from mindspore.rewrite import SymbolTree, ScopedValue
|
|
90
90
|
>>> import mindspore.nn as nn
|
|
91
91
|
>>> # Define the network structure of LeNet5. Refer to
|
|
92
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
92
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
93
93
|
>>> net = LeNet5()
|
|
94
94
|
>>> stree = SymbolTree.create(net)
|
|
95
95
|
>>> node = stree.get_node("conv1")
|
|
@@ -144,7 +144,7 @@ class Node:
|
|
|
144
144
|
>>> import mindspore.nn as nn
|
|
145
145
|
>>> import mindspore.ops as ops
|
|
146
146
|
>>> # Define the network structure of LeNet5. Refer to
|
|
147
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
147
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
148
148
|
>>> net = LeNet5()
|
|
149
149
|
>>> stree = SymbolTree.create(net)
|
|
150
150
|
>>> node = stree.get_node("conv1")
|
|
@@ -184,7 +184,7 @@ class Node:
|
|
|
184
184
|
Examples:
|
|
185
185
|
>>> from mindspore.rewrite import SymbolTree
|
|
186
186
|
>>> # Define the network structure of LeNet5. Refer to
|
|
187
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
187
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
188
188
|
>>> net = LeNet5()
|
|
189
189
|
>>> stree = SymbolTree.create(net)
|
|
190
190
|
>>> node = stree.get_node("conv2")
|
|
@@ -204,7 +204,7 @@ class Node:
|
|
|
204
204
|
Examples:
|
|
205
205
|
>>> from mindspore.rewrite import SymbolTree
|
|
206
206
|
>>> # Define the network structure of LeNet5. Refer to
|
|
207
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
207
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
208
208
|
>>> net = LeNet5()
|
|
209
209
|
>>> stree = SymbolTree.create(net)
|
|
210
210
|
>>> node = stree.get_node("conv1")
|
|
@@ -229,7 +229,7 @@ class Node:
|
|
|
229
229
|
Examples:
|
|
230
230
|
>>> from mindspore.rewrite import SymbolTree
|
|
231
231
|
>>> # Define the network structure of LeNet5. Refer to
|
|
232
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
232
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
233
233
|
>>> net = LeNet5()
|
|
234
234
|
>>> stree = SymbolTree.create(net)
|
|
235
235
|
>>> node = stree.get_node("relu_3")
|
|
@@ -267,7 +267,7 @@ class Node:
|
|
|
267
267
|
Examples:
|
|
268
268
|
>>> from mindspore.rewrite import SymbolTree
|
|
269
269
|
>>> # Define the network structure of LeNet5. Refer to
|
|
270
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
270
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
271
271
|
>>> net = LeNet5()
|
|
272
272
|
>>> stree = SymbolTree.create(net)
|
|
273
273
|
>>> src_node = stree.get_node("fc1")
|
|
@@ -307,7 +307,7 @@ class Node:
|
|
|
307
307
|
Examples:
|
|
308
308
|
>>> from mindspore.rewrite import SymbolTree
|
|
309
309
|
>>> # Define the network structure of LeNet5. Refer to
|
|
310
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
310
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
311
311
|
>>> net = LeNet5()
|
|
312
312
|
>>> stree = SymbolTree.create(net)
|
|
313
313
|
>>> node = stree.get_node("conv1")
|
|
@@ -327,7 +327,7 @@ class Node:
|
|
|
327
327
|
Examples:
|
|
328
328
|
>>> from mindspore.rewrite import SymbolTree
|
|
329
329
|
>>> # Define the network structure of LeNet5. Refer to
|
|
330
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
330
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
331
331
|
>>> net = LeNet5()
|
|
332
332
|
>>> stree = SymbolTree.create(net)
|
|
333
333
|
>>> node = stree.get_node("conv1")
|
|
@@ -354,7 +354,7 @@ class Node:
|
|
|
354
354
|
Examples:
|
|
355
355
|
>>> from mindspore.rewrite import SymbolTree
|
|
356
356
|
>>> # Define the network structure of LeNet5. Refer to
|
|
357
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
357
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
358
358
|
>>> net = LeNet5()
|
|
359
359
|
>>> stree = SymbolTree.create(net)
|
|
360
360
|
>>> node = stree.get_node("conv1")
|
|
@@ -377,7 +377,7 @@ class Node:
|
|
|
377
377
|
Examples:
|
|
378
378
|
>>> from mindspore.rewrite import SymbolTree
|
|
379
379
|
>>> # Define the network structure of LeNet5. Refer to
|
|
380
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
380
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
381
381
|
>>> net = LeNet5()
|
|
382
382
|
>>> stree = SymbolTree.create(net)
|
|
383
383
|
>>> node = stree.get_node("conv1")
|
|
@@ -396,7 +396,7 @@ class Node:
|
|
|
396
396
|
Examples:
|
|
397
397
|
>>> from mindspore.rewrite import SymbolTree
|
|
398
398
|
>>> # Define the network structure of LeNet5. Refer to
|
|
399
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
399
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
400
400
|
>>> net = LeNet5()
|
|
401
401
|
>>> stree = SymbolTree.create(net)
|
|
402
402
|
>>> node = stree.get_node("conv1")
|
|
@@ -434,7 +434,7 @@ class Node:
|
|
|
434
434
|
... x = self.relu(x)
|
|
435
435
|
... return x
|
|
436
436
|
...
|
|
437
|
-
|
|
437
|
+
>>> class Net(nn.Cell):
|
|
438
438
|
... def __init__(self):
|
|
439
439
|
... super().__init__()
|
|
440
440
|
... self.subnet = SubNet()
|
|
@@ -119,7 +119,7 @@ class SymbolTree:
|
|
|
119
119
|
Examples:
|
|
120
120
|
>>> from mindspore.rewrite import SymbolTree
|
|
121
121
|
>>> # Define the network structure of LeNet5. Refer to
|
|
122
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
122
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
123
123
|
>>> net = LeNet5()
|
|
124
124
|
>>> stree = SymbolTree.create(net)
|
|
125
125
|
>>> print(type(stree))
|
|
@@ -163,7 +163,7 @@ class SymbolTree:
|
|
|
163
163
|
Examples:
|
|
164
164
|
>>> from mindspore.rewrite import SymbolTree
|
|
165
165
|
>>> # Define the network structure of LeNet5. Refer to
|
|
166
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
166
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
167
167
|
>>> net = LeNet5()
|
|
168
168
|
>>> stree = SymbolTree.create(net)
|
|
169
169
|
>>> print([node.get_name() for node in stree.nodes()])
|
|
@@ -188,7 +188,7 @@ class SymbolTree:
|
|
|
188
188
|
Examples:
|
|
189
189
|
>>> from mindspore.rewrite import SymbolTree
|
|
190
190
|
>>> # Define the network structure of LeNet5. Refer to
|
|
191
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
191
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
192
192
|
>>> net = LeNet5()
|
|
193
193
|
>>> stree = SymbolTree.create(net)
|
|
194
194
|
>>> node = stree.get_node('conv1')
|
|
@@ -221,7 +221,7 @@ class SymbolTree:
|
|
|
221
221
|
Examples:
|
|
222
222
|
>>> from mindspore.rewrite import SymbolTree
|
|
223
223
|
>>> # Define the network structure of LeNet5. Refer to
|
|
224
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
224
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
225
225
|
>>> net = LeNet5()
|
|
226
226
|
>>> stree = SymbolTree.create(net)
|
|
227
227
|
>>> for node in stree.nodes():
|
|
@@ -250,7 +250,7 @@ class SymbolTree:
|
|
|
250
250
|
Examples:
|
|
251
251
|
>>> from mindspore.rewrite import SymbolTree
|
|
252
252
|
>>> # Define the network structure of LeNet5. Refer to
|
|
253
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
253
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
254
254
|
>>> net = LeNet5()
|
|
255
255
|
>>> stree = SymbolTree.create(net)
|
|
256
256
|
>>> for node in stree.nodes():
|
|
@@ -284,7 +284,7 @@ class SymbolTree:
|
|
|
284
284
|
>>> from mindspore.rewrite import SymbolTree, ScopedValue
|
|
285
285
|
>>> import mindspore.nn as nn
|
|
286
286
|
>>> # Define the network structure of LeNet5. Refer to
|
|
287
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
287
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
288
288
|
>>> net = LeNet5()
|
|
289
289
|
>>> stree = SymbolTree.create(net)
|
|
290
290
|
>>> node = stree.get_node("conv1")
|
|
@@ -313,7 +313,7 @@ class SymbolTree:
|
|
|
313
313
|
Examples:
|
|
314
314
|
>>> from mindspore.rewrite import SymbolTree
|
|
315
315
|
>>> # Define the network structure of LeNet5. Refer to
|
|
316
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
316
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
317
317
|
>>> net = LeNet5()
|
|
318
318
|
>>> stree = SymbolTree.create(net)
|
|
319
319
|
>>> node = stree.get_node("conv1")
|
|
@@ -351,7 +351,7 @@ class SymbolTree:
|
|
|
351
351
|
>>> from mindspore.rewrite import SymbolTree, ScopedValue
|
|
352
352
|
>>> import mindspore.nn as nn
|
|
353
353
|
>>> # Define the network structure of LeNet5. Refer to
|
|
354
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
354
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
355
355
|
>>> net = LeNet5()
|
|
356
356
|
>>> stree = SymbolTree.create(net)
|
|
357
357
|
>>> node = stree.get_node("conv1")
|
|
@@ -397,7 +397,7 @@ class SymbolTree:
|
|
|
397
397
|
Examples:
|
|
398
398
|
>>> from mindspore.rewrite import SymbolTree
|
|
399
399
|
>>> # Define the network structure of LeNet5. Refer to
|
|
400
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
400
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
401
401
|
>>> net = LeNet5()
|
|
402
402
|
>>> stree = SymbolTree.create(net)
|
|
403
403
|
>>> stree.print_node_tabulate()
|
|
@@ -417,7 +417,7 @@ class SymbolTree:
|
|
|
417
417
|
Examples:
|
|
418
418
|
>>> from mindspore.rewrite import SymbolTree
|
|
419
419
|
>>> # Define the network structure of LeNet5. Refer to
|
|
420
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
420
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
421
421
|
>>> net = LeNet5()
|
|
422
422
|
>>> stree = SymbolTree.create(net)
|
|
423
423
|
>>> codes = stree.get_code()
|
|
@@ -444,7 +444,7 @@ class SymbolTree:
|
|
|
444
444
|
Examples:
|
|
445
445
|
>>> from mindspore.rewrite import SymbolTree
|
|
446
446
|
>>> # Define the network structure of LeNet5. Refer to
|
|
447
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
447
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
448
448
|
>>> net = LeNet5()
|
|
449
449
|
>>> stree = SymbolTree.create(net)
|
|
450
450
|
>>> new_net = stree.get_network()
|
|
@@ -527,7 +527,7 @@ def check_version_and_env_config():
|
|
|
527
527
|
except OSError:
|
|
528
528
|
logger.warning("Pre-Load Library libgomp.so.1 failed, which might cause TLS memory allocation failure. If "
|
|
529
529
|
"the failure occurs, please refer to the FAQ for a solution: "
|
|
530
|
-
"https://www.mindspore.cn/docs/en/
|
|
530
|
+
"https://www.mindspore.cn/docs/en/master/faq/installation.html.")
|
|
531
531
|
MSContext.get_instance().register_check_env_callback(check_env)
|
|
532
532
|
MSContext.get_instance().register_set_env_callback(set_env)
|
|
533
533
|
MSContext.get_instance().set_device_target_inner(MSContext.get_instance().get_param(ms_ctx_param.device_target))
|
|
@@ -73,7 +73,7 @@ def obfuscate_ckpt(network, ckpt_files, target_modules=None, saved_path='./', ob
|
|
|
73
73
|
|
|
74
74
|
Examples:
|
|
75
75
|
>>> from mindspore import obfuscate_ckpt, save_checkpoint
|
|
76
|
-
>>> # Refer to https://gitee.com/mindspore/docs/blob/
|
|
76
|
+
>>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
77
77
|
>>> net = LeNet5()
|
|
78
78
|
>>> save_checkpoint(net, './test_net.ckpt')
|
|
79
79
|
>>> target_modules = ['', 'fc1|fc2']
|
|
@@ -204,7 +204,7 @@ def load_obf_params_into_net(network, target_modules, obf_ratios, data_parallel_
|
|
|
204
204
|
>>> from mindspore import obfuscate_ckpt, save_checkpoint, load_checkpoint, Tensor
|
|
205
205
|
>>> import mindspore.common.dtype as mstype
|
|
206
206
|
>>> import numpy as np
|
|
207
|
-
>>> # Refer to https://gitee.com/mindspore/docs/blob/
|
|
207
|
+
>>> # Refer to https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
208
208
|
>>> net = LeNet5()
|
|
209
209
|
>>> save_checkpoint(net, './test_net.ckpt')
|
|
210
210
|
>>> target_modules = ['', 'fc1|fc2']
|
mindspore/train/amp.py
CHANGED
|
@@ -331,7 +331,7 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
|
331
331
|
:class:`mindspore.nn.LayerNorm`]
|
|
332
332
|
|
|
333
333
|
For details on automatic mixed precision, refer to
|
|
334
|
-
`Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/
|
|
334
|
+
`Automatic Mix Precision <https://www.mindspore.cn/tutorials/en/master/advanced/mixed_precision.html>`_ .
|
|
335
335
|
|
|
336
336
|
Note:
|
|
337
337
|
- Repeatedly calling mixed-precision interfaces, such as `custom_mixed_precision` and `auto_mixed_precision`,
|
|
@@ -362,7 +362,7 @@ def auto_mixed_precision(network, amp_level="O0", dtype=mstype.float16):
|
|
|
362
362
|
Examples:
|
|
363
363
|
>>> from mindspore import amp
|
|
364
364
|
>>> # Define the network structure of LeNet5. Refer to
|
|
365
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
365
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
366
366
|
>>> network = LeNet5()
|
|
367
367
|
>>> amp_level = "O1"
|
|
368
368
|
>>> net = amp.auto_mixed_precision(network, amp_level)
|
|
@@ -597,7 +597,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
|
|
|
597
597
|
Examples:
|
|
598
598
|
>>> from mindspore import amp, nn
|
|
599
599
|
>>> # Define the network structure of LeNet5. Refer to
|
|
600
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
600
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
601
601
|
>>> network = LeNet5()
|
|
602
602
|
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(reduction="mean")
|
|
603
603
|
>>> net_opt = nn.Momentum(network.trainable_params(), learning_rate=0.01, momentum=0.9)
|
|
@@ -744,7 +744,7 @@ def custom_mixed_precision(network, *, white_list=None, black_list=None, dtype=m
|
|
|
744
744
|
Examples:
|
|
745
745
|
>>> from mindspore import amp, nn
|
|
746
746
|
>>> # Define the network structure of LeNet5. Refer to
|
|
747
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
747
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
748
748
|
>>> net = LeNet5()
|
|
749
749
|
>>> custom_white_list = amp.get_white_list()
|
|
750
750
|
>>> custom_white_list.append(nn.Flatten)
|
mindspore/train/anf_ir_pb2.py
CHANGED
|
@@ -20,7 +20,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
|
|
|
20
20
|
syntax='proto2',
|
|
21
21
|
serialized_options=None,
|
|
22
22
|
create_key=_descriptor._internal_create_key,
|
|
23
|
-
serialized_pb=b'\n\x0c\x61nf_ir.proto\x12\x0emindspore.irpb\"\xdb\x04\n\nValueProto\x12\'\n\x05\x64type\x18\x01 \x01(\x0e\x32\x18.mindspore.irpb.DataType\x12\x10\n\x08\x62ool_val\x18\x02 \x01(\x08\x12\x0f\n\x07int_val\x18\x03 \x01(\x03\x12\x10\n\x08uint_val\x18\x04 \x01(\x04\x12\x11\n\tfloat_val\x18\x05 \x01(\x02\x12\x12\n\ndouble_val\x18\x06 \x01(\x01\x12\x0f\n\x07str_val\x18\x07 \x01(\t\x12/\n\ntensor_val\x18\x08 \x01(\x0b\x32\x1b.mindspore.irpb.TensorProto\x12)\n\x05graph\x18\t \x01(\x0b\x32\x1a.mindspore.irpb.GraphProto\x12\x11\n\tbool_vals\x18\n \x03(\x08\x12\x10\n\x08int_vals\x18\x0b \x03(\x03\x12\x11\n\tuint_vals\x18\x0c \x03(\x04\x12\x12\n\nfloat_vals\x18\r \x03(\x02\x12\x13\n\x0b\x64ouble_vals\x18\x0e \x03(\x01\x12\x10\n\x08str_vals\x18\x0f \x03(\t\x12\x30\n\x0btensor_vals\x18\x10 \x03(\x0b\x32\x1b.mindspore.irpb.TensorProto\x12*\n\x06graphs\x18\x11 \x03(\x0b\x32\x1a.mindspore.irpb.GraphProto\x12*\n\x06values\x18\x12 \x03(\x0b\x32\x1a.mindspore.irpb.ValueProto\x12\x31\n\x08\x64ict_val\x18\x13 \x03(\x0b\x32\x1f.mindspore.irpb.NamedValueProto\x12+\n\x08type_val\x18\x14 \x01(\x0b\x32\x19.mindspore.irpb.TypeProto\"I\n\x0e\x41ttributeProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.mindspore.irpb.ValueProto\"I\n\x0fNamedValueProto\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.mindspore.irpb.ValueProto\"t\n\x10TensorShapeProto\x12\x37\n\x03\x64im\x18\x01 \x03(\x0b\x32*.mindspore.irpb.TensorShapeProto.Dimension\x1a\'\n\tDimension\x12\x0c\n\x04size\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\t\"\xda\x02\n\tTypeProto\x12+\n\tdata_type\x18\x01 \x01(\x0e\x32\x18.mindspore.irpb.DataType\x12\x37\n\x0btensor_type\x18\x02 \x01(\x0b\x32 .mindspore.irpb.TypeProto.TensorH\x00\x12;\n\rsequence_type\x18\x03 \x01(\x0b\x32\".mindspore.irpb.TypeProto.SequenceH\x00\x1a\x66\n\x06Tensor\x12+\n\telem_type\x18\x01 \x01(\x0e\x32\x18.mindspore.irpb.DataType\x12/\n\x05shape\x18\x02 \x01(\x0b\x32 .mindspore.irpb.TensorShapeProto\x1a\x39\n\x08Sequence\x12-\n\nelem_types\x18\x01 \x03(\x0b\x32\x19.mindspore.irpb.TypeProtoB\x07\n\x05value\"x\n\x0eParameterProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\'\n\x04type\x18\x02 \x01(\x0b\x32\x19.mindspore.irpb.TypeProto\x12/\n\x0b\x64\x65\x66\x61ult_val\x18\x03 \x01(\x0b\x32\x1a.mindspore.irpb.ValueProto\"D\n\x0bOutputProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\'\n\x04type\x18\x02 \x01(\x0b\x32\x19.mindspore.irpb.TypeProto\"z\n\nInputProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04type\x18\x02 \x01(\x0e\x32#.mindspore.irpb.InputProto.EdgeType\"+\n\x08\x45\x64geType\x12\r\n\tDATA_EDGE\x10\x00\x12\x10\n\x0c\x43ONTROL_EDGE\x10\x01\"\x83\x02\n\tNodeProto\x12)\n\x05input\x18\x01 \x03(\x0b\x32\x1a.mindspore.irpb.InputProto\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0f\n\x07op_type\x18\x03 \x01(\t\x12\r\n\x05scope\x18\x04 \x01(\t\x12\x31\n\tattribute\x18\x05 \x03(\x0b\x32\x1e.mindspore.irpb.AttributeProto\x12.\n\x0boutput_type\x18\x06 \x01(\x0b\x32\x19.mindspore.irpb.TypeProto\x12\x10\n\x08output_i\x18\x07 \x01(\x04\x12\x11\n\tfull_name\x18\x08 \x01(\t\x12\x15\n\rinstance_name\x18\n \x01(\t\"\xb0\x01\n\nModelProto\x12\x12\n\nir_version\x18\x01 \x01(\x03\x12\x0e\n\x06\x64omain\x18\x02 \x01(\t\x12\x15\n\rmodel_version\x18\x03 \x01(\x03\x12)\n\x05graph\x18\x04 \x01(\x0b\x32\x1a.mindspore.irpb.GraphProto\x12<\n\x12metadata_operators\x18\x05 \x01(\x0b\x32 .mindspore.irpb.OperatorSetProto\"?\n\rOperatorProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06\x63onfig\x18\x02 \x01(\x0c\x12\x10\n\x08obj_info\x18\x03 \x01(\x0c\"U\n\x10OperatorSetProto\x12\x30\n\toperators\x18\x01 \x03(\x0b\x32\x1d.mindspore.irpb.OperatorProto\x12\x0f\n\x07summary\x18\x02 \x01(\t\"\xda\x01\n\nGraphProto\x12\'\n\x04node\x18\x01 \x03(\x0b\x32\x19.mindspore.irpb.NodeProto\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x32\n\nparameters\x18\x03 \x03(\x0b\x32\x1e.mindspore.irpb.ParameterProto\x12,\n\x07outputs\x18\x04 \x03(\x0b\x32\x1b.mindspore.irpb.OutputProto\x12\x33\n\nconst_vals\x18\x05 \x03(\x0b\x32\x1f.mindspore.irpb.NamedValueProto\"\xd4\x01\n\x0bTensorProto\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\x12+\n\tdata_type\x18\x02 \x01(\x0e\x32\x18.mindspore.irpb.DataType\x12\x16\n\nfloat_data\x18\x03 \x03(\x02\x42\x02\x10\x01\x12\x16\n\nint32_data\x18\x04 \x03(\x05\x42\x02\x10\x01\x12\x16\n\nint64_data\x18\x05 \x03(\x03\x42\x02\x10\x01\x12\x17\n\x0b\x64ouble_data\x18\x06 \x03(\x01\x42\x02\x10\x01\x12\x17\n\x0buint64_data\x18\x07 \x03(\x04\x42\x02\x10\x01\x12\x10\n\x08raw_data\x18\x08 \x01(\x0c*/\n\x07Version\x12\x14\n\x10UNKNOWWN_VERSION\x10\x00\x12\x0e\n\nIR_VERSION\x10\x01*\
|
|
23
|
+
serialized_pb=b'\n\x0c\x61nf_ir.proto\x12\x0emindspore.irpb\"\xdb\x04\n\nValueProto\x12\'\n\x05\x64type\x18\x01 \x01(\x0e\x32\x18.mindspore.irpb.DataType\x12\x10\n\x08\x62ool_val\x18\x02 \x01(\x08\x12\x0f\n\x07int_val\x18\x03 \x01(\x03\x12\x10\n\x08uint_val\x18\x04 \x01(\x04\x12\x11\n\tfloat_val\x18\x05 \x01(\x02\x12\x12\n\ndouble_val\x18\x06 \x01(\x01\x12\x0f\n\x07str_val\x18\x07 \x01(\t\x12/\n\ntensor_val\x18\x08 \x01(\x0b\x32\x1b.mindspore.irpb.TensorProto\x12)\n\x05graph\x18\t \x01(\x0b\x32\x1a.mindspore.irpb.GraphProto\x12\x11\n\tbool_vals\x18\n \x03(\x08\x12\x10\n\x08int_vals\x18\x0b \x03(\x03\x12\x11\n\tuint_vals\x18\x0c \x03(\x04\x12\x12\n\nfloat_vals\x18\r \x03(\x02\x12\x13\n\x0b\x64ouble_vals\x18\x0e \x03(\x01\x12\x10\n\x08str_vals\x18\x0f \x03(\t\x12\x30\n\x0btensor_vals\x18\x10 \x03(\x0b\x32\x1b.mindspore.irpb.TensorProto\x12*\n\x06graphs\x18\x11 \x03(\x0b\x32\x1a.mindspore.irpb.GraphProto\x12*\n\x06values\x18\x12 \x03(\x0b\x32\x1a.mindspore.irpb.ValueProto\x12\x31\n\x08\x64ict_val\x18\x13 \x03(\x0b\x32\x1f.mindspore.irpb.NamedValueProto\x12+\n\x08type_val\x18\x14 \x01(\x0b\x32\x19.mindspore.irpb.TypeProto\"I\n\x0e\x41ttributeProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.mindspore.irpb.ValueProto\"I\n\x0fNamedValueProto\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.mindspore.irpb.ValueProto\"t\n\x10TensorShapeProto\x12\x37\n\x03\x64im\x18\x01 \x03(\x0b\x32*.mindspore.irpb.TensorShapeProto.Dimension\x1a\'\n\tDimension\x12\x0c\n\x04size\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\t\"\xda\x02\n\tTypeProto\x12+\n\tdata_type\x18\x01 \x01(\x0e\x32\x18.mindspore.irpb.DataType\x12\x37\n\x0btensor_type\x18\x02 \x01(\x0b\x32 .mindspore.irpb.TypeProto.TensorH\x00\x12;\n\rsequence_type\x18\x03 \x01(\x0b\x32\".mindspore.irpb.TypeProto.SequenceH\x00\x1a\x66\n\x06Tensor\x12+\n\telem_type\x18\x01 \x01(\x0e\x32\x18.mindspore.irpb.DataType\x12/\n\x05shape\x18\x02 \x01(\x0b\x32 .mindspore.irpb.TensorShapeProto\x1a\x39\n\x08Sequence\x12-\n\nelem_types\x18\x01 \x03(\x0b\x32\x19.mindspore.irpb.TypeProtoB\x07\n\x05value\"x\n\x0eParameterProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\'\n\x04type\x18\x02 \x01(\x0b\x32\x19.mindspore.irpb.TypeProto\x12/\n\x0b\x64\x65\x66\x61ult_val\x18\x03 \x01(\x0b\x32\x1a.mindspore.irpb.ValueProto\"D\n\x0bOutputProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\'\n\x04type\x18\x02 \x01(\x0b\x32\x19.mindspore.irpb.TypeProto\"z\n\nInputProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x31\n\x04type\x18\x02 \x01(\x0e\x32#.mindspore.irpb.InputProto.EdgeType\"+\n\x08\x45\x64geType\x12\r\n\tDATA_EDGE\x10\x00\x12\x10\n\x0c\x43ONTROL_EDGE\x10\x01\"\x83\x02\n\tNodeProto\x12)\n\x05input\x18\x01 \x03(\x0b\x32\x1a.mindspore.irpb.InputProto\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0f\n\x07op_type\x18\x03 \x01(\t\x12\r\n\x05scope\x18\x04 \x01(\t\x12\x31\n\tattribute\x18\x05 \x03(\x0b\x32\x1e.mindspore.irpb.AttributeProto\x12.\n\x0boutput_type\x18\x06 \x01(\x0b\x32\x19.mindspore.irpb.TypeProto\x12\x10\n\x08output_i\x18\x07 \x01(\x04\x12\x11\n\tfull_name\x18\x08 \x01(\t\x12\x15\n\rinstance_name\x18\n \x01(\t\"\xb0\x01\n\nModelProto\x12\x12\n\nir_version\x18\x01 \x01(\x03\x12\x0e\n\x06\x64omain\x18\x02 \x01(\t\x12\x15\n\rmodel_version\x18\x03 \x01(\x03\x12)\n\x05graph\x18\x04 \x01(\x0b\x32\x1a.mindspore.irpb.GraphProto\x12<\n\x12metadata_operators\x18\x05 \x01(\x0b\x32 .mindspore.irpb.OperatorSetProto\"?\n\rOperatorProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0e\n\x06\x63onfig\x18\x02 \x01(\x0c\x12\x10\n\x08obj_info\x18\x03 \x01(\x0c\"U\n\x10OperatorSetProto\x12\x30\n\toperators\x18\x01 \x03(\x0b\x32\x1d.mindspore.irpb.OperatorProto\x12\x0f\n\x07summary\x18\x02 \x01(\t\"\xda\x01\n\nGraphProto\x12\'\n\x04node\x18\x01 \x03(\x0b\x32\x19.mindspore.irpb.NodeProto\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x32\n\nparameters\x18\x03 \x03(\x0b\x32\x1e.mindspore.irpb.ParameterProto\x12,\n\x07outputs\x18\x04 \x03(\x0b\x32\x1b.mindspore.irpb.OutputProto\x12\x33\n\nconst_vals\x18\x05 \x03(\x0b\x32\x1f.mindspore.irpb.NamedValueProto\"\xd4\x01\n\x0bTensorProto\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\x12+\n\tdata_type\x18\x02 \x01(\x0e\x32\x18.mindspore.irpb.DataType\x12\x16\n\nfloat_data\x18\x03 \x03(\x02\x42\x02\x10\x01\x12\x16\n\nint32_data\x18\x04 \x03(\x05\x42\x02\x10\x01\x12\x16\n\nint64_data\x18\x05 \x03(\x03\x42\x02\x10\x01\x12\x17\n\x0b\x64ouble_data\x18\x06 \x03(\x01\x42\x02\x10\x01\x12\x17\n\x0buint64_data\x18\x07 \x03(\x04\x42\x02\x10\x01\x12\x10\n\x08raw_data\x18\x08 \x01(\x0c*/\n\x07Version\x12\x14\n\x10UNKNOWWN_VERSION\x10\x00\x12\x0e\n\nIR_VERSION\x10\x01*\xfb\x05\n\x08\x44\x61taType\x12\x10\n\x0c\x44T_UNDEFINED\x10\x00\x12\x0b\n\x07\x44T_BOOL\x10\x01\x12\x0b\n\x07\x44T_INT8\x10\x02\x12\x0c\n\x08\x44T_INT16\x10\x03\x12\x0c\n\x08\x44T_INT32\x10\x04\x12\x0c\n\x08\x44T_INT64\x10\x05\x12\x0c\n\x08\x44T_UINT8\x10\x06\x12\r\n\tDT_UINT16\x10\x07\x12\r\n\tDT_UINT32\x10\x08\x12\r\n\tDT_UINT64\x10\t\x12\x0e\n\nDT_FLOAT16\x10\n\x12\x0e\n\nDT_FLOAT32\x10\x0b\x12\x0e\n\nDT_FLOAT64\x10\x0c\x12\r\n\tDT_STRING\x10\r\x12\r\n\tDT_TENSOR\x10\x0e\x12\x0c\n\x08\x44T_GRAPH\x10\x0f\x12\x0c\n\x08\x44T_BOOLS\x10\x10\x12\x0c\n\x08\x44T_INTS8\x10\x11\x12\r\n\tDT_INTS16\x10\x12\x12\r\n\tDT_INTS32\x10\x13\x12\r\n\tDT_INTS64\x10\x14\x12\r\n\tDT_UINTS8\x10\x15\x12\x0e\n\nDT_UINTS16\x10\x16\x12\x0e\n\nDT_UINTS32\x10\x17\x12\x0e\n\nDT_UINTS64\x10\x18\x12\x0f\n\x0b\x44T_FLOATS16\x10\x19\x12\x0f\n\x0b\x44T_FLOATS32\x10\x1a\x12\x0f\n\x0b\x44T_FLOATS64\x10\x1b\x12\x0e\n\nDT_STRINGS\x10\x1c\x12\x0e\n\nDT_TENSORS\x10\x1d\x12\r\n\tDT_GRAPHS\x10\x1e\x12\x0c\n\x08\x44T_TUPLE\x10\x1f\x12\x0b\n\x07\x44T_LIST\x10 \x12\x0b\n\x07\x44T_DICT\x10!\x12\x0b\n\x07\x44T_NONE\x10\"\x12\x0f\n\x0b\x44T_SYM_INST\x10#\x12\x0f\n\x0b\x44T_BASE_INT\x10$\x12\x10\n\x0c\x44T_BASE_UINT\x10%\x12\x11\n\rDT_BASE_FLOAT\x10&\x12\x0b\n\x07\x44T_TYPE\x10\'\x12\n\n\x06\x44T_ANY\x10(\x12\r\n\tDT_REFKEY\x10)\x12\n\n\x06\x44T_REF\x10*\x12\x10\n\x0c\x44T_COMPLEX64\x10+\x12\x11\n\rDT_COMPLEX128\x10,\x12\x13\n\x0f\x44T_BASE_COMPLEX\x10-\x12\x0f\n\x0b\x44T_BFLOAT16\x10.\x12\x10\n\x0c\x44T_BFLOATS16\x10/\x12\x0b\n\x07\x44T_INT4\x10\x30'
|
|
24
24
|
)
|
|
25
25
|
|
|
26
26
|
_VERSION = _descriptor.EnumDescriptor(
|
|
@@ -296,11 +296,16 @@ _DATATYPE = _descriptor.EnumDescriptor(
|
|
|
296
296
|
serialized_options=None,
|
|
297
297
|
type=None,
|
|
298
298
|
create_key=_descriptor._internal_create_key),
|
|
299
|
+
_descriptor.EnumValueDescriptor(
|
|
300
|
+
name='DT_INT4', index=48, number=48,
|
|
301
|
+
serialized_options=None,
|
|
302
|
+
type=None,
|
|
303
|
+
create_key=_descriptor._internal_create_key),
|
|
299
304
|
],
|
|
300
305
|
containing_type=None,
|
|
301
306
|
serialized_options=None,
|
|
302
307
|
serialized_start=2650,
|
|
303
|
-
serialized_end=
|
|
308
|
+
serialized_end=3413,
|
|
304
309
|
)
|
|
305
310
|
_sym_db.RegisterEnumDescriptor(_DATATYPE)
|
|
306
311
|
|
|
@@ -355,6 +360,7 @@ DT_COMPLEX128 = 44
|
|
|
355
360
|
DT_BASE_COMPLEX = 45
|
|
356
361
|
DT_BFLOAT16 = 46
|
|
357
362
|
DT_BFLOATS16 = 47
|
|
363
|
+
DT_INT4 = 48
|
|
358
364
|
|
|
359
365
|
|
|
360
366
|
_INPUTPROTO_EDGETYPE = _descriptor.EnumDescriptor(
|
|
@@ -50,13 +50,13 @@ class BackupAndRestore(Callback):
|
|
|
50
50
|
>>> from mindspore.train import Model, BackupAndRestore, RunContext
|
|
51
51
|
>>>
|
|
52
52
|
>>> # Define the network structure of LeNet5. Refer to
|
|
53
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
53
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
54
54
|
>>> net = LeNet5()
|
|
55
55
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
|
56
56
|
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
|
|
57
57
|
>>> model = Model(net, loss_fn=loss, optimizer=optim)
|
|
58
58
|
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
59
|
-
>>> # https://gitee.com/mindspore/docs/blob/
|
|
59
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
60
60
|
>>> dataset = create_dataset()
|
|
61
61
|
>>> backup_ckpt = BackupAndRestore("backup")
|
|
62
62
|
>>> model.train(10, dataset, callbacks=backup_ckpt)
|
|
@@ -123,7 +123,7 @@ class Callback:
|
|
|
123
123
|
recording current attributes. Users can add custimized attributes to the information.
|
|
124
124
|
Training process can also be stopped by calling `request_stop` method. For details
|
|
125
125
|
of custom Callback, please check
|
|
126
|
-
`Callback tutorial <https://www.mindspore.cn/tutorials/en/
|
|
126
|
+
`Callback tutorial <https://www.mindspore.cn/tutorials/en/master/advanced/model/
|
|
127
127
|
callback.html#customized-callback-mechanism>`_.
|
|
128
128
|
|
|
129
129
|
Examples:
|
|
@@ -493,7 +493,7 @@ class RunContext:
|
|
|
493
493
|
`RunContext.original_args()` and add extra attributes to the information, but also can stop the
|
|
494
494
|
training process by calling `request_stop` method. For details of custom Callback,
|
|
495
495
|
please check
|
|
496
|
-
`Callback Mechanism <https://www.mindspore.cn/tutorials/en/
|
|
496
|
+
`Callback Mechanism <https://www.mindspore.cn/tutorials/en/master/advanced/model/callback.html>`_.
|
|
497
497
|
|
|
498
498
|
`RunContext.original_args()` holds the model context information as a dictionary variable, and
|
|
499
499
|
different attributes of the dictionary are stored in training or eval process. Details are as follows:
|
|
@@ -575,7 +575,7 @@ class RunContext:
|
|
|
575
575
|
|
|
576
576
|
Tutorial Examples:
|
|
577
577
|
- `Callback Mechanism - Customized Callback Mechanism
|
|
578
|
-
<https://mindspore.cn/tutorials/en/
|
|
578
|
+
<https://mindspore.cn/tutorials/en/master/advanced/model/callback.html#customized-callback-mechanism>`_
|
|
579
579
|
"""
|
|
580
580
|
return self._original_args
|
|
581
581
|
|
|
@@ -588,7 +588,7 @@ class RunContext:
|
|
|
588
588
|
|
|
589
589
|
Tutorial Examples:
|
|
590
590
|
- `Callback Mechanism - Customized Training Termination Time
|
|
591
|
-
<https://mindspore.cn/tutorials/en/
|
|
591
|
+
<https://mindspore.cn/tutorials/en/master/advanced/model/callback.html#
|
|
592
592
|
customized-training-termination-time>`_
|
|
593
593
|
"""
|
|
594
594
|
self._stop_requested = True
|