mindspore 2.7.0rc1__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +5 -2
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -3
- mindspore/_extends/parse/parser.py +28 -22
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +23 -2
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +9 -27
- mindspore/amp.py +0 -18
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/__init__.py +18 -12
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +371 -96
- mindspore/common/_utils.py +7 -43
- mindspore/common/api.py +434 -135
- mindspore/common/dtype.py +98 -57
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +15 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/hook_handle.py +82 -3
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +17 -127
- mindspore/common/recompute.py +4 -13
- mindspore/common/tensor.py +50 -217
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +20 -106
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +35 -1
- mindspore/dataset/engine/datasets.py +338 -319
- mindspore/dataset/engine/datasets_user_defined.py +38 -22
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{profiler/common/validator → graph}/__init__.py +9 -1
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/api/cell.h +28 -4
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +0 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +5 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +6 -1
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +4 -3
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +6 -6
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/__init__.py +4 -0
- mindspore/mint/distributed/distributed.py +392 -69
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/_functions.py +1 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +10 -10
- mindspore/mint/nn/layer/normalization.py +11 -16
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +231 -239
- mindspore/nn/layer/activation.py +4 -2
- mindspore/nn/layer/basic.py +56 -14
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +32 -127
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +1 -4
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +2 -4
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +39 -5
- mindspore/nn/wrap/grad_reducer.py +4 -89
- mindspore/numpy/array_creations.py +4 -4
- mindspore/numpy/fft.py +9 -9
- mindspore/numpy/utils_const.py +1 -1
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +1 -5
- mindspore/ops/_op_impl/cpu/{buffer_append.py → joinedstr_op.py} +8 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +28 -24
- mindspore/ops/auto_generate/gen_extend_func.py +6 -11
- mindspore/ops/auto_generate/gen_ops_def.py +385 -154
- mindspore/ops/auto_generate/gen_ops_prim.py +5676 -5167
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +16 -2
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +2 -0
- mindspore/ops/function/array_func.py +24 -18
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +7 -6
- mindspore/ops/function/grad/grad_func.py +4 -12
- mindspore/ops/function/math_func.py +89 -86
- mindspore/ops/function/nn_func.py +92 -313
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +4 -1
- mindspore/ops/functional_overload.py +377 -30
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +12 -50
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +5 -50
- mindspore/ops/operations/comm_ops.py +95 -17
- mindspore/ops/operations/custom_ops.py +237 -22
- mindspore/ops/operations/debug_ops.py +33 -35
- mindspore/ops/operations/manually_defined/ops_def.py +39 -318
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +3 -3
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +4 -27
- mindspore/ops/tensor_method.py +88 -10
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +0 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_auto_parallel_context.py +5 -15
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +4 -6
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_utils.py +34 -17
- mindspore/parallel/auto_parallel.py +23 -9
- mindspore/parallel/checkpoint_transform.py +20 -2
- mindspore/parallel/cluster/process_entity/_api.py +28 -33
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/parallel/cluster/run.py +5 -3
- mindspore/{experimental/llm_boost/ascend_native → parallel/distributed}/__init__.py +21 -22
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/function/reshard_func.py +6 -5
- mindspore/parallel/nn/parallel_cell_wrapper.py +40 -3
- mindspore/parallel/nn/parallel_grad_reducer.py +0 -8
- mindspore/parallel/shard.py +7 -21
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +127 -20
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +13 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +40 -4
- mindspore/profiler/common/path_manager.py +65 -24
- mindspore/profiler/common/profiler_context.py +27 -14
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +10 -6
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/dynamic_profiler.py +91 -46
- mindspore/profiler/envprofiler.py +30 -5
- mindspore/profiler/experimental_config.py +18 -2
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +34 -7
- mindspore/profiler/profiler.py +193 -145
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +9 -6
- mindspore/runtime/executor.py +35 -0
- mindspore/runtime/memory.py +113 -0
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +8 -21
- mindspore/train/amp.py +6 -7
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +1 -17
- mindspore/train/callback/_flops_collector.py +10 -6
- mindspore/train/callback/_train_fault_tolerance.py +72 -25
- mindspore/train/data_sink.py +5 -9
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +41 -230
- mindspore/train/serialization.py +160 -401
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +152 -16
- mindspore/version.py +1 -1
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/RECORD +330 -344
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -207
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/__init__.py +0 -23
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/train/memory_profiling_pb2.py +0 -298
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0rc1.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
2
|
+
#
|
|
3
|
+
# Copyright 2025 Huawei Technologies Co., Ltd
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
# ============================================================================
|
|
17
|
+
"""Define enable_dynamic decorator."""
|
|
18
|
+
import types
|
|
19
|
+
import inspect
|
|
20
|
+
from mindspore import log as logger
|
|
21
|
+
from mindspore.common.tensor import Tensor
|
|
22
|
+
from mindspore.common._utils import get_func, is_dim_unknown
|
|
23
|
+
from mindspore.common.dynamic_shape.auto_dynamic_shape import SHAPE_DIM_ANY
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
ENABLE_DYNAMIC = "__enable_dynamic__"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _check_element_valid(item, shape, name):
|
|
30
|
+
"""Check elements in shape."""
|
|
31
|
+
if item is not SHAPE_DIM_ANY and (isinstance(item, int) and item <= 0):
|
|
32
|
+
raise TypeError(f"The argument '{name}' has invalid shape '{shape}', only supports None " \
|
|
33
|
+
f"or a tuple/list of positive integers and None.")
|
|
34
|
+
return True
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _check_arg_shape_valid(arg, name):
|
|
38
|
+
"""Check if the shape of arg is valid"""
|
|
39
|
+
#if the shape of arg is None
|
|
40
|
+
if isinstance(arg, Tensor) and is_dim_unknown(arg.shape):
|
|
41
|
+
return True
|
|
42
|
+
if isinstance(arg, Tensor) and \
|
|
43
|
+
SHAPE_DIM_ANY in arg.shape and \
|
|
44
|
+
all(_check_element_valid(item, arg.shape, name) for item in arg.shape):
|
|
45
|
+
return True
|
|
46
|
+
if isinstance(arg, (tuple, list)) and any(_check_arg_shape_valid(item, name) for item in arg):
|
|
47
|
+
return True
|
|
48
|
+
return False
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _check_arg_type_valid(arg, name):
|
|
52
|
+
"""Check if the type of arg is valid."""
|
|
53
|
+
if isinstance(arg, Tensor):
|
|
54
|
+
return
|
|
55
|
+
if isinstance(arg, (tuple, list)):
|
|
56
|
+
for item in arg:
|
|
57
|
+
_check_arg_type_valid(item, name)
|
|
58
|
+
else:
|
|
59
|
+
raise TypeError(f"The decorator enable_dynamic only supports Tensor " \
|
|
60
|
+
f"or a tuple/list of Tensor, but the argument : {name} is type of:{type(arg)}.")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _check_input_valid(arg):
|
|
64
|
+
"""Check if real argument is valid."""
|
|
65
|
+
if isinstance(arg, Tensor):
|
|
66
|
+
if not all(isinstance(item, int) and item > 0 for item in arg.shape):
|
|
67
|
+
raise ValueError(f"When using decorator enable_dynamic, the corresponding shape of inputs should be " \
|
|
68
|
+
f"a tuple/list of positive integers")
|
|
69
|
+
elif isinstance(arg, (tuple, list)):
|
|
70
|
+
for item in arg:
|
|
71
|
+
_check_input_valid(item)
|
|
72
|
+
else:
|
|
73
|
+
raise TypeError(f"When using decorator enable_dynamic, the corresponding inputs only supports Tensor " \
|
|
74
|
+
f"or a tuple/list of Tensor.")
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _check_arg_type_shape(arg, dyn_arg, name):
|
|
78
|
+
"""Check the type, shape and dtype of real argument."""
|
|
79
|
+
if isinstance(arg, Tensor) and isinstance(dyn_arg, Tensor):
|
|
80
|
+
if arg.dtype != dyn_arg.dtype:
|
|
81
|
+
raise TypeError(f"When using decorator enable_dynamic, input tensor dtype = {arg.dtype}, " \
|
|
82
|
+
f"dynamic tensor dtype = {dyn_arg.dtype}, tensor dtypes are not the same.")
|
|
83
|
+
if is_dim_unknown(dyn_arg.shape):
|
|
84
|
+
return
|
|
85
|
+
if len(arg.shape) != len(dyn_arg.shape) or \
|
|
86
|
+
any(y is not SHAPE_DIM_ANY and x != y for x, y in zip(arg.shape, dyn_arg.shape)):
|
|
87
|
+
raise ValueError(f"When using decorator enable_dynamic, input tensor shape = {arg.shape}, " \
|
|
88
|
+
f"dynamic tensor shape = {dyn_arg.shape}, tensor shapes are not the same.")
|
|
89
|
+
elif isinstance(arg, (tuple, list)) and isinstance(dyn_arg, (tuple, list)):
|
|
90
|
+
if len(arg) != len(dyn_arg):
|
|
91
|
+
raise ValueError("Input sequences must have the same structure and length.")
|
|
92
|
+
for x, y in zip(arg, dyn_arg):
|
|
93
|
+
_check_arg_type_shape(x, y, name)
|
|
94
|
+
else:
|
|
95
|
+
raise TypeError(f"When using decorator enable_dynamic, the type between argument '{name}' " \
|
|
96
|
+
f"and corresponding input are not the same.")
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def generate_dynamic_sequence_args(args_list, dyn_args_list):
|
|
100
|
+
"""Generate dynamic shapes for input sequence"""
|
|
101
|
+
if isinstance(args_list, Tensor):
|
|
102
|
+
return dyn_args_list if args_list.shape != dyn_args_list.shape else args_list
|
|
103
|
+
result = []
|
|
104
|
+
for x, y in zip(args_list, dyn_args_list):
|
|
105
|
+
result.append(generate_dynamic_sequence_args(x, y))
|
|
106
|
+
return type(args_list)(result)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def generate_dynamic_tensor_args(args_list, dynamic_shapes):
|
|
110
|
+
"""Generate compile args with dynamic_shapes"""
|
|
111
|
+
new_compile_args = list(args_list)
|
|
112
|
+
for index, arg in enumerate(args_list):
|
|
113
|
+
if isinstance(arg, (tuple, list)) and not hasattr(arg, "__ms_mutable__"):
|
|
114
|
+
raise ValueError(f"When using decorator enable_dynamic, the corresponding attribute of input should be " \
|
|
115
|
+
f"mutable(tuple/list)")
|
|
116
|
+
if index not in dynamic_shapes:
|
|
117
|
+
continue
|
|
118
|
+
_check_input_valid(arg)
|
|
119
|
+
name, dyn_arg = dynamic_shapes[index]
|
|
120
|
+
_check_arg_type_shape(arg, dyn_arg, name)
|
|
121
|
+
new_compile_args[index] = generate_dynamic_sequence_args(arg, dyn_arg)
|
|
122
|
+
logger.debug(f"args_list: {args_list}, dynamic_shapes: {dynamic_shapes}, " \
|
|
123
|
+
f"new_compile_args: {new_compile_args}")
|
|
124
|
+
return new_compile_args
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def enable_dynamic(**kwargs):
|
|
128
|
+
"""
|
|
129
|
+
Use to specify whether the shape of the parameter is dynamic shape or dynamic rank.
|
|
130
|
+
|
|
131
|
+
Note:
|
|
132
|
+
- It needs to be used in conjunction with the JIT interface. Without using the JIT decorator,
|
|
133
|
+
the dynamic shape and dynamic rank functions will not be enabled.
|
|
134
|
+
- In the scenario where both set_context(mode=GRAPH_MODE) and nn.Cell are set simultaneously,
|
|
135
|
+
use enabled_dynamic to report an error.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
kwargs (dict): The input types are Tensor, tuple[Tensor] and list[Tensor]. If one or
|
|
139
|
+
more dimensions in the shape of the parameter need to be specified as dynamic shapes,
|
|
140
|
+
the corresponding dimensions in the shape can be set to None. If the shape that needs
|
|
141
|
+
to generate specified parameters is dynamic rank, the shape can be set to None.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Function, return a function that specifies the dynamic shape information of the parameter.
|
|
145
|
+
|
|
146
|
+
Supported Platforms:
|
|
147
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
148
|
+
|
|
149
|
+
Examples:
|
|
150
|
+
>>> import numpy as np
|
|
151
|
+
>>> import mindspore as ms
|
|
152
|
+
>>> from mindspore import Tensor
|
|
153
|
+
>>> from mindspore import enable_dynamic
|
|
154
|
+
>>> from mindspore import jit
|
|
155
|
+
...
|
|
156
|
+
>>> x = Tensor(np.random.randn(2, 3), ms.float32)
|
|
157
|
+
>>> y = Tensor(np.random.randn(2, 3), ms.float32)
|
|
158
|
+
...
|
|
159
|
+
>>> # Specify parameter y as dynamic shape
|
|
160
|
+
>>> @enable_dynamic(y=Tensor(shape=None, dtype=ms.float32))
|
|
161
|
+
>>> @jit
|
|
162
|
+
>>> def func(x, y):
|
|
163
|
+
... return x + 1, y + 1
|
|
164
|
+
...
|
|
165
|
+
>>> out = func(x, y)
|
|
166
|
+
"""
|
|
167
|
+
# Check inputs at first.
|
|
168
|
+
if not kwargs:
|
|
169
|
+
raise ValueError(f"When using decorator enable_dynamic, the input cannot be empty!")
|
|
170
|
+
for name, arg in kwargs.items():
|
|
171
|
+
_check_arg_type_valid(arg, name)
|
|
172
|
+
if not _check_arg_shape_valid(arg, name):
|
|
173
|
+
raise TypeError(f"When using decorator enable_dynamic, the shape of argument '{name}' " \
|
|
174
|
+
f"at least have one None.")
|
|
175
|
+
|
|
176
|
+
def decorator(func):
|
|
177
|
+
if not isinstance(func, (types.FunctionType, types.MethodType)):
|
|
178
|
+
raise ValueError(f"Decorator enable_dynamic can only be used for function or method " \
|
|
179
|
+
f"decrocated by ms.jit, but got {func}.")
|
|
180
|
+
signature = inspect.signature(func)
|
|
181
|
+
sigs_name = [sig_name for sig_name in signature.parameters if sig_name != "self"]
|
|
182
|
+
if len(kwargs) > len(sigs_name):
|
|
183
|
+
raise ValueError(f"When using decorator enable_dynamic, the number of arguments {len(kwargs)} " \
|
|
184
|
+
f"exceeds the number of function arguments {len(sigs_name)}.")
|
|
185
|
+
# Generate dynamic args.
|
|
186
|
+
dynamic_args = dict()
|
|
187
|
+
for key, value in kwargs.items():
|
|
188
|
+
index = sigs_name.index(key)
|
|
189
|
+
if index in dynamic_args:
|
|
190
|
+
raise ValueError(f"keyword argument repeated: {key}")
|
|
191
|
+
dynamic_args[index] = (key, value)
|
|
192
|
+
# Set dynamic_tensor_shape to func.
|
|
193
|
+
inner_func = inspect.unwrap(func, stop=lambda f: not hasattr(f, '__wrapped__'))
|
|
194
|
+
setattr(get_func(inner_func), ENABLE_DYNAMIC, dynamic_args)
|
|
195
|
+
logger.info(f"Set enable dynamic: {dynamic_args} to {inner_func}")
|
|
196
|
+
return func
|
|
197
|
+
return decorator
|
mindspore/common/file_system.py
CHANGED
|
@@ -14,10 +14,14 @@
|
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""File system registration management"""
|
|
16
16
|
from mindspore import log as logger
|
|
17
|
+
from mindspore import _checkparam as Validator
|
|
18
|
+
|
|
19
|
+
mindio_server_info = {"memfs.data_block_pool_capacity_in_gb": "100"}
|
|
17
20
|
|
|
18
21
|
|
|
19
22
|
class FileSystem:
|
|
20
23
|
"""File operation interface manager"""
|
|
24
|
+
|
|
21
25
|
def __init__(self):
|
|
22
26
|
self.create = open
|
|
23
27
|
self.create_args = ("ab",)
|
|
@@ -35,20 +39,33 @@ def _register_basic_file_system(fs: FileSystem):
|
|
|
35
39
|
return True
|
|
36
40
|
|
|
37
41
|
|
|
38
|
-
def
|
|
39
|
-
"""
|
|
42
|
+
def _init_mindio():
|
|
43
|
+
"""Initialize MindIO and return the module if successful"""
|
|
40
44
|
try:
|
|
41
|
-
import mindio
|
|
45
|
+
import mindio_acp as mindio
|
|
46
|
+
ret = mindio.initialize(server_info=mindio_server_info)
|
|
47
|
+
if ret == 0:
|
|
48
|
+
return mindio
|
|
49
|
+
logger.warning(f"Failed to initialize mindio_acp: ret = {ret}")
|
|
42
50
|
except ImportError:
|
|
43
|
-
|
|
51
|
+
pass
|
|
44
52
|
try:
|
|
53
|
+
import mindio
|
|
45
54
|
ret = mindio.initialize()
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
55
|
+
if ret == 0:
|
|
56
|
+
return mindio
|
|
57
|
+
logger.warning(f"Failed to initialize mindio: ret = {ret}")
|
|
58
|
+
except ImportError:
|
|
59
|
+
pass
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _register_mindio_file_system(fs: FileSystem):
|
|
64
|
+
"""register mindio file system"""
|
|
65
|
+
mindio = _init_mindio()
|
|
66
|
+
if mindio is None:
|
|
51
67
|
return False
|
|
68
|
+
|
|
52
69
|
fs.create = mindio.create_file
|
|
53
70
|
fs.create_args = ()
|
|
54
71
|
fs.open = mindio.open_file
|
|
@@ -56,3 +73,36 @@ def _register_mindio_file_system(fs: FileSystem):
|
|
|
56
73
|
fs.backend = "mindio"
|
|
57
74
|
logger.info("The weights are stored using MindIO as the backend.")
|
|
58
75
|
return True
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def set_mindio_server_info(data_block_pool_capacity_in_gb=100):
|
|
79
|
+
"""
|
|
80
|
+
Configure MindIO server settings.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
data_block_pool_capacity_in_gb (int): Memory pool capacity for data blocks in gigabytes.
|
|
84
|
+
"""
|
|
85
|
+
global mindio_server_info
|
|
86
|
+
Validator.check_positive_int(data_block_pool_capacity_in_gb, "data_block_pool_capacity_in_gb")
|
|
87
|
+
mindio_server_info["memfs.data_block_pool_capacity_in_gb"] = str(data_block_pool_capacity_in_gb)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def mindio_preload(ckpt_file_name):
|
|
91
|
+
"""
|
|
92
|
+
Preload data into memory using MindIO for faster access.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
ckpt_file_name (str): Checkpoint file name.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
bool: True if preloading is successful, False otherwise.
|
|
99
|
+
"""
|
|
100
|
+
Validator.check_value_type('ckpt_file_name', ckpt_file_name, str, "mindio_preload")
|
|
101
|
+
mindio = _init_mindio()
|
|
102
|
+
if mindio is None:
|
|
103
|
+
return False
|
|
104
|
+
if not hasattr(mindio, 'preload'):
|
|
105
|
+
logger.warning("MindIO module does not have preload method")
|
|
106
|
+
return False
|
|
107
|
+
mindio.preload(ckpt_file_name)
|
|
108
|
+
return True
|
mindspore/common/hook_handle.py
CHANGED
|
@@ -15,7 +15,23 @@
|
|
|
15
15
|
"""The removable handle for cell hook function."""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
import weakref
|
|
18
|
+
from collections import OrderedDict
|
|
18
19
|
from mindspore._c_expression import TensorPy as Tensor_
|
|
20
|
+
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# Global variable to mark the `Parameter` hook and `Cell` hook version
|
|
24
|
+
_HOOK_VERSION = 0
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _update_hook_version():
|
|
28
|
+
global _HOOK_VERSION
|
|
29
|
+
_HOOK_VERSION += 1
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _hook_version():
|
|
33
|
+
global _HOOK_VERSION
|
|
34
|
+
return _HOOK_VERSION
|
|
19
35
|
|
|
20
36
|
|
|
21
37
|
class _TensorHookHandle:
|
|
@@ -31,8 +47,9 @@ class _TensorHookHandle:
|
|
|
31
47
|
|
|
32
48
|
def __init__(self, tensor):
|
|
33
49
|
self.id = None
|
|
34
|
-
self.
|
|
50
|
+
self.tensor_weakref = weakref.ref(tensor)
|
|
35
51
|
|
|
52
|
+
@jit_forbidden_register
|
|
36
53
|
def remove(self):
|
|
37
54
|
"""
|
|
38
55
|
Remove the tensor hook function, which corresponds to this '_TensorHookHandle' object.
|
|
@@ -67,9 +84,9 @@ class _TensorHookHandle:
|
|
|
67
84
|
"""
|
|
68
85
|
if self.id is not None:
|
|
69
86
|
Tensor_.remove_hook(self.id)
|
|
70
|
-
tensor = self.
|
|
87
|
+
tensor = self.tensor_weakref()
|
|
71
88
|
if tensor is not None:
|
|
72
|
-
tensor._remove_hook()
|
|
89
|
+
tensor._remove_hook() # pylint:disable=protected-access
|
|
73
90
|
|
|
74
91
|
|
|
75
92
|
class HookHandle:
|
|
@@ -99,6 +116,7 @@ class HookHandle:
|
|
|
99
116
|
if extra_dict is not None:
|
|
100
117
|
self.extra_dict_ref = weakref.ref(extra_dict)
|
|
101
118
|
|
|
119
|
+
@jit_forbidden_register
|
|
102
120
|
def remove(self):
|
|
103
121
|
"""
|
|
104
122
|
Remove the cell hook function, which corresponds to this 'HookHandle' object.
|
|
@@ -145,6 +163,8 @@ class HookHandle:
|
|
|
145
163
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
146
164
|
value= [ 2.00000000e+00]))
|
|
147
165
|
"""
|
|
166
|
+
_update_hook_version() # pylint:disable=protected-access
|
|
167
|
+
|
|
148
168
|
if self.hook_dict_ref is not None:
|
|
149
169
|
hook_dict = self.hook_dict_ref()
|
|
150
170
|
if hook_dict is not None and self.handle_id in hook_dict:
|
|
@@ -154,3 +174,62 @@ class HookHandle:
|
|
|
154
174
|
extra_dict = self.extra_dict_ref()
|
|
155
175
|
if extra_dict is not None and self.handle_id in extra_dict:
|
|
156
176
|
del extra_dict[self.handle_id]
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _check_hook_results(pre_res, new_res, hook_fn):
|
|
180
|
+
if not isinstance(new_res, tuple):
|
|
181
|
+
raise RuntimeError(f"hook {hook_fn.__name__} should return a tuple of grad.")
|
|
182
|
+
|
|
183
|
+
new_res_len = len(new_res)
|
|
184
|
+
pre_res_len = len(pre_res)
|
|
185
|
+
if new_res_len != pre_res_len:
|
|
186
|
+
raise RuntimeError(
|
|
187
|
+
f"hook {hook_fn.__name__} returned incorrect length {new_res_len}, expected {pre_res_len}."
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class _HookUtils:
|
|
192
|
+
r"""
|
|
193
|
+
Internal utility class for hook registration and execution.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
@staticmethod
|
|
197
|
+
def register_hook(hook_dict, hook_fn):
|
|
198
|
+
"""
|
|
199
|
+
Register hook
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
hook_dict (dict): hook dict.
|
|
203
|
+
hook_fn (function): hook function.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
tuple: Updated hook_dict and HookHandle object.
|
|
207
|
+
"""
|
|
208
|
+
if hook_dict is None:
|
|
209
|
+
hook_dict = OrderedDict()
|
|
210
|
+
handle = HookHandle(hook_dict)
|
|
211
|
+
hook_dict[handle.handle_id] = hook_fn
|
|
212
|
+
return hook_dict, handle
|
|
213
|
+
|
|
214
|
+
@staticmethod
|
|
215
|
+
def run_hook(hook_dict, args):
|
|
216
|
+
"""
|
|
217
|
+
Run all hooks in the hook_dict with the given arguments.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
hook_dict (dict): Dictionary of registered hooks.
|
|
221
|
+
args (tuple): Arguments to pass to the hook functions.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Modified first argument if any hook returns a new value; otherwise, None.
|
|
225
|
+
"""
|
|
226
|
+
is_modify = False
|
|
227
|
+
args_list = list(args)
|
|
228
|
+
# Note: We create a list from hook_dict.values() to ensure safe iteration.
|
|
229
|
+
for hook_fn in list(hook_dict.values()):
|
|
230
|
+
res = hook_fn(*args_list)
|
|
231
|
+
if res is not None:
|
|
232
|
+
_check_hook_results(args_list[0], res, hook_fn)
|
|
233
|
+
args_list[0] = res
|
|
234
|
+
is_modify = True
|
|
235
|
+
return args_list[0] if is_modify else None
|
mindspore/common/jit_config.py
CHANGED
|
@@ -27,7 +27,11 @@ class JitConfig:
|
|
|
27
27
|
adopt KernelByKernel execution mode.
|
|
28
28
|
- ``"O1"``: Using commonly used optimizations and automatic operator fusion optimizations,
|
|
29
29
|
adopt KernelByKernel execution mode.
|
|
30
|
-
- ``"O2"``:
|
|
30
|
+
- ``"O2"``: Utilizes the GraphEngine, a graph compilation and execution engine within CANN,
|
|
31
|
+
for Ascend model compilation and execution. Note: O2 only supports GRAPH Mode in Ascend,
|
|
32
|
+
only supports whole graph sinking or sub graph sinking in pipeline parallel, and does not support
|
|
33
|
+
dynamic shape scenes. In addition, this mode incurs additional compilation costs and is difficult to
|
|
34
|
+
debug and tune.
|
|
31
35
|
|
|
32
36
|
exc_mode (str, optional): Control the execution mode of the model.
|
|
33
37
|
Supports ["auto", "sink", "no_sink"]. Default: ``"auto"`` .
|
mindspore/common/jit_trace.py
CHANGED
|
@@ -28,6 +28,7 @@ from mindspore._c_expression import TraceRecorder as tr
|
|
|
28
28
|
from mindspore._c_expression import JitExecutor_
|
|
29
29
|
from mindspore._c_expression import TensorPy as Tensor, CSRTensor, COOTensor
|
|
30
30
|
from mindspore._c_expression import typing
|
|
31
|
+
from mindspore.common.jit_config import JitConfig
|
|
31
32
|
|
|
32
33
|
|
|
33
34
|
class TraceJitContext(JitContext):
|
|
@@ -123,19 +124,19 @@ def nested_run(obj, cell, *args):
|
|
|
123
124
|
return file_names, linenos, res
|
|
124
125
|
|
|
125
126
|
|
|
126
|
-
def _jit_trace():
|
|
127
|
+
def _jit_trace(jit_config):
|
|
127
128
|
"""Return the wrapped function for trace mode jit."""
|
|
128
129
|
def wrap_func(fn):
|
|
129
130
|
if hasattr(fn, "construct"):
|
|
130
131
|
if isinstance(fn, ms.nn.Cell):
|
|
131
132
|
# Bound the cell object to get the self arg.
|
|
132
|
-
return types.MethodType(_jit_trace()(fn.construct.__func__), fn)
|
|
133
|
+
return types.MethodType(_jit_trace(jit_config)(fn.construct.__func__), fn)
|
|
133
134
|
if isinstance(fn, type) and issubclass(fn, ms.nn.Cell):
|
|
134
|
-
fn.construct = _jit_trace()(fn.construct)
|
|
135
|
+
fn.construct = _jit_trace(jit_config)(fn.construct)
|
|
135
136
|
return fn
|
|
136
137
|
|
|
137
138
|
if isinstance(fn, types.MethodType):
|
|
138
|
-
return types.MethodType(_jit_trace()(fn.__func__), fn.__self__)
|
|
139
|
+
return types.MethodType(_jit_trace(jit_config)(fn.__func__), fn.__self__)
|
|
139
140
|
|
|
140
141
|
if not isinstance(fn, types.FunctionType):
|
|
141
142
|
logger.warning(f"The fn should be function, method or cell instance/class, but got {fn}")
|
|
@@ -150,6 +151,10 @@ def _jit_trace():
|
|
|
150
151
|
if jit_context():
|
|
151
152
|
return fn(*args, **kwargs)
|
|
152
153
|
# Start trace process.
|
|
154
|
+
if jit_config:
|
|
155
|
+
jit_config_dict = jit_config.jit_config_dict
|
|
156
|
+
else:
|
|
157
|
+
jit_config_dict = JitConfig().jit_config_dict
|
|
153
158
|
if kwargs:
|
|
154
159
|
bound_arguments = inspect.signature(fn).bind(*args, **kwargs)
|
|
155
160
|
bound_arguments.apply_defaults()
|
|
@@ -170,14 +175,16 @@ def _jit_trace():
|
|
|
170
175
|
line_str = fn.__code__.co_filename + ":" + str(fn.__code__.co_firstlineno)
|
|
171
176
|
generate_name = generate_name + '#[' + line_str + ']'
|
|
172
177
|
|
|
173
|
-
new_compile = _jit_trace_begin(
|
|
178
|
+
new_compile = _jit_trace_begin(
|
|
179
|
+
generate_name, *jit_args, jit_config=jit_config_dict)
|
|
174
180
|
if new_compile:
|
|
175
181
|
fn_res = fn(*args, **kwargs)
|
|
176
182
|
logger.debug(f'fn: {fn}, fn_res: {fn_res}, line: {line_str}')
|
|
177
183
|
# Use fn's output to build func graph's output.
|
|
178
|
-
output = _jit_trace_end(fn_res)
|
|
184
|
+
output = _jit_trace_end(fn_res, jit_config=jit_config_dict)
|
|
179
185
|
else:
|
|
180
|
-
|
|
186
|
+
# Run with compilation.
|
|
187
|
+
output = _jit_trace_end(None, jit_config=jit_config_dict)
|
|
181
188
|
logger.debug(f'output: {output}')
|
|
182
189
|
return output
|
|
183
190
|
|
|
@@ -224,7 +231,7 @@ def _get_args_for_run(args):
|
|
|
224
231
|
return tuple(new_args)
|
|
225
232
|
|
|
226
233
|
|
|
227
|
-
def _jit_trace_begin(fn_name, *args):
|
|
234
|
+
def _jit_trace_begin(fn_name, *args, **kwargs):
|
|
228
235
|
"""
|
|
229
236
|
Start to build a MindIR func graph for a code snippet by trace method.
|
|
230
237
|
|
|
@@ -257,6 +264,10 @@ def _jit_trace_begin(fn_name, *args):
|
|
|
257
264
|
...
|
|
258
265
|
>>> out = tensor_add(x, y)
|
|
259
266
|
"""
|
|
267
|
+
if "jit_config" in kwargs:
|
|
268
|
+
jit_config = kwargs["jit_config"]
|
|
269
|
+
else:
|
|
270
|
+
jit_config = JitConfig().jit_config_dict
|
|
260
271
|
global _using_trace
|
|
261
272
|
if _using_trace:
|
|
262
273
|
raise RuntimeError(
|
|
@@ -279,7 +290,7 @@ def _jit_trace_begin(fn_name, *args):
|
|
|
279
290
|
if not _compile_only and phase in _trace_compile_cache:
|
|
280
291
|
logger.debug('Had compiled, just run.')
|
|
281
292
|
_trace_jit_context.compiled = True
|
|
282
|
-
output = tr.get_instance().run_graph(phase, args)
|
|
293
|
+
output = tr.get_instance().run_graph(phase, jit_config, args)
|
|
283
294
|
from mindspore.common.api import _convert_python_data
|
|
284
295
|
_trace_jit_context.result = _convert_python_data(output)
|
|
285
296
|
logger.debug(f'jit trace result: {_trace_jit_context.result}')
|
|
@@ -295,7 +306,7 @@ def _jit_trace_begin(fn_name, *args):
|
|
|
295
306
|
return True
|
|
296
307
|
|
|
297
308
|
|
|
298
|
-
def _jit_trace_end(*output_args):
|
|
309
|
+
def _jit_trace_end(*output_args, **kwargs):
|
|
299
310
|
"""
|
|
300
311
|
Finish building a MindIR func graph for a code snippet by trace method.
|
|
301
312
|
|
|
@@ -330,19 +341,23 @@ def _jit_trace_end(*output_args):
|
|
|
330
341
|
...
|
|
331
342
|
>>> out = tensor_add(x, y)
|
|
332
343
|
"""
|
|
344
|
+
if "jit_config" in kwargs:
|
|
345
|
+
jit_config = kwargs["jit_config"]
|
|
346
|
+
else:
|
|
347
|
+
jit_config = JitConfig().jit_config_dict
|
|
333
348
|
if _trace_jit_context.compiled:
|
|
334
349
|
output = _trace_jit_context.result
|
|
335
350
|
logger.debug(f'jit trace result: {output}')
|
|
336
351
|
else:
|
|
337
352
|
logger.debug(f'output_args: {output_args}')
|
|
338
353
|
file_names, linenos = _get_caller_lines()
|
|
339
|
-
tr.get_instance().end_graph(file_names, linenos, *output_args)
|
|
354
|
+
tr.get_instance().end_graph(file_names, linenos, jit_config, *output_args)
|
|
340
355
|
if _compile_only:
|
|
341
356
|
output = output_args[0] if len(output_args) == 1 else output_args
|
|
342
357
|
else:
|
|
343
358
|
args = _get_args_for_run(_trace_jit_context.args)
|
|
344
359
|
output = tr.get_instance().run_graph(
|
|
345
|
-
_trace_jit_context.phase, args)
|
|
360
|
+
_trace_jit_context.phase, jit_config, args)
|
|
346
361
|
from mindspore.common.api import _convert_python_data
|
|
347
362
|
output = _convert_python_data(output)
|
|
348
363
|
logger.debug(f'jit trace result: {output}')
|
mindspore/common/lazy_inline.py
CHANGED
|
@@ -32,9 +32,11 @@ def lazy_inline(fn=None, attrs=None, policy=None):
|
|
|
32
32
|
static_graph_expert_programming.html#using-lazy-inline-decorator>`_ .
|
|
33
33
|
|
|
34
34
|
.. warning::
|
|
35
|
-
This feature is only supported on Ascend and is not supported on other hardwares.
|
|
36
|
-
The construct parameters must be positional or key word arguments and have not default values.
|
|
37
|
-
The cell has not switch sub graph.
|
|
35
|
+
- This feature is only supported on Ascend and is not supported on other hardwares.
|
|
36
|
+
- The construct parameters must be positional or key word arguments and have not default values.
|
|
37
|
+
- The cell has not switch sub graph.
|
|
38
|
+
- In the gradient accumulation scenario, it is recommended to use the @lazy_inline decorator to
|
|
39
|
+
reduce compilation time, and this decorator is only allowed to configure on the outermost cell.
|
|
38
40
|
|
|
39
41
|
Args:
|
|
40
42
|
fn (function): `__init__` function of a cell.
|
mindspore/common/np_dtype.py
CHANGED
|
@@ -16,10 +16,10 @@
|
|
|
16
16
|
# ============================================================================
|
|
17
17
|
"""Numpy data type for MindSpore."""
|
|
18
18
|
|
|
19
|
-
from mindspore._c_expression.np_dtypes import
|
|
20
|
-
if
|
|
19
|
+
from mindspore._c_expression.np_dtypes import np_dtype_valid
|
|
20
|
+
if np_dtype_valid(True):
|
|
21
21
|
from mindspore._c_expression.np_dtypes import bfloat16 # pylint: disable=unused-import
|
|
22
22
|
|
|
23
23
|
__all__ = []
|
|
24
|
-
if
|
|
24
|
+
if np_dtype_valid(False):
|
|
25
25
|
__all__.extend(["bfloat16"])
|