mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.0rc1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +1 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -0
- mindspore/_extends/parse/parser.py +22 -28
- mindspore/_extends/parse/standard_method.py +1 -15
- mindspore/_extends/pijit/pijit_func_white_list.py +5 -2
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/amp.py +18 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +12 -18
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +38 -102
- mindspore/common/_utils.py +1 -9
- mindspore/common/api.py +106 -155
- mindspore/common/{dynamic_shape/auto_dynamic_shape.py → auto_dynamic_shape.py} +23 -17
- mindspore/common/dtype.py +57 -98
- mindspore/common/dump.py +1 -1
- mindspore/common/file_system.py +9 -59
- mindspore/common/hook_handle.py +3 -22
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +20 -4
- mindspore/common/recompute.py +4 -2
- mindspore/common/tensor.py +52 -38
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/context.py +21 -15
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +1 -35
- mindspore/dataset/engine/datasets.py +315 -330
- mindspore/dataset/engine/datasets_user_defined.py +22 -38
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +5 -17
- mindspore/dataset/vision/utils.py +21 -632
- mindspore/device_context/ascend/op_tuning.py +1 -35
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -3
- mindspore/include/api/cell.h +4 -28
- mindspore/include/api/cfg.h +7 -24
- mindspore/include/api/context.h +0 -1
- mindspore/include/api/delegate.h +2 -0
- mindspore/include/api/dual_abi_helper.h +19 -100
- mindspore/include/api/graph.h +1 -14
- mindspore/include/api/kernel.h +3 -16
- mindspore/include/api/kernel_api.h +1 -9
- mindspore/include/api/metrics/accuracy.h +0 -9
- mindspore/include/api/model.h +1 -5
- mindspore/include/api/model_group.h +0 -4
- mindspore/include/api/model_parallel_runner.h +0 -2
- mindspore/include/api/status.h +10 -48
- mindspore/include/api/types.h +1 -6
- mindspore/include/dataset/constants.h +0 -9
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar10.py +2 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +5 -5
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mindspore_ops_host.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/distributed/__init__.py +0 -4
- mindspore/mint/distributed/distributed.py +14 -217
- mindspore/mint/nn/layer/_functions.py +2 -1
- mindspore/mint/nn/layer/conv.py +6 -6
- mindspore/mint/nn/layer/normalization.py +3 -3
- mindspore/nn/cell.py +174 -216
- mindspore/nn/layer/activation.py +2 -4
- mindspore/nn/layer/basic.py +13 -7
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/optim/adam.py +3 -1
- mindspore/nn/optim/lamb.py +3 -1
- mindspore/nn/optim/tft_wrapper.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +5 -39
- mindspore/nn/wrap/grad_reducer.py +15 -0
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/utils_const.py +1 -1
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_op_impl/cpu/__init__.py +0 -1
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +2 -12
- mindspore/ops/auto_generate/gen_extend_func.py +4 -4
- mindspore/ops/auto_generate/gen_ops_def.py +16 -290
- mindspore/ops/auto_generate/gen_ops_prim.py +76 -563
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/function/__init__.py +0 -1
- mindspore/ops/function/array_func.py +6 -10
- mindspore/ops/function/debug_func.py +2 -4
- mindspore/ops/function/grad/grad_func.py +12 -4
- mindspore/ops/function/math_func.py +32 -44
- mindspore/ops/function/nn_func.py +20 -18
- mindspore/ops/functional.py +1 -2
- mindspore/ops/functional_overload.py +12 -23
- mindspore/ops/operations/_inner_ops.py +12 -11
- mindspore/ops/operations/array_ops.py +50 -4
- mindspore/ops/operations/comm_ops.py +15 -1
- mindspore/ops/operations/custom_ops.py +4 -10
- mindspore/ops/operations/debug_ops.py +6 -6
- mindspore/ops/operations/manually_defined/ops_def.py +12 -12
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +1 -1
- mindspore/ops/primitive.py +10 -3
- mindspore/ops/tensor_method.py +7 -16
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +16 -0
- mindspore/parallel/_auto_parallel_context.py +15 -5
- mindspore/parallel/_parallel_serialization.py +2 -3
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_transformer/transformer.py +4 -4
- mindspore/parallel/_utils.py +11 -5
- mindspore/parallel/auto_parallel.py +9 -23
- mindspore/parallel/checkpoint_transform.py +0 -2
- mindspore/parallel/cluster/process_entity/_api.py +1 -4
- mindspore/parallel/cluster/run.py +3 -5
- mindspore/parallel/function/reshard_func.py +5 -6
- mindspore/parallel/nn/parallel_cell_wrapper.py +3 -40
- mindspore/parallel/nn/parallel_grad_reducer.py +8 -0
- mindspore/parallel/shard.py +21 -7
- mindspore/parallel/transform_safetensors.py +4 -10
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +9 -10
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/common/msprof_cmd_tool.py +2 -2
- mindspore/profiler/common/path_manager.py +0 -9
- mindspore/profiler/common/profiler_context.py +2 -25
- mindspore/profiler/common/profiler_meta_data.py +0 -1
- mindspore/profiler/common/profiler_op_analyse.py +6 -10
- mindspore/{ops/_op_impl/cpu/joinedstr_op.py → profiler/common/validator/__init__.py} +1 -15
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +46 -91
- mindspore/profiler/envprofiler.py +5 -30
- mindspore/profiler/experimental_config.py +1 -16
- mindspore/profiler/platform/cpu_profiler.py +4 -10
- mindspore/profiler/platform/npu_profiler.py +1 -1
- mindspore/profiler/profiler.py +145 -193
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/runtime/__init__.py +4 -6
- mindspore/runtime/executor.py +0 -27
- mindspore/runtime/memory.py +0 -1
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +3 -3
- mindspore/train/amp.py +3 -0
- mindspore/train/callback/_callback.py +1 -2
- mindspore/train/callback/_checkpoint.py +8 -1
- mindspore/train/callback/_flops_collector.py +6 -10
- mindspore/train/callback/_train_fault_tolerance.py +7 -3
- mindspore/train/data_sink.py +4 -4
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +20 -4
- mindspore/train/serialization.py +15 -35
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/utils.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/METADATA +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/RECORD +193 -192
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +0 -1109
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/dynamic_shape/enable_dynamic.py +0 -197
- /mindspore/common/{dynamic_shape/_auto_dynamic.py → _auto_dynamic.py} +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/top_level.txt +0 -0
mindspore/common/api.py
CHANGED
|
@@ -50,14 +50,12 @@ from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcas
|
|
|
50
50
|
_is_parallel_mode
|
|
51
51
|
from mindspore import _checkparam as Validator
|
|
52
52
|
from mindspore._checkparam import is_stub_tensor
|
|
53
|
-
from mindspore.common._utils import is_shape_unknown
|
|
53
|
+
from mindspore.common._utils import is_shape_unknown
|
|
54
54
|
from mindspore.common.mutable import mutable, _check_element_type
|
|
55
|
-
from mindspore.common.
|
|
56
|
-
|
|
57
|
-
from mindspore.common.dynamic_shape.enable_dynamic import generate_dynamic_tensor_args, ENABLE_DYNAMIC
|
|
55
|
+
from mindspore.common.auto_dynamic_shape import get_auto_dynamic_shape_args, update_auto_dynamic_shape_phase, \
|
|
56
|
+
get_auto_dynamic_shape_args_with_check_input_signature, update_auto_dynamic_shape_phase_with_check_input_signature
|
|
58
57
|
from mindspore.common._pijit_context import PIJitCaptureContext
|
|
59
|
-
from mindspore.common.parameter import Parameter
|
|
60
|
-
from mindspore.common.hook_handle import _hook_version
|
|
58
|
+
from mindspore.common.parameter import Parameter, set_parameter_hook_updated, parameter_hook_updated
|
|
61
59
|
from mindspore.common.jit_context import jit_context
|
|
62
60
|
from mindspore.common.jit_trace import _jit_trace
|
|
63
61
|
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
@@ -76,11 +74,6 @@ ARG_SPECIFIED = "arg_specified_infos"
|
|
|
76
74
|
TOTAL_ARG_LEN = "total_arg_length"
|
|
77
75
|
|
|
78
76
|
|
|
79
|
-
def _real_phase(phase, obj):
|
|
80
|
-
real_phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
81
|
-
return real_phase
|
|
82
|
-
|
|
83
|
-
|
|
84
77
|
def _check_recompile_args(compile_args, kwargs):
|
|
85
78
|
"""Check recompile of graph"""
|
|
86
79
|
|
|
@@ -545,12 +538,10 @@ def _get_parameter_ids(args, kwargs):
|
|
|
545
538
|
parameter_ids += str(id(value))
|
|
546
539
|
return parameter_ids
|
|
547
540
|
|
|
548
|
-
|
|
549
541
|
def _get_tensor_hook_key(tensor):
|
|
550
542
|
"""Get the hook key of Tensor/Parameter"""
|
|
551
543
|
return ".".join(map(str, map(id, tensor.hooks())))
|
|
552
544
|
|
|
553
|
-
|
|
554
545
|
def _get_hook_key(*args, **kwargs):
|
|
555
546
|
"""Get the hook key of Tensors/Parameters"""
|
|
556
547
|
hook_key = ""
|
|
@@ -597,8 +588,6 @@ class _JitExecutor:
|
|
|
597
588
|
|
|
598
589
|
self.fn = fn
|
|
599
590
|
self.input_signature = input_signature
|
|
600
|
-
self.dynamic_args_shapes = getattr(get_func(fn), ENABLE_DYNAMIC, None)
|
|
601
|
-
self.enable_jit_dynamic = self.dynamic_args_shapes is not None
|
|
602
591
|
self.obj = None
|
|
603
592
|
if obj and hasattr(obj, fn.__name__):
|
|
604
593
|
self.obj = obj
|
|
@@ -637,10 +626,12 @@ class _JitExecutor:
|
|
|
637
626
|
else: # get compiled args to generate run args by _generate_run_args
|
|
638
627
|
compile_args = self._generate_compile_args(args_list)
|
|
639
628
|
key_id = self._get_key_id()
|
|
640
|
-
|
|
641
|
-
compile_args
|
|
642
|
-
|
|
643
|
-
|
|
629
|
+
compile_args = get_auto_dynamic_shape_args_with_check_input_signature(
|
|
630
|
+
compile_args,
|
|
631
|
+
key_id,
|
|
632
|
+
self.input_signature,
|
|
633
|
+
self._enable_auto_dynamic
|
|
634
|
+
)
|
|
644
635
|
self._compile_args = compile_args
|
|
645
636
|
|
|
646
637
|
new_inputs = self._generate_run_args(args_list, kwargs)
|
|
@@ -693,13 +684,18 @@ class _JitExecutor:
|
|
|
693
684
|
|
|
694
685
|
def compile(self, method_name, *args, **kwargs):
|
|
695
686
|
"""Returns pipeline for the given args."""
|
|
687
|
+
# Check whether hook function registered on Cell object.
|
|
688
|
+
if self.obj and hasattr(self.obj, "_hook_fn_registered"):
|
|
689
|
+
if self.obj._hook_fn_registered():
|
|
690
|
+
logger.warning(f"For 'Cell', it's not support hook function when using 'jit' decorator. "
|
|
691
|
+
f"If you want to use hook function, please use context.set_context to set "
|
|
692
|
+
f"pynative mode and remove 'jit' decorator.")
|
|
696
693
|
# Chose dynamic shape tensors or actual input tensors as compile args.
|
|
697
694
|
compile_args = self._generate_compile_args(args)
|
|
698
695
|
key_id = self._get_key_id()
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
)
|
|
696
|
+
compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
|
|
697
|
+
self.input_signature,
|
|
698
|
+
self._enable_auto_dynamic)
|
|
703
699
|
|
|
704
700
|
# Add mutable for compile_args for two scene:
|
|
705
701
|
# 1) Origin args is mutable.
|
|
@@ -739,23 +735,20 @@ class _JitExecutor:
|
|
|
739
735
|
|
|
740
736
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
741
737
|
key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
|
|
742
|
-
key = str(key)
|
|
743
738
|
|
|
744
739
|
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
745
740
|
if parameter_ids != "":
|
|
746
|
-
key
|
|
741
|
+
key = str(key) + '.' + parameter_ids
|
|
747
742
|
|
|
748
|
-
key
|
|
749
|
-
key += "." + str(_hook_version())
|
|
743
|
+
key = str(key) + "." + _get_hook_key(*args, **kwargs)
|
|
750
744
|
|
|
751
|
-
phase = generate_name + '.' + key
|
|
745
|
+
phase = generate_name + '.' + str(key)
|
|
752
746
|
|
|
753
|
-
|
|
754
|
-
update_auto_dynamic_shape_phase(compile_args, key_id, phase)
|
|
747
|
+
update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
|
|
755
748
|
|
|
756
749
|
phase = phase + self._cell_cache_key_extend
|
|
757
750
|
|
|
758
|
-
if phase in ms_compile_cache and self._graph_executor.has_compiled(phase):
|
|
751
|
+
if phase in ms_compile_cache and self._graph_executor.has_compiled(phase) and not parameter_hook_updated():
|
|
759
752
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
760
753
|
# generated in generate_arguments_key.
|
|
761
754
|
self._graph_executor.clear_compile_arguments_resource()
|
|
@@ -772,9 +765,16 @@ class _JitExecutor:
|
|
|
772
765
|
|
|
773
766
|
if self.obj is None:
|
|
774
767
|
# Set an attribute to fn as an identifier.
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
768
|
+
if isinstance(self.fn, types.MethodType):
|
|
769
|
+
setattr(self.fn.__func__, "__jit_function__", True)
|
|
770
|
+
else:
|
|
771
|
+
setattr(self.fn, "__jit_function__", True)
|
|
772
|
+
is_compile = self._graph_executor.compile(
|
|
773
|
+
self.fn, compile_args, kwargs, phase, jit_config_dict)
|
|
774
|
+
if isinstance(self.fn, types.MethodType):
|
|
775
|
+
delattr(self.fn.__func__, "__jit_function__")
|
|
776
|
+
else:
|
|
777
|
+
delattr(self.fn, "__jit_function__")
|
|
778
778
|
else:
|
|
779
779
|
if isinstance(self.obj, ms.nn.Cell):
|
|
780
780
|
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
|
@@ -783,6 +783,7 @@ class _JitExecutor:
|
|
|
783
783
|
|
|
784
784
|
if not is_compile:
|
|
785
785
|
raise RuntimeError("Executor compile failed.")
|
|
786
|
+
set_parameter_hook_updated(False)
|
|
786
787
|
ms_compile_cache.add(phase)
|
|
787
788
|
if hasattr(self.obj, "phase"):
|
|
788
789
|
self.obj.phase_cache[self.obj.phase] = phase
|
|
@@ -830,70 +831,41 @@ class _JitExecutor:
|
|
|
830
831
|
if enable_compile_cache is True or enable_compile_cache == "1":
|
|
831
832
|
self._graph_executor.set_compile_cache_dep_files(_get_compile_cache_dep_files())
|
|
832
833
|
|
|
833
|
-
def _generate_compile_args_by_enable_dynamic(self, args_list):
|
|
834
|
-
"""Generate compile args by enable_dynamic."""
|
|
835
|
-
compile_args = generate_dynamic_tensor_args(args_list, self.dynamic_args_shapes)
|
|
836
|
-
compile_args = _add_mutable_attr(args_list, compile_args, _pynative_executor.requires_grad())
|
|
837
|
-
if self.obj is not None:
|
|
838
|
-
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
|
|
839
|
-
else:
|
|
840
|
-
_pynative_executor.set_dynamic_input(self.fn, *compile_args)
|
|
841
|
-
logger.info(f"dynamic shape compile_args: {compile_args}")
|
|
842
|
-
return compile_args
|
|
843
|
-
|
|
844
|
-
def _generate_compile_args_by_set_inputs(self, args_list):
|
|
845
|
-
"""Generate compile args by set_inputs."""
|
|
846
|
-
compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs())
|
|
847
|
-
if len(compile_args) != len(args_list):
|
|
848
|
-
raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
|
|
849
|
-
f"dynamic shape tensors: {len(compile_args)}.")
|
|
850
|
-
self._graph_executor.check_argument_consistency(compile_args, args_list, "set_inputs")
|
|
851
|
-
Validator.check_symbolic_shape(compile_args, args_list)
|
|
852
|
-
return compile_args
|
|
853
|
-
|
|
854
|
-
def _generate_compile_args_by_input_signature(self, args_list):
|
|
855
|
-
"""Generate compile args by input_signature."""
|
|
856
|
-
compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature))
|
|
857
|
-
dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)])
|
|
858
|
-
Validator.check_symbolic_shape(self.input_signature, args_list)
|
|
859
|
-
if dyn_shape:
|
|
860
|
-
# Checkout whether the `sens` has been added to args_list.
|
|
861
|
-
if len(compile_args) == len(args_list) - 1:
|
|
862
|
-
logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
|
|
863
|
-
f"of input_signature args '{len(compile_args)}'. The last actual args may "
|
|
864
|
-
f"be 'sens' and added it to compile args.")
|
|
865
|
-
compile_args.append(args_list[-1])
|
|
866
|
-
compile_args = tuple(compile_args)
|
|
867
|
-
self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
|
|
868
|
-
if self.obj is not None:
|
|
869
|
-
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
|
|
870
|
-
else:
|
|
871
|
-
_pynative_executor.set_dynamic_input(self.fn, *compile_args)
|
|
872
|
-
else:
|
|
873
|
-
if not verify_inputs_signature(compile_args, args_list):
|
|
874
|
-
raise ValueError("The input args is incompatible with the args in `input_signature`!")
|
|
875
|
-
return compile_args
|
|
876
|
-
|
|
877
|
-
def _check_set_inputs(self):
|
|
878
|
-
"""Check if the `set_inputs()` of Cell object has been set."""
|
|
879
|
-
return self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs()
|
|
880
|
-
|
|
881
834
|
def _generate_compile_args(self, args_list):
|
|
882
835
|
"""Chose dynamic shape tensors or actual input tensors as compile args."""
|
|
883
|
-
# Case:
|
|
884
|
-
|
|
885
|
-
raise ValueError("When `enable_dynamic` is provided, the `set_inputs()` cannot be set!")
|
|
886
|
-
# Case: The `enable_dynamic` is provided.
|
|
887
|
-
if self.enable_jit_dynamic:
|
|
888
|
-
return self._generate_compile_args_by_enable_dynamic(args_list)
|
|
836
|
+
# Case: If the shape of input args is dynamic, get dynamic shape tensor from context and use it to compile.
|
|
837
|
+
compile_args = _pynative_executor.get_dynamic_input(args_list)
|
|
889
838
|
# Case: The `set_inputs()` of Cell object has been set, using these dynamic shape args as compile args.
|
|
890
|
-
if self.
|
|
891
|
-
|
|
839
|
+
if self.fn.__name__ == 'construct' and isinstance(self.obj, ms.nn.Cell) and self.obj.get_inputs():
|
|
840
|
+
compile_args = _generate_dyn_compile_args(args_list, self.obj.get_inputs())
|
|
841
|
+
if len(compile_args) != len(args_list):
|
|
842
|
+
raise ValueError(f"The number of actual input tensors: {len(args_list)} is not equal to the number of "
|
|
843
|
+
f"dynamic shape tensors: {len(compile_args)}.")
|
|
844
|
+
self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
|
|
845
|
+
Validator.check_symbolic_shape(compile_args, args_list)
|
|
846
|
+
|
|
892
847
|
# Case: If dynamic shape tensors have been assigned to `input_signature`, they are preferred as compile args.
|
|
893
848
|
if self.input_signature is not None:
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
849
|
+
compile_args = list(_generate_dyn_compile_args(args_list, self.input_signature))
|
|
850
|
+
dyn_shape = any([is_shape_unknown(elem.shape) for elem in compile_args if isinstance(elem, PythonTensor)])
|
|
851
|
+
Validator.check_symbolic_shape(self.input_signature, args_list)
|
|
852
|
+
if dyn_shape:
|
|
853
|
+
# Checkout whether the `sens` has been added to args_list.
|
|
854
|
+
if len(compile_args) == len(args_list) - 1:
|
|
855
|
+
logger.warning(f"The number of actual input args '{len(args_list)}' is one more than the number "
|
|
856
|
+
f"of input_signature args '{len(compile_args)}'. The last actual args may "
|
|
857
|
+
f"be 'sens' and added it to compile args.")
|
|
858
|
+
compile_args.append(args_list[-1])
|
|
859
|
+
compile_args = tuple(compile_args)
|
|
860
|
+
self._graph_executor.check_argument_consistency(compile_args, args_list, "input_signature")
|
|
861
|
+
if self.obj is not None:
|
|
862
|
+
_pynative_executor.set_dynamic_input(self.obj, *compile_args)
|
|
863
|
+
else:
|
|
864
|
+
_pynative_executor.set_dynamic_input(self.fn, *compile_args)
|
|
865
|
+
else:
|
|
866
|
+
if not verify_inputs_signature(compile_args, args_list):
|
|
867
|
+
raise ValueError("The input args is incompatible with the args in `input_signature`!")
|
|
868
|
+
return compile_args
|
|
897
869
|
|
|
898
870
|
def _generate_run_args(self, args_list, kwargs):
|
|
899
871
|
"""
|
|
@@ -1105,7 +1077,10 @@ def _jit_ast(hash_obj, dynamic, jit_config, jit_graph_name):
|
|
|
1105
1077
|
process_obj = args[0]
|
|
1106
1078
|
# Handle auto mixed precision strategy.
|
|
1107
1079
|
if not hasattr(func, "amp_strategy"):
|
|
1108
|
-
|
|
1080
|
+
if isinstance(func, types.MethodType):
|
|
1081
|
+
setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
|
|
1082
|
+
else:
|
|
1083
|
+
setattr(func, "amp_strategy", get_curr_amp_strategy())
|
|
1109
1084
|
|
|
1110
1085
|
jit_graph_name = ''
|
|
1111
1086
|
if hasattr(staging_specialize, "__jit_graph_name__"):
|
|
@@ -1113,8 +1088,6 @@ def _jit_ast(hash_obj, dynamic, jit_config, jit_graph_name):
|
|
|
1113
1088
|
jit_executor = _JitExecutor(
|
|
1114
1089
|
func, hash_obj, None, process_obj, jit_config, dynamic, jit_graph_name)
|
|
1115
1090
|
out = jit_executor(*args, **kwargs)
|
|
1116
|
-
if isinstance(process_obj, ms.nn.Cell):
|
|
1117
|
-
_clear_auto_parallel_context(process_obj)
|
|
1118
1091
|
return out
|
|
1119
1092
|
|
|
1120
1093
|
# `inspect.getfullargspec(func)` will get the specification of the decorated function by default. By set
|
|
@@ -1154,26 +1127,28 @@ def jit(
|
|
|
1154
1127
|
|
|
1155
1128
|
Keyword Args:
|
|
1156
1129
|
capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
|
|
1157
|
-
should be ``
|
|
1130
|
+
should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
|
|
1158
1131
|
|
|
1159
|
-
- ast
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1132
|
+
- `ast <https://www.mindspore.cn/docs/en/master/features/compile/graph_construction.html#ast>`_ :
|
|
1133
|
+
Parse Python ast to build graph.
|
|
1134
|
+
- `bytecode <https://www.mindspore.cn/docs/en/master/features/compile/graph_construction.html#bytecode>`_ :
|
|
1135
|
+
Parse Python bytecode to build graph at runtime. This is an experimental prototype that is subject to
|
|
1136
|
+
change and/or deletion.
|
|
1137
|
+
- `trace <https://www.mindspore.cn/docs/en/master/features/compile/graph_construction.html#trace>`_ : Trace the execution of Python code to build graph. This is an experimental prototype that is
|
|
1138
|
+
subject to change and/or deletion.
|
|
1164
1139
|
|
|
1165
1140
|
jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
|
|
1166
|
-
with ms_backend. The value of jit_level should be ``
|
|
1141
|
+
with ms_backend. The value of jit_level should be ``O0`` or ``O1`` . Default: ``O0`` .
|
|
1167
1142
|
|
|
1168
|
-
- O0
|
|
1169
|
-
- O1
|
|
1143
|
+
- `O0`: Except for optimizations that may affect functionality, all other optimizations are turned off.
|
|
1144
|
+
- `O1`: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
|
|
1170
1145
|
level is experimental and is being improved.
|
|
1171
1146
|
|
|
1172
1147
|
dynamic (int, optional): Whether dynamic shape compilation should be performed. Default: ``0``. The value range
|
|
1173
1148
|
is as follows:
|
|
1174
1149
|
|
|
1175
|
-
- 0
|
|
1176
|
-
- 1
|
|
1150
|
+
- `0`: Do not perform dynamic shape compilation.
|
|
1151
|
+
- `1`: Enable dynamic shape compilation and automatically detect shape changes.
|
|
1177
1152
|
|
|
1178
1153
|
fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
|
|
1179
1154
|
be compatible with all Python syntax in the function as much as possible. If True, we require that the
|
|
@@ -1181,14 +1156,12 @@ def jit(
|
|
|
1181
1156
|
not supported), then it will raise an exception. This currently only applies when capture_mode is ``ast``
|
|
1182
1157
|
or ``bytecode``. Default: ``False``.
|
|
1183
1158
|
backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
|
|
1184
|
-
use ``
|
|
1185
|
-
|
|
1159
|
+
use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including Atlas
|
|
1160
|
+
A2 training series products by default.
|
|
1186
1161
|
|
|
1187
|
-
- ms_backend
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
for Ascend model compilation and execution. Note: This backend takes effect only in static graph mode
|
|
1191
|
-
and can be executed only on Ascend hardware.
|
|
1162
|
+
- `ms_backend`: Adopt KernelByKernel execution mode.
|
|
1163
|
+
- `GE`: Adopt Sink execution mode. The whole model will be sinked to device to execute, only applicable to
|
|
1164
|
+
the top cell of model. And only can be used in Ascend platform.
|
|
1192
1165
|
|
|
1193
1166
|
**options (dict): A dictionary of options to pass to the compilation backend.
|
|
1194
1167
|
|
|
@@ -1211,11 +1184,11 @@ def jit(
|
|
|
1211
1184
|
`disable_format_transform` can be set to ``True`` to try to improve training performance.
|
|
1212
1185
|
Default: ``False`` .
|
|
1213
1186
|
- exec_order (str, optional): Set the sorting method for operator execution, currently only two sorting
|
|
1214
|
-
methods are supported: ``
|
|
1187
|
+
methods are supported: ``bfs`` and ``dfs`` . Default: ``bfs`` .
|
|
1215
1188
|
|
|
1216
|
-
- bfs
|
|
1189
|
+
- `bfs`: The default sorting method, breadth priority, good communication masking, relatively good
|
|
1217
1190
|
performance.
|
|
1218
|
-
- dfs
|
|
1191
|
+
- `dfs`: An optional sorting method, depth-first sorting. The performance is relatively worse than that
|
|
1219
1192
|
of bfs execution order, but it occupies less memory. It is recommended to try dfs in scenarios where
|
|
1220
1193
|
other execution orders run out of memory (OOM).
|
|
1221
1194
|
|
|
@@ -1226,11 +1199,11 @@ def jit(
|
|
|
1226
1199
|
- global (dict): Set global options.
|
|
1227
1200
|
- session (dict): Set session options.
|
|
1228
1201
|
|
|
1229
|
-
- infer_boost (str, optional): Used to control the inference mode. Default: ``
|
|
1202
|
+
- infer_boost (str, optional): Used to control the inference mode. Default: ``off``, which means
|
|
1230
1203
|
the inference mode is disabled. The range is as follows:
|
|
1231
1204
|
|
|
1232
|
-
- on
|
|
1233
|
-
- off
|
|
1205
|
+
- `on`: Enable inference mode, get better infer performance.
|
|
1206
|
+
- `off`: Disable inference mode, use forward for inference. The performance is poor.
|
|
1234
1207
|
|
|
1235
1208
|
Returns:
|
|
1236
1209
|
Function, if `fn` is not None, returns a callable function that will execute the compiled function; If `fn` is
|
|
@@ -1921,19 +1894,6 @@ class _PyNativeExecutor:
|
|
|
1921
1894
|
"""
|
|
1922
1895
|
return self._executor.constant_folding(*args)
|
|
1923
1896
|
|
|
1924
|
-
def set_creation_type(self, tensor, creation_type):
|
|
1925
|
-
"""
|
|
1926
|
-
Set tensor's view creation type
|
|
1927
|
-
|
|
1928
|
-
Args:
|
|
1929
|
-
tensor (Tensor): input tensor.
|
|
1930
|
-
creation_type (CreationType): The type of view tensor when it is created.
|
|
1931
|
-
|
|
1932
|
-
Return:
|
|
1933
|
-
None.
|
|
1934
|
-
"""
|
|
1935
|
-
return self._executor.set_creation_type(tensor, creation_type)
|
|
1936
|
-
|
|
1937
1897
|
|
|
1938
1898
|
class _CellGraphExecutor:
|
|
1939
1899
|
"""
|
|
@@ -2042,11 +2002,6 @@ class _CellGraphExecutor:
|
|
|
2042
2002
|
if not hasattr(obj, obj.__parse_method__):
|
|
2043
2003
|
raise AttributeError(
|
|
2044
2004
|
'The class {} does not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
|
|
2045
|
-
inner_func = inspect.unwrap(obj.construct)
|
|
2046
|
-
if hasattr(get_func(inner_func), ENABLE_DYNAMIC):
|
|
2047
|
-
raise ValueError(
|
|
2048
|
-
"When using set_context(mode=GRAPH_MODE) together with nn.Cell, the 'enable_dynamic' cannot be set!"
|
|
2049
|
-
)
|
|
2050
2005
|
key_id = str(id(obj)) + str(obj.create_time)
|
|
2051
2006
|
args = get_auto_dynamic_shape_args(args, key_id)
|
|
2052
2007
|
|
|
@@ -2057,25 +2012,20 @@ class _CellGraphExecutor:
|
|
|
2057
2012
|
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
|
|
2058
2013
|
|
|
2059
2014
|
key = self._graph_executor.generate_arguments_key(obj, args, kwargs, self.enable_tuple_broaden)
|
|
2060
|
-
|
|
2015
|
+
obj.arguments_key = str(key)
|
|
2016
|
+
|
|
2017
|
+
obj.arguments_key = obj.arguments_key + "." + _get_hook_key(*args, **kwargs)
|
|
2061
2018
|
|
|
2062
2019
|
# When exist parameter in the top graph inputs, need check if the parameter object has changed.
|
|
2063
2020
|
parameter_ids = _get_parameter_ids(args, kwargs)
|
|
2064
2021
|
if parameter_ids != "":
|
|
2065
|
-
|
|
2066
|
-
|
|
2067
|
-
key += "." + _get_hook_key(*args, **kwargs)
|
|
2068
|
-
key += "." + str(_hook_version())
|
|
2069
|
-
|
|
2070
|
-
obj.arguments_key = key
|
|
2071
|
-
|
|
2022
|
+
obj.arguments_key = obj.arguments_key + '.' + parameter_ids
|
|
2072
2023
|
raw_phase = phase
|
|
2073
|
-
|
|
2074
|
-
phase = _real_phase(phase, obj)
|
|
2024
|
+
phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
2075
2025
|
obj.phase_cache[raw_phase] = phase
|
|
2076
2026
|
update_auto_dynamic_shape_phase(args, key_id, phase)
|
|
2077
2027
|
obj.current_phase = phase
|
|
2078
|
-
if phase in obj.compile_cache and self.has_compiled(phase):
|
|
2028
|
+
if phase in obj.compile_cache and self.has_compiled(phase) and not parameter_hook_updated():
|
|
2079
2029
|
logger.debug("%r graph has existed.", phase)
|
|
2080
2030
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
2081
2031
|
# generated in generate_arguments_key.
|
|
@@ -2101,6 +2051,7 @@ class _CellGraphExecutor:
|
|
|
2101
2051
|
obj.compile_cache.add(phase)
|
|
2102
2052
|
if not result:
|
|
2103
2053
|
raise RuntimeError("Executor compile failed.")
|
|
2054
|
+
set_parameter_hook_updated(False)
|
|
2104
2055
|
graph = self._graph_executor.get_func_graph(phase)
|
|
2105
2056
|
|
|
2106
2057
|
if graph is None:
|
|
@@ -2125,15 +2076,15 @@ class _CellGraphExecutor:
|
|
|
2125
2076
|
return self._graph_executor.updata_param_node_default_input(phase, new_param)
|
|
2126
2077
|
|
|
2127
2078
|
def _get_shard_strategy(self, obj):
|
|
2128
|
-
real_phase =
|
|
2079
|
+
real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
2129
2080
|
return self._graph_executor.get_strategy(real_phase)
|
|
2130
2081
|
|
|
2131
2082
|
def _get_num_parallel_ops(self, obj):
|
|
2132
|
-
real_phase =
|
|
2083
|
+
real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
2133
2084
|
return self._graph_executor.get_num_parallel_ops(real_phase)
|
|
2134
2085
|
|
|
2135
2086
|
def _get_allreduce_fusion(self, obj):
|
|
2136
|
-
real_phase =
|
|
2087
|
+
real_phase = obj.phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
2137
2088
|
return self._graph_executor.get_allreduce_fusion(real_phase)
|
|
2138
2089
|
|
|
2139
2090
|
def __call__(self, obj, *args, phase='predict'):
|
|
@@ -2185,10 +2136,10 @@ class _CellGraphExecutor:
|
|
|
2185
2136
|
Tensor/Tuple, return execute result.
|
|
2186
2137
|
"""
|
|
2187
2138
|
if phase == 'save':
|
|
2188
|
-
exe_phase =
|
|
2139
|
+
exe_phase = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
2189
2140
|
return self._graph_executor((), exe_phase)
|
|
2190
2141
|
|
|
2191
|
-
phase_real =
|
|
2142
|
+
phase_real = phase + '.' + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
2192
2143
|
if self.has_compiled(phase_real):
|
|
2193
2144
|
return self._exec_pip(obj, *args, phase=phase_real)
|
|
2194
2145
|
raise KeyError('{} graph is not exist.'.format(phase_real))
|
|
@@ -2215,7 +2166,7 @@ class _CellGraphExecutor:
|
|
|
2215
2166
|
|
|
2216
2167
|
def get_optimize_graph_proto(self, obj):
|
|
2217
2168
|
"""Return optimize graph binary proto."""
|
|
2218
|
-
exec_id =
|
|
2169
|
+
exec_id = obj.phase + "." + str(obj.create_time) + '.' + str(id(obj)) + '.' + obj.arguments_key
|
|
2219
2170
|
if self._graph_executor.has_compiled(exec_id) is False:
|
|
2220
2171
|
return None
|
|
2221
2172
|
graph_proto = self._graph_executor.get_optimize_graph_proto(exec_id)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
2
2
|
#
|
|
3
|
-
# Copyright 2020-
|
|
3
|
+
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
6
|
# you may not use this file except in compliance with the License.
|
|
@@ -261,12 +261,7 @@ class _AutoIdentifyDynamicShape:
|
|
|
261
261
|
return False
|
|
262
262
|
return True
|
|
263
263
|
|
|
264
|
-
|
|
265
|
-
def _is_invalid_shape(shape):
|
|
266
|
-
"""Check if input shape is valid"""
|
|
267
|
-
return is_shape_unknown(shape) or not shape
|
|
268
|
-
|
|
269
|
-
def _is_enable_auto_dynamic_shape(self, args_list, is_sink_mode, enable_jit_dynamic=False):
|
|
264
|
+
def _is_enable_auto_dynamic_shape(self, args_list, is_sink_mode):
|
|
270
265
|
"""is enable auto identify shape"""
|
|
271
266
|
if not is_sink_mode and not args_list:
|
|
272
267
|
return False
|
|
@@ -275,12 +270,10 @@ class _AutoIdentifyDynamicShape:
|
|
|
275
270
|
continue
|
|
276
271
|
if not isinstance(elem, (list, tuple, Tensor, int, float)):
|
|
277
272
|
return False
|
|
278
|
-
if isinstance(elem, Tensor) and
|
|
279
|
-
self._is_invalid_shape(elem.shape) and \
|
|
280
|
-
not enable_jit_dynamic:
|
|
273
|
+
if isinstance(elem, Tensor) and (is_shape_unknown(elem.shape) or (not elem.shape)):
|
|
281
274
|
return False
|
|
282
275
|
if not is_sink_mode and isinstance(elem, (list, tuple)):
|
|
283
|
-
return self._is_enable_auto_dynamic_shape(elem, is_sink_mode
|
|
276
|
+
return self._is_enable_auto_dynamic_shape(elem, is_sink_mode)
|
|
284
277
|
return True
|
|
285
278
|
|
|
286
279
|
@staticmethod
|
|
@@ -335,10 +328,10 @@ class _AutoIdentifyDynamicShape:
|
|
|
335
328
|
logger.info((f'generalize with generalize shape cache, compile args shape = {res_shape}'))
|
|
336
329
|
return new_generalize_shape
|
|
337
330
|
|
|
338
|
-
def auto_dynamic_generate_compile_args(self, args_list, is_sink_mode
|
|
331
|
+
def auto_dynamic_generate_compile_args(self, args_list, is_sink_mode):
|
|
339
332
|
"""generate compile args in auto dynamic shape"""
|
|
340
333
|
if not self.is_enable_auto_dynamic_shape or \
|
|
341
|
-
not self._is_enable_auto_dynamic_shape(args_list, is_sink_mode
|
|
334
|
+
not self._is_enable_auto_dynamic_shape(args_list, is_sink_mode) or \
|
|
342
335
|
not self._check_input_number_and_type(args_list):
|
|
343
336
|
self.is_enable_auto_dynamic_shape = False
|
|
344
337
|
return args_list
|
|
@@ -482,13 +475,11 @@ class _AutoIdentifyDynamicShape:
|
|
|
482
475
|
_auto_dynamic_shape = _AutoIdentifyDynamicShape()
|
|
483
476
|
|
|
484
477
|
|
|
485
|
-
def get_auto_dynamic_shape_args(compile_args, key_id, enable_auto_dynamic=False
|
|
478
|
+
def get_auto_dynamic_shape_args(compile_args, key_id, enable_auto_dynamic=False):
|
|
486
479
|
"""get auto dynamic shape args."""
|
|
487
480
|
if key_id not in auto_dynamic_shape_dict:
|
|
488
481
|
auto_dynamic_shape_dict[key_id] = _AutoIdentifyDynamicShape(enable_auto_dynamic)
|
|
489
|
-
compile_args = auto_dynamic_shape_dict[key_id].auto_dynamic_generate_compile_args(
|
|
490
|
-
compile_args, False, enable_jit_dynamic
|
|
491
|
-
)
|
|
482
|
+
compile_args = auto_dynamic_shape_dict[key_id].auto_dynamic_generate_compile_args(compile_args, False)
|
|
492
483
|
return compile_args
|
|
493
484
|
|
|
494
485
|
|
|
@@ -496,3 +487,18 @@ def update_auto_dynamic_shape_phase(compile_args, key_id, phase):
|
|
|
496
487
|
"""update auto dynamic shape phase."""
|
|
497
488
|
if key_id in auto_dynamic_shape_dict:
|
|
498
489
|
auto_dynamic_shape_dict[key_id].update_phase_and_compile_args(compile_args, phase, False)
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
def get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id, input_signature,
|
|
493
|
+
enable_auto_dynamic=False):
|
|
494
|
+
"""get auto dynamic shape args."""
|
|
495
|
+
if input_signature is None:
|
|
496
|
+
return get_auto_dynamic_shape_args(compile_args, key_id, enable_auto_dynamic)
|
|
497
|
+
return compile_args
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, input_signature):
|
|
501
|
+
"""update auto dynamic shape phase."""
|
|
502
|
+
if input_signature is None:
|
|
503
|
+
if key_id in auto_dynamic_shape_dict:
|
|
504
|
+
auto_dynamic_shape_dict[key_id].update_phase_and_compile_args(compile_args, phase, False)
|