mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +8 -1
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +275 -64
- mindspore/common/_utils.py +0 -44
- mindspore/common/api.py +285 -35
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
- mindspore/common/hook_handle.py +60 -0
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/parameter.py +13 -107
- mindspore/common/recompute.py +4 -11
- mindspore/common/tensor.py +16 -169
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +5 -85
- mindspore/dataset/engine/datasets.py +8 -4
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dnnl.dll +0 -0
- mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +1 -1
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/distributed.py +182 -62
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +4 -4
- mindspore/mint/nn/layer/normalization.py +8 -13
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +16 -66
- mindspore/nn/layer/basic.py +49 -1
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +31 -124
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +0 -1
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +0 -1
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/wrap/grad_reducer.py +4 -74
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/fft.py +9 -9
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +0 -5
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
- mindspore/ops/auto_generate/gen_extend_func.py +2 -7
- mindspore/ops/auto_generate/gen_ops_def.py +98 -141
- mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +15 -1
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +1 -0
- mindspore/ops/function/array_func.py +14 -12
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +3 -4
- mindspore/ops/function/math_func.py +45 -54
- mindspore/ops/function/nn_func.py +75 -294
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +2 -0
- mindspore/ops/functional_overload.py +354 -18
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +1 -38
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +1 -0
- mindspore/ops/operations/comm_ops.py +94 -2
- mindspore/ops/operations/custom_ops.py +228 -19
- mindspore/ops/operations/debug_ops.py +27 -29
- mindspore/ops/operations/manually_defined/ops_def.py +27 -306
- mindspore/ops/operations/nn_ops.py +2 -2
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +1 -17
- mindspore/ops/tensor_method.py +72 -3
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +1 -4
- mindspore/parallel/_utils.py +29 -6
- mindspore/parallel/checkpoint_transform.py +18 -2
- mindspore/parallel/cluster/process_entity/_api.py +24 -32
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +117 -16
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +38 -2
- mindspore/profiler/common/path_manager.py +56 -24
- mindspore/profiler/common/profiler_context.py +2 -12
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/experimental_config.py +2 -1
- mindspore/profiler/platform/npu_profiler.py +33 -6
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +3 -2
- mindspore/runtime/executor.py +11 -3
- mindspore/runtime/memory.py +112 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +5 -18
- mindspore/train/amp.py +6 -4
- mindspore/train/callback/_checkpoint.py +0 -9
- mindspore/train/callback/_train_fault_tolerance.py +69 -18
- mindspore/train/data_sink.py +1 -5
- mindspore/train/model.py +38 -211
- mindspore/train/serialization.py +126 -387
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +144 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/train/memory_profiling_pb2.py +0 -298
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
mindspore/common/api.py
CHANGED
|
@@ -44,7 +44,7 @@ from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
|
|
|
44
44
|
from mindspore._c_expression.amp import get_curr_amp_strategy
|
|
45
45
|
from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
|
|
46
46
|
PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
|
|
47
|
-
_run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx,
|
|
47
|
+
_run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, TensorPy as Tensor, dump_func_graph, _GraphFragment_
|
|
48
48
|
from mindspore.parallel._ps_context import _is_role_sched
|
|
49
49
|
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
|
|
50
50
|
_is_parallel_mode
|
|
@@ -208,6 +208,11 @@ def _handle_func_args(func, *args, **kwargs):
|
|
|
208
208
|
args = bound_arguments.args
|
|
209
209
|
kwargs = bound_arguments.kwargs
|
|
210
210
|
|
|
211
|
+
return args, kwargs
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _check_func_args(func, *args):
|
|
215
|
+
"""Check the *args inputs of the function"""
|
|
211
216
|
positional_args = 0
|
|
212
217
|
default_args = 0
|
|
213
218
|
has_var = False
|
|
@@ -221,14 +226,13 @@ def _handle_func_args(func, *args, **kwargs):
|
|
|
221
226
|
default_args += 1
|
|
222
227
|
|
|
223
228
|
if has_var:
|
|
224
|
-
return
|
|
229
|
+
return
|
|
225
230
|
|
|
226
231
|
if len(args) < positional_args:
|
|
227
232
|
raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument, but got {len(args)}.")
|
|
228
233
|
if len(args) > positional_args + default_args:
|
|
229
234
|
raise TypeError(f"Function {func.__name__} needs {positional_args} positional argument and {default_args} "
|
|
230
235
|
f"default argument, total {positional_args + default_args}, but got {len(args)}.")
|
|
231
|
-
return args, kwargs
|
|
232
236
|
|
|
233
237
|
|
|
234
238
|
sys_path = list(sys.path)
|
|
@@ -349,7 +353,7 @@ def _get_parameter_layout():
|
|
|
349
353
|
return layout
|
|
350
354
|
|
|
351
355
|
|
|
352
|
-
def _handle_arg(obj, arg, has_mutable_arg):
|
|
356
|
+
def _handle_arg(obj, arg, has_mutable_arg, is_predict):
|
|
353
357
|
"""Handle arg for runtime .If need handle the arg, return True"""
|
|
354
358
|
from mindspore._extends.parse import compile_config
|
|
355
359
|
if isinstance(arg, PythonTensor):
|
|
@@ -364,7 +368,7 @@ def _handle_arg(obj, arg, has_mutable_arg):
|
|
|
364
368
|
if isinstance(arg, list) and not arg:
|
|
365
369
|
return None
|
|
366
370
|
return arg
|
|
367
|
-
elif (context.get_context("grad_for_scalar") or str(compile_config.GRAD_FOR_SCALAR) == '1') and \
|
|
371
|
+
elif not is_predict and (context.get_context("grad_for_scalar") or str(compile_config.GRAD_FOR_SCALAR) == '1') and \
|
|
368
372
|
isinstance(arg, (int, float)):
|
|
369
373
|
return arg
|
|
370
374
|
elif hasattr(obj, "enable_tuple_broaden") and obj.enable_tuple_broaden and isinstance(arg, tuple) and \
|
|
@@ -394,17 +398,16 @@ def _handle_arg_predict(obj, arg, has_mutable_arg):
|
|
|
394
398
|
return arg
|
|
395
399
|
|
|
396
400
|
|
|
397
|
-
def _get_args_for_run(obj, args, kwargs, has_mutable_args_list, is_predict):
|
|
401
|
+
def _get_args_for_run(obj, args, kwargs, has_mutable_args_list, is_predict=False):
|
|
398
402
|
"""Get the actual input args and kwargs for runtime."""
|
|
399
403
|
new_args = []
|
|
400
|
-
fn = _handle_arg_predict if is_predict else _handle_arg
|
|
401
404
|
for arg, has_mutable_arg in zip(args, has_mutable_args_list):
|
|
402
|
-
new_arg =
|
|
405
|
+
new_arg = _handle_arg(obj, arg, has_mutable_arg, is_predict)
|
|
403
406
|
if new_arg is not None:
|
|
404
407
|
new_args.append(new_arg)
|
|
405
408
|
|
|
406
409
|
for _, value in kwargs.items():
|
|
407
|
-
new_value =
|
|
410
|
+
new_value = _handle_arg(obj, value, None, is_predict)
|
|
408
411
|
if new_value is not None:
|
|
409
412
|
new_args.append(new_value)
|
|
410
413
|
|
|
@@ -609,7 +612,7 @@ class _JitExecutor:
|
|
|
609
612
|
else:
|
|
610
613
|
self._graph_executor = GraphExecutor_.get_instance()
|
|
611
614
|
self._create_time = ms_create_time
|
|
612
|
-
self.
|
|
615
|
+
self._mutable_flags = None
|
|
613
616
|
self._enable_auto_dynamic = dynamic == 1
|
|
614
617
|
self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
|
|
615
618
|
self._cell_cache_key_extend = cell_cache_key_extend
|
|
@@ -634,16 +637,8 @@ class _JitExecutor:
|
|
|
634
637
|
except Exception as err:
|
|
635
638
|
_pynative_executor.clear_res()
|
|
636
639
|
raise err
|
|
637
|
-
else: # get compiled args to generate run args by _generate_run_args
|
|
638
|
-
compile_args = self._generate_compile_args(args_list)
|
|
639
|
-
key_id = self._get_key_id()
|
|
640
|
-
if self.input_signature is None:
|
|
641
|
-
compile_args = get_auto_dynamic_shape_args(
|
|
642
|
-
compile_args, key_id, self._enable_auto_dynamic
|
|
643
|
-
)
|
|
644
|
-
self._compile_args = compile_args
|
|
645
640
|
|
|
646
|
-
new_inputs = self._generate_run_args(args_list, kwargs)
|
|
641
|
+
new_inputs = self._generate_run_args(args_list, kwargs, is_predict=True)
|
|
647
642
|
if self.jit_config_dict:
|
|
648
643
|
jit_config_dict = self.jit_config_dict
|
|
649
644
|
else:
|
|
@@ -656,11 +651,25 @@ class _JitExecutor:
|
|
|
656
651
|
res = _convert_python_data(output)
|
|
657
652
|
return True, res
|
|
658
653
|
|
|
654
|
+
def compile_frontend(self, *args, **kwargs):
|
|
655
|
+
"""Only compile to the frontend graph."""
|
|
656
|
+
args_list = args
|
|
657
|
+
if self.obj is not None:
|
|
658
|
+
args_list = args_list[1:]
|
|
659
|
+
os.environ['MS_DEV_PRECOMPILE_ONLY'] = '1'
|
|
660
|
+
phase = ""
|
|
661
|
+
_pynative_executor.set_jit_compile_phase(phase)
|
|
662
|
+
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
663
|
+
_pynative_executor.set_jit_compile_phase(phase)
|
|
664
|
+
os.unsetenv('MS_DEV_PRECOMPILE_ONLY')
|
|
665
|
+
return self._graph_executor.get_func_graph(phase), self._mutable_flags, phase, self.enable_tuple_broaden
|
|
666
|
+
|
|
659
667
|
@_wrap_func
|
|
660
668
|
def __call__(self, *args, **kwargs):
|
|
661
669
|
predict, res = self._predict(*args, **kwargs)
|
|
662
670
|
if predict:
|
|
663
671
|
return res
|
|
672
|
+
_check_func_args(self.fn, *args)
|
|
664
673
|
if jit_context() and jit_context().is_nested():
|
|
665
674
|
return jit_context().run_graph("", None, *())
|
|
666
675
|
args_list = args
|
|
@@ -668,9 +677,9 @@ class _JitExecutor:
|
|
|
668
677
|
args_list = args_list[1:]
|
|
669
678
|
phase = ""
|
|
670
679
|
try:
|
|
671
|
-
_pynative_executor.
|
|
680
|
+
_pynative_executor.set_jit_compile_phase(phase)
|
|
672
681
|
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
673
|
-
_pynative_executor.
|
|
682
|
+
_pynative_executor.set_jit_compile_phase(phase)
|
|
674
683
|
except Exception as err:
|
|
675
684
|
_pynative_executor.clear_res()
|
|
676
685
|
raise err
|
|
@@ -694,6 +703,7 @@ class _JitExecutor:
|
|
|
694
703
|
def compile(self, method_name, *args, **kwargs):
|
|
695
704
|
"""Returns pipeline for the given args."""
|
|
696
705
|
# Chose dynamic shape tensors or actual input tensors as compile args.
|
|
706
|
+
self._graph_executor.set_real_args(args, kwargs)
|
|
697
707
|
compile_args = self._generate_compile_args(args)
|
|
698
708
|
key_id = self._get_key_id()
|
|
699
709
|
if self.input_signature is None:
|
|
@@ -705,7 +715,11 @@ class _JitExecutor:
|
|
|
705
715
|
# 1) Origin args is mutable.
|
|
706
716
|
# 2) Args contains sequence with gradient tensor.
|
|
707
717
|
compile_args = _add_mutable_attr(args, compile_args, _pynative_executor.requires_grad())
|
|
708
|
-
|
|
718
|
+
mutable_flags = _get_mutable_flags(compile_args)
|
|
719
|
+
self._mutable_flags = mutable_flags
|
|
720
|
+
# Store the _mutable_flags in the cell obj for incremental inference.
|
|
721
|
+
if self.obj is not None:
|
|
722
|
+
self.obj._mutable_flags = mutable_flags
|
|
709
723
|
generate_name, echo_function_name = self._get_generate_name()
|
|
710
724
|
# The full Function name
|
|
711
725
|
full_function_name = generate_name
|
|
@@ -839,6 +853,7 @@ class _JitExecutor:
|
|
|
839
853
|
else:
|
|
840
854
|
_pynative_executor.set_dynamic_input(self.fn, *compile_args)
|
|
841
855
|
logger.info(f"dynamic shape compile_args: {compile_args}")
|
|
856
|
+
Validator.check_symbolic_shape(compile_args, args_list)
|
|
842
857
|
return compile_args
|
|
843
858
|
|
|
844
859
|
def _generate_compile_args_by_set_inputs(self, args_list):
|
|
@@ -895,7 +910,7 @@ class _JitExecutor:
|
|
|
895
910
|
# Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
|
|
896
911
|
return _pynative_executor.get_dynamic_input(args_list)
|
|
897
912
|
|
|
898
|
-
def _generate_run_args(self, args_list, kwargs):
|
|
913
|
+
def _generate_run_args(self, args_list, kwargs, is_predict=False):
|
|
899
914
|
"""
|
|
900
915
|
Generate input args, which are required for running.
|
|
901
916
|
|
|
@@ -906,7 +921,11 @@ class _JitExecutor:
|
|
|
906
921
|
Returns:
|
|
907
922
|
new_inputs, new input args, which are required for running.
|
|
908
923
|
"""
|
|
909
|
-
|
|
924
|
+
if self.obj is not None and hasattr(self.obj, '_mutable_flags'):
|
|
925
|
+
mutable_flags = self.obj._mutable_flags
|
|
926
|
+
else:
|
|
927
|
+
mutable_flags = self._mutable_flags
|
|
928
|
+
return _get_args_for_run(self, args_list, kwargs, mutable_flags, is_predict)
|
|
910
929
|
|
|
911
930
|
def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
|
|
912
931
|
"""Get graph proto from pipeline."""
|
|
@@ -978,7 +997,7 @@ def _check_option_backend(option, backend):
|
|
|
978
997
|
'ge_options': ['GE'],
|
|
979
998
|
'infer_boost': ['ms_backend'],
|
|
980
999
|
}
|
|
981
|
-
if option in option_backend_cfgs and backend not in option_backend_cfgs[option]:
|
|
1000
|
+
if option in option_backend_cfgs and backend != '' and backend not in option_backend_cfgs[option]:
|
|
982
1001
|
logger.warning(f"For 'jit(options)', the option '{option}' is only support backend in "
|
|
983
1002
|
f"'{option_backend_cfgs[option]}', but got '{backend}', ignore it.")
|
|
984
1003
|
|
|
@@ -1187,8 +1206,10 @@ def jit(
|
|
|
1187
1206
|
- ms_backend: Utilizes the built-in backend engine of MindSpore for hardware-related compilation
|
|
1188
1207
|
optimization and execution, supporting multiple hardware forms such as Ascend, GPU, and CPU.
|
|
1189
1208
|
- GE: Utilizes the GraphEngine, a graph compilation and execution engine within CANN,
|
|
1190
|
-
for Ascend model compilation and execution. Note: This backend
|
|
1191
|
-
|
|
1209
|
+
for Ascend model compilation and execution. Note: This backend only supports GRAPH Mode in Ascend,
|
|
1210
|
+
only supports whole graph sinking or sub graph sinking in pipeline parallel, and does not support
|
|
1211
|
+
dynamic shape scenes. In addition, this backend incurs additional compilation costs and is difficult to
|
|
1212
|
+
debug and tune.
|
|
1192
1213
|
|
|
1193
1214
|
**options (dict): A dictionary of options to pass to the compilation backend.
|
|
1194
1215
|
|
|
@@ -1333,9 +1354,8 @@ def jit(
|
|
|
1333
1354
|
jit_level = Validator.check_string(jit_level, ["O0", "O1"], "jit_level", "jit")
|
|
1334
1355
|
dynamic = Validator.check_int_range(dynamic, 0, 1, Validator.INC_BOTH, "dynamic", "jit")
|
|
1335
1356
|
fullgraph = Validator.check_bool(fullgraph, "fullgraph", "jit")
|
|
1336
|
-
if backend
|
|
1337
|
-
backend =
|
|
1338
|
-
backend = Validator.check_string(backend, ["ms_backend", "GE"], "backend", "jit")
|
|
1357
|
+
if backend != "":
|
|
1358
|
+
backend = Validator.check_string(backend, ["ms_backend", "GE"], "backend", "jit")
|
|
1339
1359
|
jit_syntax_level = "LAX" if fullgraph is False else "STRICT"
|
|
1340
1360
|
hash_obj = _get_hash_obj(options)
|
|
1341
1361
|
_check_options(options, backend)
|
|
@@ -1350,7 +1370,7 @@ def jit(
|
|
|
1350
1370
|
elif capture_mode == "bytecode":
|
|
1351
1371
|
wrap_func = PIJitCaptureContext(fullgraph=fullgraph, jit_config=jit_config)
|
|
1352
1372
|
else:
|
|
1353
|
-
wrap_func = _jit_trace()
|
|
1373
|
+
wrap_func = _jit_trace(jit_config)
|
|
1354
1374
|
|
|
1355
1375
|
if function is not None:
|
|
1356
1376
|
return wrap_func(function)
|
|
@@ -1557,6 +1577,20 @@ def _parameter_broadcast(obj):
|
|
|
1557
1577
|
_build_broadcast_graph(broadcast_params_dict, broadcast_phase)
|
|
1558
1578
|
|
|
1559
1579
|
|
|
1580
|
+
def _run_in_jit():
|
|
1581
|
+
"""In jit, this function always returns true. Otherwise, returns false."""
|
|
1582
|
+
def _temp_func():
|
|
1583
|
+
return 0
|
|
1584
|
+
|
|
1585
|
+
from mindspore.ops.primitive import constexpr
|
|
1586
|
+
|
|
1587
|
+
@constexpr(check=False)
|
|
1588
|
+
def _check_func(func):
|
|
1589
|
+
return func is None
|
|
1590
|
+
|
|
1591
|
+
return _check_func(_temp_func)
|
|
1592
|
+
|
|
1593
|
+
|
|
1560
1594
|
class _no_grad(contextlib.ContextDecorator):
|
|
1561
1595
|
"""
|
|
1562
1596
|
Context Manager to disable gradient calculation. When enter this context, we will disable calculate
|
|
@@ -1826,17 +1860,16 @@ class _PyNativeExecutor:
|
|
|
1826
1860
|
"""
|
|
1827
1861
|
return self._executor.requires_grad()
|
|
1828
1862
|
|
|
1829
|
-
def
|
|
1863
|
+
def set_jit_compile_phase(self, phase):
|
|
1830
1864
|
"""
|
|
1831
|
-
Set jit
|
|
1865
|
+
Set jit phase
|
|
1832
1866
|
|
|
1833
1867
|
Args:
|
|
1834
|
-
status(bool): jit compile status
|
|
1835
1868
|
phase (str): The phase of cell/function instance.
|
|
1836
1869
|
Return:
|
|
1837
1870
|
None.
|
|
1838
1871
|
"""
|
|
1839
|
-
self._executor.
|
|
1872
|
+
self._executor.set_jit_compile_phase(phase)
|
|
1840
1873
|
|
|
1841
1874
|
def set_is_run_recompute(self, status):
|
|
1842
1875
|
"""
|
|
@@ -1934,6 +1967,19 @@ class _PyNativeExecutor:
|
|
|
1934
1967
|
"""
|
|
1935
1968
|
return self._executor.set_creation_type(tensor, creation_type)
|
|
1936
1969
|
|
|
1970
|
+
def queue_backward_final_callback(self, callback):
|
|
1971
|
+
"""
|
|
1972
|
+
add backward final callback
|
|
1973
|
+
|
|
1974
|
+
Args:
|
|
1975
|
+
callback(Function): callback function.
|
|
1976
|
+
|
|
1977
|
+
Return:
|
|
1978
|
+
None.
|
|
1979
|
+
"""
|
|
1980
|
+
return self._executor.queue_backward_final_callback(callback)
|
|
1981
|
+
|
|
1982
|
+
|
|
1937
1983
|
|
|
1938
1984
|
class _CellGraphExecutor:
|
|
1939
1985
|
"""
|
|
@@ -2075,6 +2121,8 @@ class _CellGraphExecutor:
|
|
|
2075
2121
|
obj.phase_cache[raw_phase] = phase
|
|
2076
2122
|
update_auto_dynamic_shape_phase(args, key_id, phase)
|
|
2077
2123
|
obj.current_phase = phase
|
|
2124
|
+
obj._add_attr("compile_phase", phase)
|
|
2125
|
+
obj.compile_phase = phase
|
|
2078
2126
|
if phase in obj.compile_cache and self.has_compiled(phase):
|
|
2079
2127
|
logger.debug("%r graph has existed.", phase)
|
|
2080
2128
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
@@ -2124,6 +2172,10 @@ class _CellGraphExecutor:
|
|
|
2124
2172
|
new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
|
|
2125
2173
|
return self._graph_executor.updata_param_node_default_input(phase, new_param)
|
|
2126
2174
|
|
|
2175
|
+
def set_real_args(self, args, kwargs):
|
|
2176
|
+
"""Set real arguments to graph executor."""
|
|
2177
|
+
self._graph_executor.set_real_args(args, kwargs)
|
|
2178
|
+
|
|
2127
2179
|
def _get_shard_strategy(self, obj):
|
|
2128
2180
|
real_phase = _real_phase(obj.phase, obj)
|
|
2129
2181
|
return self._graph_executor.get_strategy(real_phase)
|
|
@@ -2213,6 +2265,19 @@ class _CellGraphExecutor:
|
|
|
2213
2265
|
return None
|
|
2214
2266
|
return self._graph_executor.get_func_graph_proto(exec_id, ir_type, incremental)
|
|
2215
2267
|
|
|
2268
|
+
def _get_onnx_func_graph_proto(self, obj, exec_id, use_prefix=False, input_names=None, output_names=None,
|
|
2269
|
+
opset_version=11, export_params=True, keep_initializers_as_inputs=False,
|
|
2270
|
+
dynamic_axes=None, extra_save_params=False, save_file_dir=None):
|
|
2271
|
+
"""Get graph proto from pipeline."""
|
|
2272
|
+
if use_prefix:
|
|
2273
|
+
exec_id = exec_id + '.' + obj.arguments_key
|
|
2274
|
+
if self._graph_executor.has_compiled(exec_id) is False:
|
|
2275
|
+
return None
|
|
2276
|
+
|
|
2277
|
+
return self._graph_executor.get_onnx_func_graph_proto(exec_id, input_names, output_names, opset_version,
|
|
2278
|
+
export_params, keep_initializers_as_inputs, dynamic_axes,
|
|
2279
|
+
extra_save_params, save_file_dir)
|
|
2280
|
+
|
|
2216
2281
|
def get_optimize_graph_proto(self, obj):
|
|
2217
2282
|
"""Return optimize graph binary proto."""
|
|
2218
2283
|
exec_id = _real_phase(obj.phase, obj)
|
|
@@ -2295,5 +2360,190 @@ def flops_collection(phase='train'):
|
|
|
2295
2360
|
return _cell_graph_executor.flops_collection(phase)
|
|
2296
2361
|
|
|
2297
2362
|
|
|
2363
|
+
class _ScriptGraph:
|
|
2364
|
+
"""Store the graph compiled by the frontend compiler."""
|
|
2365
|
+
def __init__(self, func_graph, func, origin_cell, mutable_flags, phase, enable_tuple_broaden):
|
|
2366
|
+
self.func_graph = func_graph
|
|
2367
|
+
self.func = func
|
|
2368
|
+
self.origin_cell = origin_cell
|
|
2369
|
+
self.mutable_flags = mutable_flags
|
|
2370
|
+
self.phase = phase
|
|
2371
|
+
self.enable_tuple_broaden = enable_tuple_broaden
|
|
2372
|
+
|
|
2373
|
+
def print(self):
|
|
2374
|
+
"""Print the MindIR of the frontend graph."""
|
|
2375
|
+
graph_str = dump_func_graph(self.func_graph)
|
|
2376
|
+
print(graph_str, flush=True)
|
|
2377
|
+
|
|
2378
|
+
|
|
2379
|
+
def _frontend_compile_ast(dynamic, jit_config, jit_graph_name=''):
|
|
2380
|
+
"""Return the wrapped function for ast mode jit."""
|
|
2381
|
+
def wrap_func(func):
|
|
2382
|
+
if hasattr(func, "construct") and isinstance(func, ms.nn.Cell):
|
|
2383
|
+
# Bound the cell object to get the self arg.
|
|
2384
|
+
return types.MethodType(_frontend_compile_ast(dynamic, jit_config,
|
|
2385
|
+
func._jit_graph_name)(func.construct.__func__), func)
|
|
2386
|
+
|
|
2387
|
+
if isinstance(func, types.MethodType):
|
|
2388
|
+
return types.MethodType(_frontend_compile_ast(dynamic, jit_config)(func.__func__), func.__self__)
|
|
2389
|
+
|
|
2390
|
+
if not isinstance(func, types.FunctionType):
|
|
2391
|
+
logger.warning(f"The func should be function, method or cell instance/class, but got {func}")
|
|
2392
|
+
return func
|
|
2393
|
+
|
|
2394
|
+
hash_obj = int(time.time() * 1e9)
|
|
2395
|
+
|
|
2396
|
+
@wraps(func)
|
|
2397
|
+
def staging_specialize(*args, **kwargs):
|
|
2398
|
+
if os.getenv("MS_JIT") == '0':
|
|
2399
|
+
return func(*args, **kwargs)
|
|
2400
|
+
|
|
2401
|
+
args, kwargs = _handle_func_args(func, *args, **kwargs)
|
|
2402
|
+
process_obj = None
|
|
2403
|
+
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
|
2404
|
+
process_obj = args[0]
|
|
2405
|
+
# Handle auto mixed precision strategy.
|
|
2406
|
+
if not hasattr(func, "amp_strategy"):
|
|
2407
|
+
setattr(get_func(func), "amp_strategy", get_curr_amp_strategy())
|
|
2408
|
+
|
|
2409
|
+
jit_graph_name = ''
|
|
2410
|
+
if hasattr(staging_specialize, "__jit_graph_name__"):
|
|
2411
|
+
jit_graph_name = staging_specialize.__jit_graph_name__
|
|
2412
|
+
jit_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic, jit_graph_name)
|
|
2413
|
+
func_graph, mutable_flags, phase, enable_tuple_broaden = jit_executor.compile_frontend(*args, **kwargs)
|
|
2414
|
+
return _ScriptGraph(func_graph, func, process_obj, mutable_flags, phase, enable_tuple_broaden)
|
|
2415
|
+
|
|
2416
|
+
# `inspect.getfullargspec(func)` will get the specification of the decorated function by default. By set
|
|
2417
|
+
# `__signature__` for the decorated function, `inspect.getfullargspec(func)` will get the specification of
|
|
2418
|
+
# original `func`.
|
|
2419
|
+
staging_specialize.__signature__ = inspect.signature(func)
|
|
2420
|
+
setattr(staging_specialize, "__jit_graph_name__", jit_graph_name)
|
|
2421
|
+
return staging_specialize
|
|
2422
|
+
|
|
2423
|
+
return wrap_func
|
|
2424
|
+
|
|
2425
|
+
|
|
2426
|
+
def _frontend_compile(function: Callable,
|
|
2427
|
+
*,
|
|
2428
|
+
dynamic: int = 0,
|
|
2429
|
+
fullgraph: bool = False):
|
|
2430
|
+
"""
|
|
2431
|
+
Create a frontend MindSpore graph from a Python function by the ast capture mode.
|
|
2432
|
+
|
|
2433
|
+
Args:
|
|
2434
|
+
function (Callable, optional): The Python function or Cell instance that will be compiled as a frontend graph.
|
|
2435
|
+
Default: ``None``.
|
|
2436
|
+
|
|
2437
|
+
Keyword Args:
|
|
2438
|
+
dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
|
|
2439
|
+
is as follows:
|
|
2440
|
+
|
|
2441
|
+
- `0`: Do not perform dynamic shape compilation.
|
|
2442
|
+
- `1`: Enable dynamic shape compilation and automatically detect shape changes.
|
|
2443
|
+
|
|
2444
|
+
fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
|
|
2445
|
+
be compatible with all Python syntax in the function as much as possible. If True, we require that the
|
|
2446
|
+
entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
|
|
2447
|
+
not supported), then it will raise an exception. This currently only applies when capture_mode is ``ast``
|
|
2448
|
+
or ``bytecode``. Default: ``False``.
|
|
2449
|
+
|
|
2450
|
+
Returns:
|
|
2451
|
+
a :class:`_ScriptGraph` object.
|
|
2452
|
+
|
|
2453
|
+
Supported Platforms:
|
|
2454
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2455
|
+
|
|
2456
|
+
Examples:
|
|
2457
|
+
>>> import numpy as np
|
|
2458
|
+
>>> from mindspore import Tensor
|
|
2459
|
+
>>> from mindspore import ops
|
|
2460
|
+
>>> from mindspore.common.api import _frontend_compile
|
|
2461
|
+
...
|
|
2462
|
+
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
2463
|
+
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
2464
|
+
...
|
|
2465
|
+
>>> def tensor_add(x, y):
|
|
2466
|
+
... z = x + y
|
|
2467
|
+
... return z
|
|
2468
|
+
...
|
|
2469
|
+
>>> tensor_add_graph = _frontend_compile(tensor_add)(x, y)
|
|
2470
|
+
>>> tensor_add_graph.print()
|
|
2471
|
+
...
|
|
2472
|
+
"""
|
|
2473
|
+
|
|
2474
|
+
dynamic = Validator.check_int_range(dynamic, 0, 1, Validator.INC_BOTH, "dynamic", "jit")
|
|
2475
|
+
fullgraph = Validator.check_bool(fullgraph, "fullgraph", "jit")
|
|
2476
|
+
jit_syntax_level = "LAX" if fullgraph is False else "STRICT"
|
|
2477
|
+
jit_config = JitConfig(jit_syntax_level=jit_syntax_level)
|
|
2478
|
+
return _frontend_compile_ast(dynamic, jit_config)(function)
|
|
2479
|
+
|
|
2480
|
+
|
|
2481
|
+
class _GraphFragment(_GraphFragment_):
|
|
2482
|
+
"""
|
|
2483
|
+
Represents the output by backend graph split.
|
|
2484
|
+
"""
|
|
2485
|
+
def __init__(self, frag):
|
|
2486
|
+
if frag is None or not isinstance(frag, _GraphFragment_):
|
|
2487
|
+
raise TypeError(f"Expect input `frag` to be a _GraphFragment_, but got {type(frag)}")
|
|
2488
|
+
_GraphFragment_.__init__(self, frag)
|
|
2489
|
+
|
|
2490
|
+
def __call__(self, *args):
|
|
2491
|
+
return super().__call__(args)
|
|
2492
|
+
|
|
2493
|
+
def __repr__(self):
|
|
2494
|
+
return self.__str__()
|
|
2495
|
+
|
|
2496
|
+
def id(self):
|
|
2497
|
+
return self.id_()
|
|
2498
|
+
|
|
2499
|
+
def is_graph(self):
|
|
2500
|
+
return self.is_graph_()
|
|
2501
|
+
|
|
2502
|
+
def py_key(self):
|
|
2503
|
+
return self.py_key_()
|
|
2504
|
+
|
|
2505
|
+
def args_list(self):
|
|
2506
|
+
return self.args_list_()
|
|
2507
|
+
|
|
2508
|
+
|
|
2509
|
+
def _graph_split(script_graph):
|
|
2510
|
+
"""
|
|
2511
|
+
Split the script_graph into several fragments according to the nodes with the split op attribute.
|
|
2512
|
+
|
|
2513
|
+
Args:
|
|
2514
|
+
a :class:`_ScriptGraph` object.
|
|
2515
|
+
|
|
2516
|
+
Returns:
|
|
2517
|
+
several :class:`_GraphFragment` object.
|
|
2518
|
+
|
|
2519
|
+
Supported Platforms:
|
|
2520
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
2521
|
+
|
|
2522
|
+
Examples:
|
|
2523
|
+
>>> import numpy as np
|
|
2524
|
+
>>> from mindspore import Tensor
|
|
2525
|
+
>>> from mindspore import ops
|
|
2526
|
+
>>> from mindspore.common.api import _frontend_compile, _graph_split
|
|
2527
|
+
...
|
|
2528
|
+
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
2529
|
+
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
2530
|
+
>>> add = ops.Add().add_prim_attr("split_op", True).add_prim_attr("func_id", "add_func")
|
|
2531
|
+
...
|
|
2532
|
+
>>> def tensor_add(x, y):
|
|
2533
|
+
... z1 = x + y
|
|
2534
|
+
... z2 = add(z1, x)
|
|
2535
|
+
... return z2
|
|
2536
|
+
...
|
|
2537
|
+
>>> tensor_add_graph = _frontend_compile(tensor_add)(x, y)
|
|
2538
|
+
>>> frags = _graph_split(tensor_add_graph)
|
|
2539
|
+
>>> print(frags)
|
|
2540
|
+
...
|
|
2541
|
+
"""
|
|
2542
|
+
outputs = JitExecutor_.get_instance().split_graph(script_graph.func_graph)
|
|
2543
|
+
fragments = []
|
|
2544
|
+
for arg in outputs:
|
|
2545
|
+
fragments.append(_GraphFragment(arg))
|
|
2546
|
+
return fragments
|
|
2547
|
+
|
|
2298
2548
|
_cell_graph_executor = _CellGraphExecutor()
|
|
2299
2549
|
_pynative_executor = _PyNativeExecutor()
|
mindspore/common/dump.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2021-
|
|
1
|
+
# Copyright 2021-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -14,115 +14,14 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Controlling dump behavior."""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
|
-
from
|
|
18
|
-
|
|
19
|
-
import mindspore.context as context
|
|
20
|
-
from mindspore._c_expression import security
|
|
17
|
+
from mindspore.tools import set_dump as tools_set_dump
|
|
18
|
+
from mindspore.common._decorator import deprecated
|
|
21
19
|
|
|
22
20
|
|
|
21
|
+
@deprecated("2.7.1", "mindspore.tools.set_dump", module_prefix="mindspore.")
|
|
23
22
|
def set_dump(target, enabled=True):
|
|
24
23
|
"""
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
`target` should be an instance of :class:`mindspore.nn.Cell` or :class:`mindspore.ops.Primitive` .
|
|
28
|
-
Please note that this API takes effect only when the Dump function is enabled, and the `dump_mode`
|
|
29
|
-
field in the Dump configuration file is set to `"2"` with the `ms_backend` compilation backend
|
|
30
|
-
(please refer to the backend parameter in
|
|
31
|
-
`jit <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.jit.html>`_).
|
|
32
|
-
See the `dump document <https://www.mindspore.cn/tutorials/en/master/debug/dump.html>`_ for details.
|
|
33
|
-
The default enabled status for
|
|
34
|
-
a :class:`mindspore.nn.Cell` or :class:`mindspore.ops.Primitive` is False.
|
|
35
|
-
|
|
36
|
-
Note:
|
|
37
|
-
1. This API is only available for JIT compilation, requires 'Ascend' as the device_target and
|
|
38
|
-
`ms_backend` as the compilation backend (please refer to the backend parameter in
|
|
39
|
-
`jit <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.jit.html>`_),
|
|
40
|
-
and does not support fused operators.
|
|
41
|
-
2. This API only supports being called before training starts.
|
|
42
|
-
If you call this API during training, it may not be effective.
|
|
43
|
-
3. After using `set_dump(Cell, True)` , operators in forward and backward
|
|
44
|
-
computation (computation generated by the grad operations) of the
|
|
45
|
-
cell will be dumped.
|
|
46
|
-
4. For :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer, the forward
|
|
47
|
-
computation and backward computation use the same set of
|
|
48
|
-
operators. So you can only see dump data from backward computation.
|
|
49
|
-
Please note that :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer will also use
|
|
50
|
-
the above operators internally when initialized with `sparse=True` and
|
|
51
|
-
`reduction="mean"` .
|
|
52
|
-
|
|
53
|
-
Args:
|
|
54
|
-
target (Union[Cell, Primitive]): The Cell instance or Primitive instance
|
|
55
|
-
to which the dump flag is set.
|
|
56
|
-
enabled (bool, optional): ``True`` means enable dump, ``False`` means disable dump.
|
|
57
|
-
Default: ``True`` .
|
|
58
|
-
|
|
59
|
-
Supported Platforms:
|
|
60
|
-
``Ascend``
|
|
61
|
-
|
|
62
|
-
Examples:
|
|
63
|
-
.. note::
|
|
64
|
-
Please set environment variable `MINDSPORE_DUMP_CONFIG` to the dump config file and set `dump_mode` field
|
|
65
|
-
in dump config file to 2 before running this example.
|
|
66
|
-
See `dump document <https://www.mindspore.cn/tutorials/en/master/debug/dump.html>`_ for details.
|
|
67
|
-
|
|
68
|
-
>>> import numpy as np
|
|
69
|
-
>>> import mindspore as ms
|
|
70
|
-
>>> import mindspore.nn as nn
|
|
71
|
-
>>> from mindspore import Tensor, set_dump, jit
|
|
72
|
-
>>>
|
|
73
|
-
>>> ms.set_device(device_target="Ascend")
|
|
74
|
-
>>>
|
|
75
|
-
>>> class MyNet(nn.Cell):
|
|
76
|
-
... def __init__(self):
|
|
77
|
-
... super().__init__()
|
|
78
|
-
... self.conv1 = nn.Conv2d(5, 6, 5, pad_mode='valid')
|
|
79
|
-
... self.relu1 = nn.ReLU()
|
|
80
|
-
...
|
|
81
|
-
... @jit
|
|
82
|
-
... def construct(self, x):
|
|
83
|
-
... x = self.conv1(x)
|
|
84
|
-
... x = self.relu1(x)
|
|
85
|
-
... return x
|
|
86
|
-
>>>
|
|
87
|
-
>>> if __name__ == "__main__":
|
|
88
|
-
... net = MyNet()
|
|
89
|
-
... set_dump(net.conv1)
|
|
90
|
-
... input_tensor = Tensor(np.ones([1, 5, 10, 10], dtype=np.float32))
|
|
91
|
-
... output = net(input_tensor)
|
|
24
|
+
This api will be deprecated and removed in future versions, please use the api
|
|
25
|
+
:func:`mindspore.tools.set_dump` instead.
|
|
92
26
|
"""
|
|
93
|
-
|
|
94
|
-
raise ValueError('The set_dump API is not supported, please recompile '
|
|
95
|
-
'source without "-s on".')
|
|
96
|
-
|
|
97
|
-
import mindspore.nn as nn # avoid circular import
|
|
98
|
-
from mindspore.ops import Primitive
|
|
99
|
-
if not isinstance(target, nn.Cell) and not isinstance(target, Primitive):
|
|
100
|
-
raise ValueError(f"The \"target\" parameter must be an instance of "
|
|
101
|
-
f"Cell or Primitive, "
|
|
102
|
-
f"but got an instance of {type(target)}.")
|
|
103
|
-
|
|
104
|
-
if not isinstance(enabled, bool):
|
|
105
|
-
raise ValueError("The \"enabled\" parameter must be bool.")
|
|
106
|
-
|
|
107
|
-
# Checking for device target and mode.
|
|
108
|
-
current_target = context.get_context("device_target")
|
|
109
|
-
if current_target != "Ascend":
|
|
110
|
-
# We will not return here in case user changed device_target later.
|
|
111
|
-
warn("Current device_target is {}, which is not supported by set_dump. "
|
|
112
|
-
"Only Ascend device target is supported currently. "
|
|
113
|
-
"If you have Ascend device, consider set device_target to Ascend "
|
|
114
|
-
"before calling set_dump.".format(current_target))
|
|
115
|
-
|
|
116
|
-
# The actual set dump logic.
|
|
117
|
-
if isinstance(target, nn.Cell):
|
|
118
|
-
target.add_flags(dump=enabled)
|
|
119
|
-
for cell in target.cells():
|
|
120
|
-
set_dump(cell, enabled)
|
|
121
|
-
|
|
122
|
-
primitives = getattr(target, "_primitives", {})
|
|
123
|
-
for value in primitives.values():
|
|
124
|
-
if value and "dump" in value.attrs:
|
|
125
|
-
set_dump(value, enabled)
|
|
126
|
-
|
|
127
|
-
if isinstance(target, Primitive):
|
|
128
|
-
target.add_prim_attr("dump", "true" if enabled else "false")
|
|
27
|
+
tools_set_dump(target, enabled)
|
|
@@ -275,9 +275,7 @@ class _AutoIdentifyDynamicShape:
|
|
|
275
275
|
continue
|
|
276
276
|
if not isinstance(elem, (list, tuple, Tensor, int, float)):
|
|
277
277
|
return False
|
|
278
|
-
if isinstance(elem, Tensor) and
|
|
279
|
-
self._is_invalid_shape(elem.shape) and \
|
|
280
|
-
not enable_jit_dynamic:
|
|
278
|
+
if isinstance(elem, Tensor) and self._is_invalid_shape(elem.shape) and not enable_jit_dynamic:
|
|
281
279
|
return False
|
|
282
280
|
if not is_sink_mode and isinstance(elem, (list, tuple)):
|
|
283
281
|
return self._is_enable_auto_dynamic_shape(elem, is_sink_mode, enable_jit_dynamic)
|