mindspore 2.7.0rc1__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +5 -2
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
- mindspore/_extends/parse/parser.py +28 -22
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +23 -2
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
- mindspore/amp.py +0 -18
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/__init__.py +18 -12
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +371 -96
- mindspore/common/_utils.py +7 -43
- mindspore/common/api.py +434 -135
- mindspore/common/dtype.py +98 -57
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/hook_handle.py +82 -3
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +17 -127
- mindspore/common/recompute.py +4 -13
- mindspore/common/tensor.py +50 -217
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +20 -106
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +35 -1
- mindspore/dataset/engine/datasets.py +338 -319
- mindspore/dataset/engine/datasets_user_defined.py +38 -22
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/api/cell.h +28 -4
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +0 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +5 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +6 -1
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +4 -3
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/__init__.py +4 -0
- mindspore/mint/distributed/distributed.py +392 -69
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/_functions.py +1 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +10 -10
- mindspore/mint/nn/layer/normalization.py +11 -16
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +231 -239
- mindspore/nn/layer/activation.py +4 -2
- mindspore/nn/layer/basic.py +56 -14
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +32 -127
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +1 -4
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +2 -4
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +39 -5
- mindspore/nn/wrap/grad_reducer.py +4 -89
- mindspore/numpy/array_creations.py +4 -4
- mindspore/numpy/fft.py +9 -9
- mindspore/numpy/utils_const.py +1 -1
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +1 -5
- mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
- mindspore/ops/auto_generate/gen_extend_func.py +6 -11
- mindspore/ops/auto_generate/gen_ops_def.py +385 -154
- mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +16 -2
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +2 -0
- mindspore/ops/function/array_func.py +24 -18
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +7 -6
- mindspore/ops/function/grad/grad_func.py +4 -12
- mindspore/ops/function/math_func.py +89 -86
- mindspore/ops/function/nn_func.py +92 -313
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +4 -1
- mindspore/ops/functional_overload.py +377 -30
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +12 -50
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +5 -50
- mindspore/ops/operations/comm_ops.py +95 -17
- mindspore/ops/operations/custom_ops.py +237 -22
- mindspore/ops/operations/debug_ops.py +33 -35
- mindspore/ops/operations/manually_defined/ops_def.py +39 -318
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +3 -3
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +4 -27
- mindspore/ops/tensor_method.py +88 -10
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_auto_parallel_context.py +5 -15
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +4 -6
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_utils.py +34 -17
- mindspore/parallel/auto_parallel.py +23 -9
- mindspore/parallel/checkpoint_transform.py +20 -2
- mindspore/parallel/cluster/process_entity/_api.py +28 -33
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/parallel/cluster/run.py +5 -3
- mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/function/reshard_func.py +6 -5
- mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
- mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
- mindspore/parallel/shard.py +7 -21
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +127 -20
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +40 -4
- mindspore/profiler/common/path_manager.py +65 -24
- mindspore/profiler/common/profiler_context.py +27 -14
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +10 -6
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/dynamic_profiler.py +91 -46
- mindspore/profiler/envprofiler.py +30 -5
- mindspore/profiler/experimental_config.py +18 -2
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +34 -7
- mindspore/profiler/profiler.py +193 -145
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +9 -6
- mindspore/runtime/executor.py +35 -0
- mindspore/runtime/memory.py +113 -0
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +8 -21
- mindspore/train/amp.py +6 -7
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +1 -17
- mindspore/train/callback/_flops_collector.py +10 -6
- mindspore/train/callback/_train_fault_tolerance.py +72 -25
- mindspore/train/data_sink.py +5 -9
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +41 -230
- mindspore/train/serialization.py +160 -401
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +152 -16
- mindspore/version.py +1 -1
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/__init__.py +0 -23
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/train/memory_profiling_pb2.py +0 -298
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
mindspore/common/dtype.py
CHANGED
|
@@ -17,21 +17,25 @@
|
|
|
17
17
|
"""Data type for MindSpore."""
|
|
18
18
|
from __future__ import absolute_import
|
|
19
19
|
|
|
20
|
+
import builtins
|
|
20
21
|
import enum
|
|
21
22
|
from inspect import isfunction
|
|
22
23
|
import numpy as np
|
|
24
|
+
from mindspore import log as logger
|
|
23
25
|
from mindspore._c_expression import typing
|
|
24
26
|
from mindspore._c_expression.typing import Type
|
|
25
|
-
from mindspore._c_expression.np_dtypes import
|
|
27
|
+
from mindspore._c_expression.np_dtypes import np_dtype_valid
|
|
26
28
|
|
|
27
|
-
if
|
|
29
|
+
if np_dtype_valid(False):
|
|
28
30
|
from mindspore._c_expression.np_dtypes import bfloat16 as np_bfloat16
|
|
29
31
|
|
|
32
|
+
# bool, int, float are not defined in __all__ to avoid conflict with built-in types.
|
|
30
33
|
__dtype__ = [
|
|
34
|
+
"bool_",
|
|
31
35
|
"int8", "byte",
|
|
32
36
|
"int16", "short",
|
|
33
37
|
"int32", "intc",
|
|
34
|
-
"int64", "intp",
|
|
38
|
+
"int64", "long", "intp",
|
|
35
39
|
"uint8", "ubyte",
|
|
36
40
|
"uint16", "ushort",
|
|
37
41
|
"uint32", "uintc",
|
|
@@ -39,17 +43,15 @@ __dtype__ = [
|
|
|
39
43
|
"float16", "half",
|
|
40
44
|
"float32", "single",
|
|
41
45
|
"float64", "double",
|
|
42
|
-
"
|
|
43
|
-
"
|
|
44
|
-
"
|
|
46
|
+
"complex64", "cfloat",
|
|
47
|
+
"complex128", "cdouble",
|
|
48
|
+
"qint4x2", "bfloat16",
|
|
49
|
+
"float8_e4m3fn", "float8_e5m2", "hifloat8",
|
|
50
|
+
"int_", "uint", "float_",
|
|
51
|
+
"list_", "tuple_", "string",
|
|
45
52
|
"number", "tensor_type",
|
|
46
|
-
"
|
|
47
|
-
"TensorType", "
|
|
48
|
-
"Type", "Int",
|
|
49
|
-
"complex64", "complex128",
|
|
50
|
-
"bfloat16", "qint4x2",
|
|
51
|
-
"float8_e4m3fn", "float8_e5m2",
|
|
52
|
-
"hifloat8"
|
|
53
|
+
"type_none", "_null",
|
|
54
|
+
"TensorType", "Type", "Int",
|
|
53
55
|
]
|
|
54
56
|
|
|
55
57
|
__method__ = [
|
|
@@ -62,16 +64,18 @@ __all__.extend(__dtype__)
|
|
|
62
64
|
__all__.extend(__method__)
|
|
63
65
|
|
|
64
66
|
# type definition
|
|
65
|
-
|
|
67
|
+
bool = typing.kBool
|
|
68
|
+
bool_ = bool
|
|
66
69
|
|
|
67
|
-
qint4x2 = typing.kInt4
|
|
68
70
|
int8 = typing.kInt8
|
|
69
71
|
byte = int8
|
|
70
72
|
int16 = typing.kInt16
|
|
71
73
|
short = int16
|
|
72
74
|
int32 = typing.kInt32
|
|
75
|
+
int = int32
|
|
73
76
|
intc = int32
|
|
74
77
|
int64 = typing.kInt64
|
|
78
|
+
long = int64
|
|
75
79
|
intp = int64
|
|
76
80
|
|
|
77
81
|
uint8 = typing.kUInt8
|
|
@@ -86,15 +90,21 @@ uintp = uint64
|
|
|
86
90
|
float16 = typing.kFloat16
|
|
87
91
|
half = float16
|
|
88
92
|
float32 = typing.kFloat32
|
|
93
|
+
float = float32
|
|
89
94
|
single = float32
|
|
90
95
|
float64 = typing.kFloat64
|
|
91
96
|
double = float64
|
|
97
|
+
|
|
98
|
+
qint4x2 = typing.kInt4
|
|
92
99
|
float8_e4m3fn = typing.kFloat8E4M3FN
|
|
93
100
|
float8_e5m2 = typing.kFloat8E5M2
|
|
94
101
|
hifloat8 = typing.kHiFloat8
|
|
95
102
|
bfloat16 = typing.kBFloat16
|
|
103
|
+
|
|
96
104
|
complex64 = typing.kComplex64
|
|
105
|
+
cfloat = complex64
|
|
97
106
|
complex128 = typing.kComplex128
|
|
107
|
+
cdouble = complex128
|
|
98
108
|
|
|
99
109
|
number = typing.kNumber
|
|
100
110
|
int_ = typing.kInt
|
|
@@ -137,41 +147,25 @@ AnythingType = typing.TypeAny
|
|
|
137
147
|
RefType = typing.RefType
|
|
138
148
|
_NullType = typing.TypeNull
|
|
139
149
|
|
|
140
|
-
number_type = (int8,
|
|
141
|
-
|
|
142
|
-
int32,
|
|
143
|
-
int64,
|
|
144
|
-
uint8,
|
|
145
|
-
uint16,
|
|
146
|
-
uint32,
|
|
147
|
-
uint64,
|
|
148
|
-
float16,
|
|
149
|
-
float32,
|
|
150
|
-
float64,
|
|
151
|
-
bfloat16,
|
|
152
|
-
complex64,
|
|
153
|
-
complex128,
|
|
154
|
-
qint4x2,
|
|
155
|
-
float8_e4m3fn,
|
|
156
|
-
float8_e5m2,
|
|
157
|
-
hifloat8)
|
|
150
|
+
number_type = (int8, int16, int32, int64, uint8, uint16, uint32, uint64, float16, float32, float64, bfloat16, complex64,
|
|
151
|
+
complex128, qint4x2, float8_e4m3fn, float8_e5m2, hifloat8)
|
|
158
152
|
|
|
159
153
|
int_type = (int8, int16, int32, int64,)
|
|
160
154
|
uint_type = (uint8, uint16, uint32, uint64,)
|
|
161
155
|
float_type = (float16, float32, float64, bfloat16, float8_e4m3fn, float8_e5m2, hifloat8)
|
|
162
|
-
signed_type = (int8,
|
|
163
|
-
|
|
156
|
+
signed_type = (int8, int16, int32, int64, float16, float32, float64, bfloat16, complex64, complex128, qint4x2,
|
|
157
|
+
float8_e4m3fn, float8_e5m2, hifloat8)
|
|
164
158
|
complex_type = (complex64, complex128,)
|
|
165
|
-
all_types = (
|
|
166
|
-
float8_e4m3fn, float8_e5m2, hifloat8)
|
|
159
|
+
all_types = (bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float16, float32, float64, bfloat16,
|
|
160
|
+
complex64, complex128, qint4x2, float8_e4m3fn, float8_e5m2, hifloat8)
|
|
167
161
|
|
|
168
162
|
_simple_types = {
|
|
169
163
|
list: list_,
|
|
170
164
|
tuple: tuple_,
|
|
171
165
|
type(None): type_none,
|
|
172
|
-
bool: bool_,
|
|
173
|
-
int: int64,
|
|
174
|
-
float: float64,
|
|
166
|
+
builtins.bool: bool_,
|
|
167
|
+
builtins.int: int64,
|
|
168
|
+
builtins.float: float64,
|
|
175
169
|
complex: complex128,
|
|
176
170
|
str: string,
|
|
177
171
|
np.bool_: bool_,
|
|
@@ -194,6 +188,9 @@ def pytype_to_dtype(obj):
|
|
|
194
188
|
"""
|
|
195
189
|
Convert python type to MindSpore type.
|
|
196
190
|
|
|
191
|
+
Note:
|
|
192
|
+
The interface is deprecated from version 2.7 and will be removed in a future version.
|
|
193
|
+
|
|
197
194
|
Args:
|
|
198
195
|
obj (type): A python type object.
|
|
199
196
|
|
|
@@ -209,6 +206,15 @@ def pytype_to_dtype(obj):
|
|
|
209
206
|
>>> print(out)
|
|
210
207
|
Bool
|
|
211
208
|
"""
|
|
209
|
+
logger.warning("The interface 'mindspore.pytype_to_dtype' is deprecated from version 2.7 "
|
|
210
|
+
"and will be removed in a future version.")
|
|
211
|
+
return _pytype_to_dtype(obj)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def _pytype_to_dtype(obj):
|
|
215
|
+
"""
|
|
216
|
+
Convert python type to MindSpore type.
|
|
217
|
+
"""
|
|
212
218
|
|
|
213
219
|
if isinstance(obj, np.dtype):
|
|
214
220
|
obj = obj.type
|
|
@@ -226,6 +232,9 @@ def get_py_obj_dtype(obj):
|
|
|
226
232
|
"""
|
|
227
233
|
Get the MindSpore data type, which corresponds to python type or variable.
|
|
228
234
|
|
|
235
|
+
Note:
|
|
236
|
+
The interface is deprecated from version 2.7 and will be removed in a future version.
|
|
237
|
+
|
|
229
238
|
Args:
|
|
230
239
|
obj (type): An object of python type, or a variable of python type.
|
|
231
240
|
|
|
@@ -237,6 +246,15 @@ def get_py_obj_dtype(obj):
|
|
|
237
246
|
>>> ms.get_py_obj_dtype(1)
|
|
238
247
|
mindspore.int64
|
|
239
248
|
"""
|
|
249
|
+
logger.warning("The interface 'mindspore.get_py_obj_dtype' is deprecated from version 2.7 "
|
|
250
|
+
"and will be removed in a future version.")
|
|
251
|
+
return _get_py_obj_dtype(obj)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def _get_py_obj_dtype(obj):
|
|
255
|
+
"""
|
|
256
|
+
Get the MindSpore data type, which corresponds to python type or variable.
|
|
257
|
+
"""
|
|
240
258
|
# Tensor
|
|
241
259
|
if hasattr(obj, 'shape') and hasattr(obj, 'dtype') and isinstance(obj.dtype, typing.Type):
|
|
242
260
|
return TensorType(obj.dtype)
|
|
@@ -260,6 +278,9 @@ def dtype_to_nptype(type_):
|
|
|
260
278
|
r"""
|
|
261
279
|
Convert MindSpore dtype to numpy data type.
|
|
262
280
|
|
|
281
|
+
Note:
|
|
282
|
+
The interface is deprecated from version 2.7 and will be removed in a future version.
|
|
283
|
+
|
|
263
284
|
Args:
|
|
264
285
|
type\_ (:class:`mindspore.dtype`): MindSpore's dtype.
|
|
265
286
|
|
|
@@ -271,6 +292,15 @@ def dtype_to_nptype(type_):
|
|
|
271
292
|
>>> ms.dtype_to_nptype(ms.int8)
|
|
272
293
|
<class 'numpy.int8'>
|
|
273
294
|
"""
|
|
295
|
+
logger.warning("The interface 'mindspore.dtype_to_nptype' is deprecated from version 2.7 "
|
|
296
|
+
"and will be removed in a future version.")
|
|
297
|
+
return _dtype_to_nptype(type_)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _dtype_to_nptype(type_):
|
|
301
|
+
"""
|
|
302
|
+
Convert MindSpore dtype to numpy data type.
|
|
303
|
+
"""
|
|
274
304
|
_dtype_nptype_dict = {
|
|
275
305
|
bool_: np.bool_,
|
|
276
306
|
int8: np.int8,
|
|
@@ -288,7 +318,7 @@ def dtype_to_nptype(type_):
|
|
|
288
318
|
complex128: np.complex128,
|
|
289
319
|
}
|
|
290
320
|
if type_ == bfloat16:
|
|
291
|
-
if not
|
|
321
|
+
if not np_dtype_valid(True):
|
|
292
322
|
raise TypeError(
|
|
293
323
|
"The Numpy bfloat16 data type is not supported now, please ensure that the current "
|
|
294
324
|
"Numpy version is not less than the version when the mindspore is compiled, "
|
|
@@ -302,6 +332,9 @@ def dtype_to_pytype(type_):
|
|
|
302
332
|
r"""
|
|
303
333
|
Convert MindSpore dtype to python data type.
|
|
304
334
|
|
|
335
|
+
Note:
|
|
336
|
+
The interface is deprecated from version 2.7 and will be removed in a future version.
|
|
337
|
+
|
|
305
338
|
Args:
|
|
306
339
|
type\_ (:class:`mindspore.dtype`): MindSpore's dtype.
|
|
307
340
|
|
|
@@ -314,23 +347,31 @@ def dtype_to_pytype(type_):
|
|
|
314
347
|
>>> print(out)
|
|
315
348
|
<class 'bool'>
|
|
316
349
|
"""
|
|
350
|
+
logger.warning("The interface 'mindspore.dtype_to_pytype' is deprecated from version 2.7 "
|
|
351
|
+
"and will be removed in a future version.")
|
|
352
|
+
return _dtype_to_pytype(type_)
|
|
353
|
+
|
|
317
354
|
|
|
355
|
+
def _dtype_to_pytype(type_):
|
|
356
|
+
"""
|
|
357
|
+
Convert MindSpore dtype to python data type.
|
|
358
|
+
"""
|
|
318
359
|
return {
|
|
319
|
-
bool_: bool,
|
|
320
|
-
int_: int,
|
|
321
|
-
int8: int,
|
|
322
|
-
int16: int,
|
|
323
|
-
int32: int,
|
|
324
|
-
int64: int,
|
|
325
|
-
uint8: int,
|
|
326
|
-
uint16: int,
|
|
327
|
-
uint32: int,
|
|
328
|
-
uint64: int,
|
|
329
|
-
float_: float,
|
|
330
|
-
float16: float,
|
|
331
|
-
float32: float,
|
|
332
|
-
float64: float,
|
|
333
|
-
bfloat16: float,
|
|
360
|
+
bool_: builtins.bool,
|
|
361
|
+
int_: builtins.int,
|
|
362
|
+
int8: builtins.int,
|
|
363
|
+
int16: builtins.int,
|
|
364
|
+
int32: builtins.int,
|
|
365
|
+
int64: builtins.int,
|
|
366
|
+
uint8: builtins.int,
|
|
367
|
+
uint16: builtins.int,
|
|
368
|
+
uint32: builtins.int,
|
|
369
|
+
uint64: builtins.int,
|
|
370
|
+
float_: builtins.float,
|
|
371
|
+
float16: builtins.float,
|
|
372
|
+
float32: builtins.float,
|
|
373
|
+
float64: builtins.float,
|
|
374
|
+
bfloat16: builtins.float,
|
|
334
375
|
list_: list,
|
|
335
376
|
tuple_: tuple,
|
|
336
377
|
string: str,
|
|
@@ -417,7 +458,7 @@ class QuantDtype(enum.Enum):
|
|
|
417
458
|
def __str__(self):
|
|
418
459
|
return f"{self.name}"
|
|
419
460
|
|
|
420
|
-
def value(self) -> int:
|
|
461
|
+
def value(self) -> builtins.int:
|
|
421
462
|
"""
|
|
422
463
|
Return value of `QuantDtype`. This interface is currently used to serialize or deserialize `QuantDtype`
|
|
423
464
|
primarily.
|
mindspore/common/dump.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2021-
|
|
1
|
+
# Copyright 2021-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -14,115 +14,14 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""Controlling dump behavior."""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
|
-
from
|
|
18
|
-
|
|
19
|
-
import mindspore.context as context
|
|
20
|
-
from mindspore._c_expression import security
|
|
17
|
+
from mindspore.tools import set_dump as tools_set_dump
|
|
18
|
+
from mindspore.common._decorator import deprecated
|
|
21
19
|
|
|
22
20
|
|
|
21
|
+
@deprecated("2.7.1", "mindspore.tools.set_dump", module_prefix="mindspore.")
|
|
23
22
|
def set_dump(target, enabled=True):
|
|
24
23
|
"""
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
`target` should be an instance of :class:`mindspore.nn.Cell` or :class:`mindspore.ops.Primitive` .
|
|
28
|
-
Please note that this API takes effect only when the Dump function is enabled, and the `dump_mode`
|
|
29
|
-
field in the Dump configuration file is set to `"2"` with the `ms_backend` compilation backend
|
|
30
|
-
(please refer to the backend parameter in
|
|
31
|
-
`jit <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.jit.html>`_).
|
|
32
|
-
See the `dump document <https://www.mindspore.cn/tutorials/en/master/debug/dump.html>`_ for details.
|
|
33
|
-
The default enabled status for
|
|
34
|
-
a :class:`mindspore.nn.Cell` or :class:`mindspore.ops.Primitive` is False.
|
|
35
|
-
|
|
36
|
-
Note:
|
|
37
|
-
1. This API is only available for JIT compilation, requires 'Ascend' as the device_target and
|
|
38
|
-
`ms_backend` as the compilation backend (please refer to the backend parameter in
|
|
39
|
-
`jit <https://www.mindspore.cn/docs/en/master/api_python/mindspore/mindspore.jit.html>`_),
|
|
40
|
-
and does not support fused operators.
|
|
41
|
-
2. This API only supports being called before training starts.
|
|
42
|
-
If you call this API during training, it may not be effective.
|
|
43
|
-
3. After using `set_dump(Cell, True)` , operators in forward and backward
|
|
44
|
-
computation (computation generated by the grad operations) of the
|
|
45
|
-
cell will be dumped.
|
|
46
|
-
4. For :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer, the forward
|
|
47
|
-
computation and backward computation use the same set of
|
|
48
|
-
operators. So you can only see dump data from backward computation.
|
|
49
|
-
Please note that :class:`mindspore.nn.SoftmaxCrossEntropyWithLogits` layer will also use
|
|
50
|
-
the above operators internally when initialized with `sparse=True` and
|
|
51
|
-
`reduction="mean"` .
|
|
52
|
-
|
|
53
|
-
Args:
|
|
54
|
-
target (Union[Cell, Primitive]): The Cell instance or Primitive instance
|
|
55
|
-
to which the dump flag is set.
|
|
56
|
-
enabled (bool, optional): ``True`` means enable dump, ``False`` means disable dump.
|
|
57
|
-
Default: ``True`` .
|
|
58
|
-
|
|
59
|
-
Supported Platforms:
|
|
60
|
-
``Ascend``
|
|
61
|
-
|
|
62
|
-
Examples:
|
|
63
|
-
.. note::
|
|
64
|
-
Please set environment variable `MINDSPORE_DUMP_CONFIG` to the dump config file and set `dump_mode` field
|
|
65
|
-
in dump config file to 2 before running this example.
|
|
66
|
-
See `dump document <https://www.mindspore.cn/tutorials/en/master/debug/dump.html>`_ for details.
|
|
67
|
-
|
|
68
|
-
>>> import numpy as np
|
|
69
|
-
>>> import mindspore as ms
|
|
70
|
-
>>> import mindspore.nn as nn
|
|
71
|
-
>>> from mindspore import Tensor, set_dump
|
|
72
|
-
>>>
|
|
73
|
-
>>> ms.set_device(device_target="Ascend")
|
|
74
|
-
>>>
|
|
75
|
-
>>> class MyNet(nn.Cell):
|
|
76
|
-
... def __init__(self):
|
|
77
|
-
... super().__init__()
|
|
78
|
-
... self.conv1 = nn.Conv2d(5, 6, 5, pad_mode='valid')
|
|
79
|
-
... self.relu1 = nn.ReLU()
|
|
80
|
-
...
|
|
81
|
-
... @jit
|
|
82
|
-
... def construct(self, x):
|
|
83
|
-
... x = self.conv1(x)
|
|
84
|
-
... x = self.relu1(x)
|
|
85
|
-
... return x
|
|
86
|
-
>>>
|
|
87
|
-
>>> if __name__ == "__main__":
|
|
88
|
-
... net = MyNet()
|
|
89
|
-
... set_dump(net.conv1)
|
|
90
|
-
... input_tensor = Tensor(np.ones([1, 5, 10, 10], dtype=np.float32))
|
|
91
|
-
... output = net(input_tensor)
|
|
24
|
+
This api will be deprecated and removed in future versions, please use the api
|
|
25
|
+
:func:`mindspore.tools.set_dump` instead.
|
|
92
26
|
"""
|
|
93
|
-
|
|
94
|
-
raise ValueError('The set_dump API is not supported, please recompile '
|
|
95
|
-
'source without "-s on".')
|
|
96
|
-
|
|
97
|
-
import mindspore.nn as nn # avoid circular import
|
|
98
|
-
from mindspore.ops import Primitive
|
|
99
|
-
if not isinstance(target, nn.Cell) and not isinstance(target, Primitive):
|
|
100
|
-
raise ValueError(f"The \"target\" parameter must be an instance of "
|
|
101
|
-
f"Cell or Primitive, "
|
|
102
|
-
f"but got an instance of {type(target)}.")
|
|
103
|
-
|
|
104
|
-
if not isinstance(enabled, bool):
|
|
105
|
-
raise ValueError("The \"enabled\" parameter must be bool.")
|
|
106
|
-
|
|
107
|
-
# Checking for device target and mode.
|
|
108
|
-
current_target = context.get_context("device_target")
|
|
109
|
-
if current_target != "Ascend":
|
|
110
|
-
# We will not return here in case user changed device_target later.
|
|
111
|
-
warn("Current device_target is {}, which is not supported by set_dump. "
|
|
112
|
-
"Only Ascend device target is supported currently. "
|
|
113
|
-
"If you have Ascend device, consider set device_target to Ascend "
|
|
114
|
-
"before calling set_dump.".format(current_target))
|
|
115
|
-
|
|
116
|
-
# The actual set dump logic.
|
|
117
|
-
if isinstance(target, nn.Cell):
|
|
118
|
-
target.add_flags(dump=enabled)
|
|
119
|
-
for cell in target.cells():
|
|
120
|
-
set_dump(cell, enabled)
|
|
121
|
-
|
|
122
|
-
primitives = getattr(target, "_primitives", {})
|
|
123
|
-
for value in primitives.values():
|
|
124
|
-
if value and "dump" in value.attrs:
|
|
125
|
-
set_dump(value, enabled)
|
|
126
|
-
|
|
127
|
-
if isinstance(target, Primitive):
|
|
128
|
-
target.add_prim_attr("dump", "true" if enabled else "false")
|
|
27
|
+
tools_set_dump(target, enabled)
|
|
File without changes
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
2
2
|
#
|
|
3
|
-
# Copyright 2020-
|
|
3
|
+
# Copyright 2020-2025 Huawei Technologies Co., Ltd
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
6
|
# you may not use this file except in compliance with the License.
|
|
@@ -261,7 +261,12 @@ class _AutoIdentifyDynamicShape:
|
|
|
261
261
|
return False
|
|
262
262
|
return True
|
|
263
263
|
|
|
264
|
-
|
|
264
|
+
@staticmethod
|
|
265
|
+
def _is_invalid_shape(shape):
|
|
266
|
+
"""Check if input shape is valid"""
|
|
267
|
+
return is_shape_unknown(shape) or not shape
|
|
268
|
+
|
|
269
|
+
def _is_enable_auto_dynamic_shape(self, args_list, is_sink_mode, enable_jit_dynamic=False):
|
|
265
270
|
"""is enable auto identify shape"""
|
|
266
271
|
if not is_sink_mode and not args_list:
|
|
267
272
|
return False
|
|
@@ -270,10 +275,10 @@ class _AutoIdentifyDynamicShape:
|
|
|
270
275
|
continue
|
|
271
276
|
if not isinstance(elem, (list, tuple, Tensor, int, float)):
|
|
272
277
|
return False
|
|
273
|
-
if isinstance(elem, Tensor) and (
|
|
278
|
+
if isinstance(elem, Tensor) and self._is_invalid_shape(elem.shape) and not enable_jit_dynamic:
|
|
274
279
|
return False
|
|
275
280
|
if not is_sink_mode and isinstance(elem, (list, tuple)):
|
|
276
|
-
return self._is_enable_auto_dynamic_shape(elem, is_sink_mode)
|
|
281
|
+
return self._is_enable_auto_dynamic_shape(elem, is_sink_mode, enable_jit_dynamic)
|
|
277
282
|
return True
|
|
278
283
|
|
|
279
284
|
@staticmethod
|
|
@@ -328,10 +333,10 @@ class _AutoIdentifyDynamicShape:
|
|
|
328
333
|
logger.info((f'generalize with generalize shape cache, compile args shape = {res_shape}'))
|
|
329
334
|
return new_generalize_shape
|
|
330
335
|
|
|
331
|
-
def auto_dynamic_generate_compile_args(self, args_list, is_sink_mode):
|
|
336
|
+
def auto_dynamic_generate_compile_args(self, args_list, is_sink_mode, enable_jit_dynamic=False):
|
|
332
337
|
"""generate compile args in auto dynamic shape"""
|
|
333
338
|
if not self.is_enable_auto_dynamic_shape or \
|
|
334
|
-
not self._is_enable_auto_dynamic_shape(args_list, is_sink_mode) or \
|
|
339
|
+
not self._is_enable_auto_dynamic_shape(args_list, is_sink_mode, enable_jit_dynamic) or \
|
|
335
340
|
not self._check_input_number_and_type(args_list):
|
|
336
341
|
self.is_enable_auto_dynamic_shape = False
|
|
337
342
|
return args_list
|
|
@@ -475,11 +480,13 @@ class _AutoIdentifyDynamicShape:
|
|
|
475
480
|
_auto_dynamic_shape = _AutoIdentifyDynamicShape()
|
|
476
481
|
|
|
477
482
|
|
|
478
|
-
def get_auto_dynamic_shape_args(compile_args, key_id, enable_auto_dynamic=False):
|
|
483
|
+
def get_auto_dynamic_shape_args(compile_args, key_id, enable_auto_dynamic=False, enable_jit_dynamic=False):
|
|
479
484
|
"""get auto dynamic shape args."""
|
|
480
485
|
if key_id not in auto_dynamic_shape_dict:
|
|
481
486
|
auto_dynamic_shape_dict[key_id] = _AutoIdentifyDynamicShape(enable_auto_dynamic)
|
|
482
|
-
compile_args = auto_dynamic_shape_dict[key_id].auto_dynamic_generate_compile_args(
|
|
487
|
+
compile_args = auto_dynamic_shape_dict[key_id].auto_dynamic_generate_compile_args(
|
|
488
|
+
compile_args, False, enable_jit_dynamic
|
|
489
|
+
)
|
|
483
490
|
return compile_args
|
|
484
491
|
|
|
485
492
|
|
|
@@ -487,18 +494,3 @@ def update_auto_dynamic_shape_phase(compile_args, key_id, phase):
|
|
|
487
494
|
"""update auto dynamic shape phase."""
|
|
488
495
|
if key_id in auto_dynamic_shape_dict:
|
|
489
496
|
auto_dynamic_shape_dict[key_id].update_phase_and_compile_args(compile_args, phase, False)
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
def get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id, input_signature,
|
|
493
|
-
enable_auto_dynamic=False):
|
|
494
|
-
"""get auto dynamic shape args."""
|
|
495
|
-
if input_signature is None:
|
|
496
|
-
return get_auto_dynamic_shape_args(compile_args, key_id, enable_auto_dynamic)
|
|
497
|
-
return compile_args
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
def update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, input_signature):
|
|
501
|
-
"""update auto dynamic shape phase."""
|
|
502
|
-
if input_signature is None:
|
|
503
|
-
if key_id in auto_dynamic_shape_dict:
|
|
504
|
-
auto_dynamic_shape_dict[key_id].update_phase_and_compile_args(compile_args, phase, False)
|