mindspore 2.6.0rc1__cp39-cp39-win_amd64.whl → 2.7.0rc1__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +1 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +40 -9
- mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
- mindspore/_extends/optimize/cell_utils.py +96 -0
- mindspore/_extends/parse/__init__.py +2 -2
- mindspore/_extends/parse/compile_config.py +44 -22
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -1
- mindspore/_extends/parse/parser.py +37 -62
- mindspore/_extends/parse/resources.py +39 -0
- mindspore/_extends/parse/standard_method.py +43 -13
- mindspore/_extends/parse/trope.py +8 -1
- mindspore/_extends/pijit/__init__.py +1 -2
- mindspore/amp.py +4 -4
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +4 -4
- mindspore/common/__init__.py +27 -2
- mindspore/common/_grad_function.py +2 -1
- mindspore/common/_pijit_context.py +28 -7
- mindspore/common/_stub_tensor.py +1 -209
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +77 -16
- mindspore/common/api.py +238 -113
- mindspore/common/dtype.py +21 -11
- mindspore/common/dump.py +10 -15
- mindspore/common/generator.py +5 -3
- mindspore/common/hook_handle.py +11 -2
- mindspore/common/jit_config.py +1 -1
- mindspore/common/jit_trace.py +84 -105
- mindspore/common/parameter.py +26 -12
- mindspore/common/recompute.py +3 -3
- mindspore/common/sparse_tensor.py +0 -3
- mindspore/common/symbol.py +0 -1
- mindspore/common/tensor.py +81 -81
- mindspore/communication/_comm_helper.py +46 -4
- mindspore/communication/management.py +79 -7
- mindspore/context.py +58 -40
- mindspore/dataset/core/config.py +3 -3
- mindspore/dataset/engine/datasets.py +20 -7
- mindspore/dataset/engine/datasets_user_defined.py +33 -3
- mindspore/dataset/engine/iterators.py +2 -2
- mindspore/dataset/engine/obs/config_loader.py +2 -2
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
- mindspore/dataset/transforms/py_transforms.py +7 -3
- mindspore/dataset/transforms/transforms.py +7 -3
- mindspore/dataset/vision/validators.py +1 -0
- mindspore/device_context/ascend/device.py +1 -1
- mindspore/device_context/gpu/__init__.py +2 -2
- mindspore/device_context/gpu/device.py +1 -1
- mindspore/device_context/gpu/op_precision.py +4 -2
- mindspore/device_context/gpu/op_tuning.py +6 -3
- mindspore/device_manager.py +16 -9
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -7
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/optim/adadelta.py +13 -20
- mindspore/experimental/optim/adagrad.py +15 -22
- mindspore/experimental/optim/adam.py +17 -24
- mindspore/experimental/optim/adamax.py +14 -22
- mindspore/experimental/optim/adamw.py +28 -34
- mindspore/experimental/optim/asgd.py +15 -25
- mindspore/experimental/optim/lr_scheduler.py +27 -45
- mindspore/experimental/optim/nadam.py +14 -24
- mindspore/experimental/optim/optimizer.py +13 -23
- mindspore/experimental/optim/radam.py +18 -24
- mindspore/experimental/optim/rmsprop.py +14 -25
- mindspore/experimental/optim/rprop.py +15 -26
- mindspore/experimental/optim/sgd.py +9 -19
- mindspore/hal/__init__.py +4 -4
- mindspore/hal/contiguous_tensors_handle.py +2 -2
- mindspore/hal/memory.py +27 -7
- mindspore/include/api/cell.h +37 -1
- mindspore/include/api/delegate.h +10 -0
- mindspore/include/api/model.h +3 -0
- mindspore/include/api/types.h +2 -2
- mindspore/include/c_api/model_c.h +0 -58
- mindspore/include/c_api/tensor_c.h +0 -26
- mindspore/include/dataset/vision_ascend.h +1 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar10.py +60 -11
- mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mindspore_ops_host.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +6 -46
- mindspore/mint/distributed/__init__.py +1 -0
- mindspore/mint/distributed/distributed.py +212 -9
- mindspore/mint/nn/__init__.py +1 -1
- mindspore/mint/nn/functional.py +53 -6
- mindspore/mint/nn/layer/_functions.py +164 -294
- mindspore/mint/nn/layer/activation.py +8 -6
- mindspore/mint/nn/layer/conv.py +137 -101
- mindspore/mint/nn/layer/normalization.py +8 -22
- mindspore/mint/optim/adam.py +19 -18
- mindspore/mint/optim/adamw.py +14 -8
- mindspore/mint/optim/sgd.py +5 -5
- mindspore/nn/cell.py +328 -502
- mindspore/nn/grad/cell_grad.py +11 -12
- mindspore/nn/layer/activation.py +32 -34
- mindspore/nn/layer/basic.py +67 -64
- mindspore/nn/layer/channel_shuffle.py +4 -4
- mindspore/nn/layer/combined.py +4 -2
- mindspore/nn/layer/conv.py +117 -110
- mindspore/nn/layer/dense.py +9 -7
- mindspore/nn/layer/embedding.py +50 -52
- mindspore/nn/layer/image.py +37 -39
- mindspore/nn/layer/math.py +111 -112
- mindspore/nn/layer/normalization.py +56 -44
- mindspore/nn/layer/pooling.py +58 -63
- mindspore/nn/layer/rnn_cells.py +33 -33
- mindspore/nn/layer/rnns.py +56 -56
- mindspore/nn/layer/thor_layer.py +74 -73
- mindspore/nn/layer/transformer.py +11 -1
- mindspore/nn/learning_rate_schedule.py +20 -20
- mindspore/nn/loss/loss.py +79 -81
- mindspore/nn/optim/adam.py +3 -3
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -0
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -1
- mindspore/nn/probability/distribution/poisson.py +2 -1
- mindspore/nn/sparse/sparse.py +3 -3
- mindspore/nn/wrap/cell_wrapper.py +34 -37
- mindspore/nn/wrap/grad_reducer.py +37 -37
- mindspore/nn/wrap/loss_scale.py +72 -74
- mindspore/numpy/array_creations.py +5 -5
- mindspore/numpy/fft.py +1 -1
- mindspore/numpy/math_ops.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
- mindspore/ops/_vmap/vmap_array_ops.py +31 -13
- mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +42 -11
- mindspore/ops/auto_generate/gen_extend_func.py +23 -141
- mindspore/ops/auto_generate/gen_ops_def.py +727 -321
- mindspore/ops/auto_generate/gen_ops_prim.py +1721 -984
- mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
- mindspore/ops/composite/__init__.py +10 -0
- mindspore/ops/composite/base.py +8 -4
- mindspore/ops/composite/multitype_ops/__init__.py +12 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +133 -109
- mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
- mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
- mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
- mindspore/ops/function/__init__.py +3 -1
- mindspore/ops/function/_add_attr_func.py +11 -6
- mindspore/ops/function/array_func.py +9 -96
- mindspore/ops/function/debug_func.py +4 -3
- mindspore/ops/function/grad/grad_func.py +1 -1
- mindspore/ops/function/math_func.py +33 -540
- mindspore/ops/function/nn_func.py +28 -74
- mindspore/ops/function/other_func.py +4 -1
- mindspore/ops/function/random_func.py +44 -5
- mindspore/ops/function/vmap_func.py +2 -1
- mindspore/ops/functional.py +2 -3
- mindspore/ops/functional_overload.py +571 -6
- mindspore/ops/op_info_register.py +21 -0
- mindspore/ops/operations/__init__.py +16 -11
- mindspore/ops/operations/_custom_ops_utils.py +689 -34
- mindspore/ops/operations/_inner_ops.py +3 -6
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +2 -2
- mindspore/ops/operations/comm_ops.py +185 -26
- mindspore/ops/operations/custom_ops.py +294 -174
- mindspore/ops/operations/debug_ops.py +59 -4
- mindspore/ops/operations/image_ops.py +13 -13
- mindspore/ops/operations/manually_defined/ops_def.py +15 -16
- mindspore/ops/operations/math_ops.py +3 -4
- mindspore/ops/operations/nn_ops.py +7 -39
- mindspore/ops/primitive.py +6 -10
- mindspore/ops/tensor_method.py +47 -8
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
- mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
- mindspore/ops_generate/api/functions_cc_generator.py +58 -10
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
- mindspore/ops_generate/common/base_generator.py +14 -0
- mindspore/ops_generate/common/gen_constants.py +8 -3
- mindspore/ops_generate/common/gen_utils.py +0 -19
- mindspore/ops_generate/common/op_proto.py +11 -4
- mindspore/ops_generate/common/template.py +88 -11
- mindspore/ops_generate/gen_ops.py +1 -1
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
- mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
- mindspore/parallel/_auto_parallel_context.py +11 -8
- mindspore/parallel/_cell_wrapper.py +113 -45
- mindspore/parallel/_parallel_serialization.py +1 -1
- mindspore/parallel/_ps_context.py +4 -6
- mindspore/parallel/_tensor.py +167 -12
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/transformer.py +13 -8
- mindspore/parallel/auto_parallel.py +14 -7
- mindspore/parallel/checkpoint_convert.py +3 -3
- mindspore/parallel/checkpoint_transform.py +11 -7
- mindspore/parallel/cluster/process_entity/_api.py +84 -48
- mindspore/parallel/cluster/process_entity/_utils.py +95 -7
- mindspore/parallel/cluster/run.py +43 -4
- mindspore/parallel/function/__init__.py +8 -1
- mindspore/parallel/function/reshard_func.py +6 -7
- mindspore/parallel/nn/__init__.py +15 -2
- mindspore/parallel/nn/parallel_cell_wrapper.py +9 -10
- mindspore/parallel/nn/parallel_grad_reducer.py +7 -6
- mindspore/parallel/shard.py +3 -4
- mindspore/parallel/transform_safetensors.py +463 -174
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +12 -6
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
- mindspore/profiler/analysis/task_manager.py +1 -1
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +42 -22
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
- mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
- mindspore/profiler/common/constant.py +16 -0
- mindspore/profiler/common/profiler_context.py +25 -27
- mindspore/profiler/common/profiler_info.py +0 -16
- mindspore/profiler/common/profiler_op_analyse.py +235 -0
- mindspore/profiler/common/profiler_output_path.py +23 -8
- mindspore/profiler/common/profiler_parameters.py +128 -35
- mindspore/profiler/dynamic_profile/__init__.py +0 -0
- mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
- mindspore/profiler/dynamic_profiler.py +305 -314
- mindspore/profiler/envprofiler.py +12 -7
- mindspore/profiler/experimental_config.py +96 -6
- mindspore/profiler/mstx.py +33 -12
- mindspore/profiler/platform/__init__.py +2 -3
- mindspore/profiler/platform/npu_profiler.py +29 -19
- mindspore/profiler/profiler.py +35 -19
- mindspore/profiler/profiler_action_controller.py +64 -76
- mindspore/profiler/schedule.py +10 -4
- mindspore/rewrite/common/config.py +1 -0
- mindspore/rewrite/common/namer.py +1 -0
- mindspore/rewrite/common/namespace.py +1 -0
- mindspore/rewrite/node/node.py +31 -11
- mindspore/rewrite/parsers/assign_parser.py +1 -1
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/run_check/_check_version.py +7 -10
- mindspore/runtime/__init__.py +5 -5
- mindspore/runtime/event.py +10 -4
- mindspore/runtime/executor.py +60 -45
- mindspore/runtime/memory.py +30 -32
- mindspore/runtime/thread_bind_core.py +298 -164
- mindspore/safeguard/rewrite_obfuscation.py +12 -13
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +14 -4
- mindspore/train/amp.py +43 -20
- mindspore/train/callback/__init__.py +5 -5
- mindspore/train/callback/_checkpoint.py +3 -6
- mindspore/train/callback/_flops_collector.py +1 -1
- mindspore/train/callback/_landscape.py +0 -1
- mindspore/train/callback/_train_fault_tolerance.py +97 -16
- mindspore/train/data_sink.py +11 -2
- mindspore/train/dataset_helper.py +9 -0
- mindspore/train/model.py +135 -55
- mindspore/train/serialization.py +133 -111
- mindspore/train/summary/summary_record.py +13 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +0 -6
- mindspore/utils/runtime_execution_order_check.py +163 -77
- mindspore/utils/sdc_detect.py +68 -0
- mindspore/utils/utils.py +6 -9
- mindspore/version.py +1 -1
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/METADATA +5 -4
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/RECORD +333 -371
- mindspore/_deprecated/jit.py +0 -198
- mindspore/experimental/es/__init__.py +0 -22
- mindspore/experimental/es/embedding_service.py +0 -891
- mindspore/experimental/es/embedding_service_layer.py +0 -581
- mindspore/profiler/parser/__init__.py +0 -14
- mindspore/profiler/parser/aicpu_data_parser.py +0 -272
- mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
- mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
- mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
- mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
- mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
- mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
- mindspore/profiler/parser/ascend_flops_generator.py +0 -116
- mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
- mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
- mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
- mindspore/profiler/parser/ascend_memory_generator.py +0 -185
- mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
- mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
- mindspore/profiler/parser/ascend_op_generator.py +0 -334
- mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
- mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
- mindspore/profiler/parser/base_timeline_generator.py +0 -483
- mindspore/profiler/parser/container.py +0 -229
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
- mindspore/profiler/parser/flops_parser.py +0 -531
- mindspore/profiler/parser/framework_enum.py +0 -111
- mindspore/profiler/parser/framework_parser.py +0 -464
- mindspore/profiler/parser/framework_struct.py +0 -61
- mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
- mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
- mindspore/profiler/parser/hccl_parser.py +0 -573
- mindspore/profiler/parser/hwts_log_parser.py +0 -122
- mindspore/profiler/parser/integrator.py +0 -526
- mindspore/profiler/parser/memory_usage_parser.py +0 -277
- mindspore/profiler/parser/minddata_analyzer.py +0 -800
- mindspore/profiler/parser/minddata_parser.py +0 -186
- mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
- mindspore/profiler/parser/op_intermediate_parser.py +0 -149
- mindspore/profiler/parser/optime_parser.py +0 -250
- mindspore/profiler/parser/profiler_info.py +0 -213
- mindspore/profiler/parser/step_trace_parser.py +0 -666
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0rc1.dist-info}/top_level.txt +0 -0
mindspore/common/api.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
2
2
|
#
|
|
3
|
-
# Copyright 2020-
|
|
3
|
+
# Copyright 2020-2025 Huawei Technologies Co., Ltd
|
|
4
4
|
#
|
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
6
|
# you may not use this file except in compliance with the License.
|
|
@@ -17,6 +17,8 @@
|
|
|
17
17
|
"""Providing interface methods."""
|
|
18
18
|
from __future__ import absolute_import
|
|
19
19
|
|
|
20
|
+
__all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']
|
|
21
|
+
|
|
20
22
|
import gc
|
|
21
23
|
import types
|
|
22
24
|
import sys
|
|
@@ -42,7 +44,7 @@ from mindspore.common.sparse_tensor import RowTensor as PythonRowTensor
|
|
|
42
44
|
from mindspore._c_expression.amp import get_curr_amp_strategy
|
|
43
45
|
from mindspore._c_expression import GraphExecutor_, JitExecutor_, CSRTensor, RowTensor, COOTensor, \
|
|
44
46
|
PyNativeExecutor_, verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_pipeline, \
|
|
45
|
-
_run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx,
|
|
47
|
+
_run_jit_pipeline, _ms_memory_recycle, _bind_device_ctx, MSContext, TensorPy as Tensor
|
|
46
48
|
from mindspore.parallel._ps_context import _is_role_sched
|
|
47
49
|
from mindspore.parallel._utils import _check_full_batch, _get_parameter_broadcast, _is_in_auto_parallel_mode, \
|
|
48
50
|
_is_parallel_mode
|
|
@@ -58,7 +60,7 @@ from mindspore.common.jit_context import jit_context
|
|
|
58
60
|
from mindspore.common.jit_trace import _jit_trace
|
|
59
61
|
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
60
62
|
|
|
61
|
-
# Store
|
|
63
|
+
# Store jit class compiled pipeline cache.
|
|
62
64
|
ms_compile_cache = set()
|
|
63
65
|
# Store cell compiled pipeline cache.
|
|
64
66
|
cells_compile_cache = {}
|
|
@@ -134,8 +136,6 @@ def _convert_python_data(data):
|
|
|
134
136
|
"""
|
|
135
137
|
if isinstance(data, PythonTensor):
|
|
136
138
|
return data
|
|
137
|
-
if isinstance(data, StubNode):
|
|
138
|
-
return ms.common._stub_tensor._convert_stub(data)
|
|
139
139
|
if data.__class__ is tuple:
|
|
140
140
|
# Handle namedtuple since its type is tuple.
|
|
141
141
|
if hasattr(data, "_fields"):
|
|
@@ -278,13 +278,13 @@ def __get_compile_cache_dep_files(file_path, compile_cache_dep_files, pkg):
|
|
|
278
278
|
module = importlib.util.module_from_spec(module_spec)
|
|
279
279
|
if hasattr(module, '__file__'):
|
|
280
280
|
dep_file_path = module.__file__
|
|
281
|
+
# Exclude the installed modules.
|
|
282
|
+
if not _in_sys_path(dep_file_path) and dep_file_path not in compile_cache_dep_files:
|
|
283
|
+
logger.debug(f"dependent file path: {dep_file_path}")
|
|
284
|
+
compile_cache_dep_files.append(dep_file_path)
|
|
285
|
+
__get_compile_cache_dep_files(dep_file_path, compile_cache_dep_files, module.__package__)
|
|
281
286
|
else:
|
|
282
287
|
continue
|
|
283
|
-
# Exclude the installed modules.
|
|
284
|
-
if not _in_sys_path(dep_file_path) and dep_file_path not in compile_cache_dep_files:
|
|
285
|
-
logger.debug(f"dependent file path: {dep_file_path}")
|
|
286
|
-
compile_cache_dep_files.append(dep_file_path)
|
|
287
|
-
__get_compile_cache_dep_files(dep_file_path, compile_cache_dep_files, module.__package__)
|
|
288
288
|
|
|
289
289
|
|
|
290
290
|
def _get_compile_cache_dep_files():
|
|
@@ -342,7 +342,7 @@ def _get_parameter_layout():
|
|
|
342
342
|
return layout
|
|
343
343
|
|
|
344
344
|
|
|
345
|
-
def _handle_arg(obj, arg,
|
|
345
|
+
def _handle_arg(obj, arg, has_mutable_arg):
|
|
346
346
|
"""Handle arg for runtime .If need handle the arg, return True"""
|
|
347
347
|
from mindspore._extends.parse import compile_config
|
|
348
348
|
if isinstance(arg, PythonTensor):
|
|
@@ -352,7 +352,7 @@ def _handle_arg(obj, arg, compile_arg):
|
|
|
352
352
|
return arg
|
|
353
353
|
elif isinstance(arg, (Tensor, CSRTensor, COOTensor)):
|
|
354
354
|
return arg
|
|
355
|
-
elif
|
|
355
|
+
elif has_mutable_arg:
|
|
356
356
|
# mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
|
|
357
357
|
if isinstance(arg, list) and not arg:
|
|
358
358
|
return None
|
|
@@ -366,7 +366,7 @@ def _handle_arg(obj, arg, compile_arg):
|
|
|
366
366
|
return None
|
|
367
367
|
|
|
368
368
|
|
|
369
|
-
def _handle_arg_predict(obj, arg,
|
|
369
|
+
def _handle_arg_predict(obj, arg, has_mutable_arg):
|
|
370
370
|
"""Handle arg for runtime .If need handle the arg, return True"""
|
|
371
371
|
if arg is None:
|
|
372
372
|
return None
|
|
@@ -375,8 +375,7 @@ def _handle_arg_predict(obj, arg, compile_arg):
|
|
|
375
375
|
return None
|
|
376
376
|
|
|
377
377
|
if isinstance(arg, (list, tuple)):
|
|
378
|
-
if
|
|
379
|
-
getattr(compile_arg, "__ms_mutable__"):
|
|
378
|
+
if has_mutable_arg:
|
|
380
379
|
# mutable([]) will be eliminated by FuncGraphSpecializer, and empty list is not supported by backend.
|
|
381
380
|
if isinstance(arg, list) and not arg:
|
|
382
381
|
return None
|
|
@@ -388,35 +387,30 @@ def _handle_arg_predict(obj, arg, compile_arg):
|
|
|
388
387
|
return arg
|
|
389
388
|
|
|
390
389
|
|
|
391
|
-
def _get_args_for_run(obj, args, kwargs,
|
|
390
|
+
def _get_args_for_run(obj, args, kwargs, has_mutable_args_list, is_predict):
|
|
392
391
|
"""Get the actual input args and kwargs for runtime."""
|
|
393
392
|
new_args = []
|
|
394
|
-
|
|
395
|
-
|
|
393
|
+
fn = _handle_arg_predict if is_predict else _handle_arg
|
|
394
|
+
for arg, has_mutable_arg in zip(args, has_mutable_args_list):
|
|
395
|
+
new_arg = fn(obj, arg, has_mutable_arg)
|
|
396
396
|
if new_arg is not None:
|
|
397
397
|
new_args.append(new_arg)
|
|
398
398
|
|
|
399
399
|
for _, value in kwargs.items():
|
|
400
|
-
new_value =
|
|
400
|
+
new_value = fn(obj, value, None)
|
|
401
401
|
if new_value is not None:
|
|
402
402
|
new_args.append(new_value)
|
|
403
403
|
|
|
404
404
|
return new_args
|
|
405
405
|
|
|
406
406
|
|
|
407
|
-
def
|
|
408
|
-
"""Get
|
|
407
|
+
def _get_mutable_flags(compile_args):
|
|
408
|
+
"""Get a list of booleans indicating whether each argument is marked as mutable"""
|
|
409
409
|
new_args = []
|
|
410
|
-
for
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
for _, value in kwargs.items():
|
|
416
|
-
new_value = _handle_arg_predict(obj, value, None)
|
|
417
|
-
if new_value is not None:
|
|
418
|
-
new_args.append(new_value)
|
|
419
|
-
|
|
410
|
+
for compile_arg in compile_args:
|
|
411
|
+
has_mutable_arg = compile_arg is not None and hasattr(compile_arg, "__ms_mutable__") and \
|
|
412
|
+
getattr(compile_arg, "__ms_mutable__")
|
|
413
|
+
new_args.append(has_mutable_arg)
|
|
420
414
|
return new_args
|
|
421
415
|
|
|
422
416
|
|
|
@@ -586,7 +580,8 @@ class _JitExecutor:
|
|
|
586
580
|
The result of pipeline running in graph mode.
|
|
587
581
|
"""
|
|
588
582
|
|
|
589
|
-
def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0
|
|
583
|
+
def __init__(self, fn, ms_create_time, input_signature=None, obj=None, jit_config=None, dynamic=0,
|
|
584
|
+
cell_cache_key_extend=''):
|
|
590
585
|
init_pipeline()
|
|
591
586
|
if not isinstance(fn, (types.FunctionType, types.MethodType)):
|
|
592
587
|
raise RuntimeError('fn {} is not function or method'.format(fn))
|
|
@@ -606,9 +601,57 @@ class _JitExecutor:
|
|
|
606
601
|
self._compile_args = None
|
|
607
602
|
self._enable_auto_dynamic = dynamic == 1
|
|
608
603
|
self.jit_config_dict = jit_config.jit_config_dict if jit_config else None
|
|
604
|
+
self._cell_cache_key_extend = cell_cache_key_extend
|
|
605
|
+
|
|
606
|
+
def _predict(self, *args, **kwargs):
|
|
607
|
+
"""Dedicated routine for predict."""
|
|
608
|
+
if not hasattr(self.obj, "phase"):
|
|
609
|
+
return False, None
|
|
610
|
+
|
|
611
|
+
predict_vailid_phase = {"prefill", 'increment'}
|
|
612
|
+
predict_phase = self.obj.phase
|
|
613
|
+
if predict_phase not in predict_vailid_phase:
|
|
614
|
+
return False, None
|
|
615
|
+
|
|
616
|
+
args_list = args
|
|
617
|
+
if self.obj is not None:
|
|
618
|
+
args_list = args_list[1:]
|
|
619
|
+
|
|
620
|
+
if predict_phase not in self.obj.phase_cache:
|
|
621
|
+
try:
|
|
622
|
+
predict_phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
623
|
+
except Exception as err:
|
|
624
|
+
_pynative_executor.clear_res()
|
|
625
|
+
raise err
|
|
626
|
+
else: # get compiled args to generate run args by _generate_run_args
|
|
627
|
+
compile_args = self._generate_compile_args(args_list)
|
|
628
|
+
key_id = self._get_key_id()
|
|
629
|
+
compile_args = get_auto_dynamic_shape_args_with_check_input_signature(
|
|
630
|
+
compile_args,
|
|
631
|
+
key_id,
|
|
632
|
+
self.input_signature,
|
|
633
|
+
self._enable_auto_dynamic
|
|
634
|
+
)
|
|
635
|
+
self._compile_args = compile_args
|
|
636
|
+
|
|
637
|
+
new_inputs = self._generate_run_args(args_list, kwargs)
|
|
638
|
+
if self.jit_config_dict:
|
|
639
|
+
jit_config_dict = self.jit_config_dict
|
|
640
|
+
else:
|
|
641
|
+
jit_config_dict = JitConfig().jit_config_dict
|
|
642
|
+
self._graph_executor.set_jit_config(jit_config_dict)
|
|
643
|
+
output = self._graph_executor(
|
|
644
|
+
tuple(new_inputs),
|
|
645
|
+
self.obj.phase_cache[self.obj.phase]
|
|
646
|
+
)
|
|
647
|
+
res = _convert_python_data(output)
|
|
648
|
+
return True, res
|
|
609
649
|
|
|
610
650
|
@_wrap_func
|
|
611
651
|
def __call__(self, *args, **kwargs):
|
|
652
|
+
predict, res = self._predict(*args, **kwargs)
|
|
653
|
+
if predict:
|
|
654
|
+
return res
|
|
612
655
|
if jit_context() and jit_context().is_nested():
|
|
613
656
|
return jit_context().run_graph("", None, *())
|
|
614
657
|
args_list = args
|
|
@@ -616,12 +659,9 @@ class _JitExecutor:
|
|
|
616
659
|
args_list = args_list[1:]
|
|
617
660
|
phase = ""
|
|
618
661
|
try:
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
_pynative_executor.set_jit_compile_status(False, phase)
|
|
623
|
-
else:
|
|
624
|
-
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
662
|
+
_pynative_executor.set_jit_compile_status(True, phase)
|
|
663
|
+
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
|
|
664
|
+
_pynative_executor.set_jit_compile_status(False, phase)
|
|
625
665
|
except Exception as err:
|
|
626
666
|
_pynative_executor.clear_res()
|
|
627
667
|
raise err
|
|
@@ -630,15 +670,16 @@ class _JitExecutor:
|
|
|
630
670
|
return None
|
|
631
671
|
|
|
632
672
|
new_inputs = self._generate_run_args(args_list, kwargs)
|
|
633
|
-
if
|
|
634
|
-
|
|
673
|
+
if self.jit_config_dict:
|
|
674
|
+
jit_config_dict = self.jit_config_dict
|
|
635
675
|
else:
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
676
|
+
jit_config_dict = JitConfig().jit_config_dict
|
|
677
|
+
self._graph_executor.set_jit_config(jit_config_dict)
|
|
678
|
+
output = _pynative_executor.grad_jit(*new_inputs)
|
|
679
|
+
if jit_context():
|
|
680
|
+
if is_stub_tensor(output):
|
|
681
|
+
output = output.stub_sync()
|
|
682
|
+
return jit_context().run_graph(phase, output, *tuple(new_inputs))
|
|
642
683
|
return output
|
|
643
684
|
|
|
644
685
|
def compile(self, method_name, *args, **kwargs):
|
|
@@ -674,7 +715,7 @@ class _JitExecutor:
|
|
|
674
715
|
f'`{self.fn.__module__}`')
|
|
675
716
|
self.obj.__parse_method__ = method_name
|
|
676
717
|
if isinstance(self.obj, ms.nn.Cell):
|
|
677
|
-
generate_name = generate_name + '.' + str(self.obj.create_time)
|
|
718
|
+
generate_name = generate_name + '.' + str(self.obj.create_time) + self.obj.phase
|
|
678
719
|
create_time = str(self.obj.create_time)
|
|
679
720
|
else:
|
|
680
721
|
generate_name = generate_name + '.' + str(self._create_time)
|
|
@@ -705,6 +746,8 @@ class _JitExecutor:
|
|
|
705
746
|
|
|
706
747
|
update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
|
|
707
748
|
|
|
749
|
+
phase = phase + self._cell_cache_key_extend
|
|
750
|
+
|
|
708
751
|
if phase in ms_compile_cache and self._graph_executor.has_compiled(phase) and not parameter_hook_updated():
|
|
709
752
|
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
|
|
710
753
|
# generated in generate_arguments_key.
|
|
@@ -716,10 +759,9 @@ class _JitExecutor:
|
|
|
716
759
|
# If enable compile cache, get the dependency files list and set to graph executor.
|
|
717
760
|
self._set_compile_cache_dep_files()
|
|
718
761
|
if self.jit_config_dict:
|
|
719
|
-
self.
|
|
762
|
+
jit_config_dict = self.jit_config_dict
|
|
720
763
|
else:
|
|
721
764
|
jit_config_dict = JitConfig().jit_config_dict
|
|
722
|
-
self._graph_executor.set_jit_config(jit_config_dict)
|
|
723
765
|
|
|
724
766
|
if self.obj is None:
|
|
725
767
|
# Set an attribute to fn as an identifier.
|
|
@@ -727,7 +769,8 @@ class _JitExecutor:
|
|
|
727
769
|
setattr(self.fn.__func__, "__jit_function__", True)
|
|
728
770
|
else:
|
|
729
771
|
setattr(self.fn, "__jit_function__", True)
|
|
730
|
-
is_compile = self._graph_executor.compile(
|
|
772
|
+
is_compile = self._graph_executor.compile(
|
|
773
|
+
self.fn, compile_args, kwargs, phase, jit_config_dict)
|
|
731
774
|
if isinstance(self.fn, types.MethodType):
|
|
732
775
|
delattr(self.fn.__func__, "__jit_function__")
|
|
733
776
|
else:
|
|
@@ -735,12 +778,15 @@ class _JitExecutor:
|
|
|
735
778
|
else:
|
|
736
779
|
if isinstance(self.obj, ms.nn.Cell):
|
|
737
780
|
self._graph_executor.set_weights_values(self.obj.parameters_dict())
|
|
738
|
-
is_compile = self._graph_executor.compile(
|
|
781
|
+
is_compile = self._graph_executor.compile(
|
|
782
|
+
self.obj, compile_args, kwargs, phase, jit_config_dict)
|
|
739
783
|
|
|
740
784
|
if not is_compile:
|
|
741
785
|
raise RuntimeError("Executor compile failed.")
|
|
742
786
|
set_parameter_hook_updated(False)
|
|
743
787
|
ms_compile_cache.add(phase)
|
|
788
|
+
if hasattr(self.obj, "phase"):
|
|
789
|
+
self.obj.phase_cache[self.obj.phase] = phase
|
|
744
790
|
|
|
745
791
|
return phase
|
|
746
792
|
|
|
@@ -832,7 +878,7 @@ class _JitExecutor:
|
|
|
832
878
|
Returns:
|
|
833
879
|
new_inputs, new input args, which are required for running.
|
|
834
880
|
"""
|
|
835
|
-
return _get_args_for_run(self, args_list, kwargs, self._compile_args)
|
|
881
|
+
return _get_args_for_run(self, args_list, kwargs, _get_mutable_flags(self._compile_args), False)
|
|
836
882
|
|
|
837
883
|
def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False, incremental=False):
|
|
838
884
|
"""Get graph proto from pipeline."""
|
|
@@ -993,6 +1039,68 @@ def _check_options(options, backend):
|
|
|
993
1039
|
_check_option_value(option, value)
|
|
994
1040
|
|
|
995
1041
|
|
|
1042
|
+
def _jit_ast(hash_obj, dynamic, jit_config, jit_graph_name):
|
|
1043
|
+
"""Return the wrapped function for ast mode jit."""
|
|
1044
|
+
def wrap_func(func):
|
|
1045
|
+
nonlocal hash_obj
|
|
1046
|
+
if hasattr(func, "construct"):
|
|
1047
|
+
if isinstance(func, ms.nn.Cell):
|
|
1048
|
+
# Bound the cell object to get the self arg.
|
|
1049
|
+
return types.MethodType(_jit_ast(
|
|
1050
|
+
hash_obj, dynamic, jit_config, func._jit_graph_name)(func.construct.__func__), func)
|
|
1051
|
+
if isinstance(func, type) and issubclass(func, ms.nn.Cell):
|
|
1052
|
+
func.construct = _jit_ast(
|
|
1053
|
+
hash_obj, dynamic, jit_config, '')(func.construct)
|
|
1054
|
+
return func
|
|
1055
|
+
|
|
1056
|
+
if isinstance(func, types.MethodType):
|
|
1057
|
+
return types.MethodType(_jit_ast(hash_obj, dynamic, jit_config, '')(func.__func__), func.__self__)
|
|
1058
|
+
|
|
1059
|
+
if not isinstance(func, types.FunctionType):
|
|
1060
|
+
logger.warning(f"The func should be function, method or cell instance/class, but got {func}")
|
|
1061
|
+
return func
|
|
1062
|
+
|
|
1063
|
+
if hasattr(func, "__wrapped_by_jit__"):
|
|
1064
|
+
logger.warning(f"The func {func} should be wrapped by jit only once.")
|
|
1065
|
+
|
|
1066
|
+
if hash_obj is None or not _is_inner_func(func):
|
|
1067
|
+
hash_obj = int(time.time() * 1e9)
|
|
1068
|
+
|
|
1069
|
+
@wraps(func)
|
|
1070
|
+
def staging_specialize(*args, **kwargs):
|
|
1071
|
+
if os.getenv("MS_JIT") == '0':
|
|
1072
|
+
return func(*args, **kwargs)
|
|
1073
|
+
|
|
1074
|
+
args, kwargs = _handle_func_args(func, *args, **kwargs)
|
|
1075
|
+
process_obj = None
|
|
1076
|
+
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
|
1077
|
+
process_obj = args[0]
|
|
1078
|
+
# Handle auto mixed precision strategy.
|
|
1079
|
+
if not hasattr(func, "amp_strategy"):
|
|
1080
|
+
if isinstance(func, types.MethodType):
|
|
1081
|
+
setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
|
|
1082
|
+
else:
|
|
1083
|
+
setattr(func, "amp_strategy", get_curr_amp_strategy())
|
|
1084
|
+
|
|
1085
|
+
jit_graph_name = ''
|
|
1086
|
+
if hasattr(staging_specialize, "__jit_graph_name__"):
|
|
1087
|
+
jit_graph_name = staging_specialize.__jit_graph_name__
|
|
1088
|
+
jit_executor = _JitExecutor(
|
|
1089
|
+
func, hash_obj, None, process_obj, jit_config, dynamic, jit_graph_name)
|
|
1090
|
+
out = jit_executor(*args, **kwargs)
|
|
1091
|
+
return out
|
|
1092
|
+
|
|
1093
|
+
# `inspect.getfullargspec(func)` will get the specification of the decorated function by default. By set
|
|
1094
|
+
# `__signature__` for the decorated function, `inspect.getfullargspec(func)` will get the specification of
|
|
1095
|
+
# original `func`.
|
|
1096
|
+
staging_specialize.__signature__ = inspect.signature(func)
|
|
1097
|
+
setattr(staging_specialize, "__wrapped_by_jit__", True)
|
|
1098
|
+
setattr(staging_specialize, "__jit_graph_name__", jit_graph_name)
|
|
1099
|
+
return staging_specialize
|
|
1100
|
+
|
|
1101
|
+
return wrap_func
|
|
1102
|
+
|
|
1103
|
+
|
|
996
1104
|
def jit(
|
|
997
1105
|
function: Optional[Callable] = None,
|
|
998
1106
|
*,
|
|
@@ -1015,22 +1123,22 @@ def jit(
|
|
|
1015
1123
|
and the decoration @jit(capture_mode=“bytecode”) is considered invalid.
|
|
1016
1124
|
|
|
1017
1125
|
Args:
|
|
1018
|
-
function (
|
|
1126
|
+
function (Callable, optional): The Python function or Cell that will be run as a graph. Default: ``None``.
|
|
1019
1127
|
|
|
1020
1128
|
Keyword Args:
|
|
1021
1129
|
capture_mode (str, optional): The method to create a callable MindSpore graph. The value of capture_mode
|
|
1022
1130
|
should be ``ast`` , ``bytecode`` or ``trace`` . Default: ``ast`` .
|
|
1023
1131
|
|
|
1024
|
-
- `ast <https://www.mindspore.cn/
|
|
1132
|
+
- `ast <https://www.mindspore.cn/docs/en/master/features/compile/graph_construction.html#ast>`_ :
|
|
1025
1133
|
Parse Python ast to build graph.
|
|
1026
|
-
- `bytecode
|
|
1134
|
+
- `bytecode <https://www.mindspore.cn/docs/en/master/features/compile/graph_construction.html#bytecode>`_ :
|
|
1027
1135
|
Parse Python bytecode to build graph at runtime. This is an experimental prototype that is subject to
|
|
1028
1136
|
change and/or deletion.
|
|
1029
|
-
- `trace
|
|
1137
|
+
- `trace <https://www.mindspore.cn/docs/en/master/features/compile/graph_construction.html#trace>`_ : Trace the execution of Python code to build graph. This is an experimental prototype that is
|
|
1030
1138
|
subject to change and/or deletion.
|
|
1031
1139
|
|
|
1032
1140
|
jit_level (str, optional): Used to control the compilation optimization level. Currently is only effective
|
|
1033
|
-
with
|
|
1141
|
+
with ms_backend. The value of jit_level should be ``O0`` or ``O1`` . Default: ``O0`` .
|
|
1034
1142
|
|
|
1035
1143
|
- `O0`: Except for optimizations that may affect functionality, all other optimizations are turned off.
|
|
1036
1144
|
- `O1`: Using commonly used optimizations and automatic operator fusion optimizations. This optimization
|
|
@@ -1045,8 +1153,8 @@ def jit(
|
|
|
1045
1153
|
fullgraph (bool, optional): Whether to capture the entire function into graph. If False, jit attempts to
|
|
1046
1154
|
be compatible with all Python syntax in the function as much as possible. If True, we require that the
|
|
1047
1155
|
entire function can be captured into graph. If this is not possible (that is, if there is Python syntax
|
|
1048
|
-
not supported), then it will raise an exception. This currently only applies when capture_mode is ast
|
|
1049
|
-
Default: ``False``.
|
|
1156
|
+
not supported), then it will raise an exception. This currently only applies when capture_mode is ``ast``
|
|
1157
|
+
or ``bytecode``. Default: ``False``.
|
|
1050
1158
|
backend (str, optional): The compilation backend to be used. If this parameter is not set, the framework will
|
|
1051
1159
|
use ``GE`` backend for Atlas training series products and ``ms_backend`` backend for others including Atlas
|
|
1052
1160
|
A2 training series products by default.
|
|
@@ -1114,29 +1222,84 @@ def jit(
|
|
|
1114
1222
|
>>> x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
1115
1223
|
>>> y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32))
|
|
1116
1224
|
...
|
|
1117
|
-
>>> #
|
|
1225
|
+
>>> # Create a callable MindSpore graph by calling jit.
|
|
1118
1226
|
>>> def tensor_add(x, y):
|
|
1119
1227
|
... z = x + y
|
|
1120
1228
|
... return z
|
|
1121
1229
|
...
|
|
1122
1230
|
>>> tensor_add_graph = jit(function=tensor_add)
|
|
1123
1231
|
>>> out = tensor_add_graph(x, y)
|
|
1232
|
+
>>> print(out)
|
|
1233
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1234
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1235
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1236
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1124
1237
|
...
|
|
1125
|
-
>>> #
|
|
1238
|
+
>>> # Create a callable MindSpore graph through decorator @jit.
|
|
1126
1239
|
>>> @jit
|
|
1127
1240
|
... def tensor_add_with_dec(x, y):
|
|
1128
1241
|
... z = x + y
|
|
1129
1242
|
... return z
|
|
1130
1243
|
...
|
|
1131
1244
|
>>> out = tensor_add_with_dec(x, y)
|
|
1245
|
+
>>> print(out)
|
|
1246
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1247
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1248
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1249
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1132
1250
|
...
|
|
1133
|
-
>>> #
|
|
1251
|
+
>>> # Create a callable MindSpore graph and capture the entire function into the graph.
|
|
1134
1252
|
>>> @jit(fullgraph=True)
|
|
1135
1253
|
... def tensor_add_fullgraph(x, y):
|
|
1136
1254
|
... z = x + y
|
|
1137
1255
|
... return z
|
|
1138
1256
|
...
|
|
1139
1257
|
>>> out = tensor_add_fullgraph(x, y)
|
|
1258
|
+
>>> print(out)
|
|
1259
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1260
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1261
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1262
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1263
|
+
...
|
|
1264
|
+
>>> # Create a callable MindSpore graph by trace mode.
|
|
1265
|
+
>>> @jit(capture_mode="trace")
|
|
1266
|
+
... def tensor_add_by_trace(x, y):
|
|
1267
|
+
... z = x + y
|
|
1268
|
+
... return z
|
|
1269
|
+
...
|
|
1270
|
+
>>> out = tensor_add_by_trace(x, y)
|
|
1271
|
+
>>> print(out)
|
|
1272
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1273
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1274
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1275
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1276
|
+
...
|
|
1277
|
+
>>> # Create a callable MindSpore graph with ms_backend and jit_level="O1".
|
|
1278
|
+
>>> @jit(backend="ms_backend", jit_level="O1")
|
|
1279
|
+
... def tensor_add_by_trace(x, y):
|
|
1280
|
+
... z = x + y
|
|
1281
|
+
... return z
|
|
1282
|
+
...
|
|
1283
|
+
>>> out = tensor_add_by_trace(x, y)
|
|
1284
|
+
>>> print(out)
|
|
1285
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1286
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1287
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1288
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1289
|
+
...
|
|
1290
|
+
>>> # Create a callable MindSpore graph with GE backend and some ge options on Ascend.
|
|
1291
|
+
>>> @jit(backend="GE", ge_options={"global": {"ge.opSelectImplmode": "high_precision"}})
|
|
1292
|
+
... def tensor_add_by_trace(x, y):
|
|
1293
|
+
... z = x + y
|
|
1294
|
+
... return z
|
|
1295
|
+
...
|
|
1296
|
+
>>> out = tensor_add_by_trace(x, y)
|
|
1297
|
+
>>> print(out)
|
|
1298
|
+
Tensor(shape=[1, 1, 3, 3], dtype=Float32, value=
|
|
1299
|
+
[[[[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1300
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00],
|
|
1301
|
+
[ 2.00000000e+00, 2.00000000e+00, 2.00000000e+00]]]])
|
|
1302
|
+
...
|
|
1140
1303
|
"""
|
|
1141
1304
|
|
|
1142
1305
|
capture_mode = Validator.check_string(capture_mode, ["ast", "bytecode", "trace"], "capture_mode", "jit")
|
|
@@ -1155,39 +1318,12 @@ def jit(
|
|
|
1155
1318
|
jit_config = JitConfig(jit_level=jit_level, exc_mode=exc_mode, jit_syntax_level=jit_syntax_level,
|
|
1156
1319
|
infer_boost=infer_boost, backend=backend, options=options_str)
|
|
1157
1320
|
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
def staging_specialize(*args, **kwargs):
|
|
1165
|
-
if os.getenv("MS_JIT") == '0':
|
|
1166
|
-
return func(*args, **kwargs)
|
|
1167
|
-
|
|
1168
|
-
args, kwargs = _handle_func_args(func, *args, **kwargs)
|
|
1169
|
-
process_obj = None
|
|
1170
|
-
if args and not isinstance(args[0], PythonTensor) and hasattr(args[0], func.__name__):
|
|
1171
|
-
process_obj = args[0]
|
|
1172
|
-
# Handle auto mixed precision strategy.
|
|
1173
|
-
if not hasattr(func, "amp_strategy"):
|
|
1174
|
-
if isinstance(func, types.MethodType):
|
|
1175
|
-
setattr(func.__func__, "amp_strategy", get_curr_amp_strategy())
|
|
1176
|
-
else:
|
|
1177
|
-
setattr(func, "amp_strategy", get_curr_amp_strategy())
|
|
1178
|
-
|
|
1179
|
-
ms_function_executor = _JitExecutor(func, hash_obj, None, process_obj, jit_config, dynamic)
|
|
1180
|
-
out = ms_function_executor(*args, **kwargs)
|
|
1181
|
-
return out
|
|
1182
|
-
|
|
1183
|
-
return staging_specialize
|
|
1184
|
-
|
|
1185
|
-
if capture_mode == "bytecode":
|
|
1186
|
-
wrap_func = PIJitCaptureContext(jit_config)
|
|
1187
|
-
elif capture_mode == "trace":
|
|
1188
|
-
if function is not None:
|
|
1189
|
-
return _jit_trace(function)
|
|
1190
|
-
return _jit_trace
|
|
1321
|
+
if capture_mode == "ast":
|
|
1322
|
+
wrap_func = _jit_ast(hash_obj, dynamic, jit_config, '')
|
|
1323
|
+
elif capture_mode == "bytecode":
|
|
1324
|
+
wrap_func = PIJitCaptureContext(fullgraph=fullgraph, jit_config=jit_config)
|
|
1325
|
+
else:
|
|
1326
|
+
wrap_func = _jit_trace()
|
|
1191
1327
|
|
|
1192
1328
|
if function is not None:
|
|
1193
1329
|
return wrap_func(function)
|
|
@@ -1503,7 +1639,7 @@ class _PyNativeExecutor:
|
|
|
1503
1639
|
"""
|
|
1504
1640
|
self._executor.end_graph(obj, output, *args, *(kwargs.values()))
|
|
1505
1641
|
|
|
1506
|
-
def check_run(self, grad, obj, weights, grad_hash_id, *args):
|
|
1642
|
+
def check_run(self, grad, obj, weights, grad_hash_id, *args, **kwargs):
|
|
1507
1643
|
"""
|
|
1508
1644
|
Whether the forward graph need to construct.
|
|
1509
1645
|
|
|
@@ -1516,7 +1652,7 @@ class _PyNativeExecutor:
|
|
|
1516
1652
|
Return:
|
|
1517
1653
|
bool, specifies whether the forward graph needs to construct.
|
|
1518
1654
|
"""
|
|
1519
|
-
return self._executor.check_run(grad, obj, weights, grad_hash_id, *args)
|
|
1655
|
+
return self._executor.check_run(grad, obj, weights, grad_hash_id, *args, **kwargs)
|
|
1520
1656
|
|
|
1521
1657
|
def grad(self, obj, grad, weights, grad_position, *args):
|
|
1522
1658
|
"""
|
|
@@ -1834,13 +1970,6 @@ class _CellGraphExecutor:
|
|
|
1834
1970
|
else:
|
|
1835
1971
|
_set_dataset_mode_config('normal')
|
|
1836
1972
|
|
|
1837
|
-
@staticmethod
|
|
1838
|
-
def _use_vm_mode():
|
|
1839
|
-
enable_ge = context.get_context("enable_ge")
|
|
1840
|
-
enable_debug_runtime = context.get_context("enable_debug_runtime")
|
|
1841
|
-
exe_mode = context.get_context("mode") == context.PYNATIVE_MODE
|
|
1842
|
-
return not enable_ge or (enable_debug_runtime and exe_mode)
|
|
1843
|
-
|
|
1844
1973
|
def _build_data_graph(self, obj, phase):
|
|
1845
1974
|
self._graph_executor.build_data_graph(obj.parameters_dict(), phase)
|
|
1846
1975
|
|
|
@@ -1872,7 +2001,7 @@ class _CellGraphExecutor:
|
|
|
1872
2001
|
obj.__parse_method__ = 'construct'
|
|
1873
2002
|
if not hasattr(obj, obj.__parse_method__):
|
|
1874
2003
|
raise AttributeError(
|
|
1875
|
-
'The class {}
|
|
2004
|
+
'The class {} does not have method {}'.format(obj.__class__.__name__, obj.__parse_method__))
|
|
1876
2005
|
key_id = str(id(obj)) + str(obj.create_time)
|
|
1877
2006
|
args = get_auto_dynamic_shape_args(args, key_id)
|
|
1878
2007
|
|
|
@@ -1904,7 +2033,7 @@ class _CellGraphExecutor:
|
|
|
1904
2033
|
_clear_auto_parallel_context(obj)
|
|
1905
2034
|
return phase, False
|
|
1906
2035
|
|
|
1907
|
-
full_function_name = obj.__class__.__name__ + '.' + str(obj.
|
|
2036
|
+
full_function_name = obj.__class__.__name__ + '.' + str(obj.total_instance_count) + '.' + str(id(type(obj)))
|
|
1908
2037
|
echo_function_name = obj.__class__.__name__
|
|
1909
2038
|
_check_recompile(obj, args, kwargs, full_function_name, obj.create_time, echo_function_name)
|
|
1910
2039
|
|
|
@@ -1914,13 +2043,11 @@ class _CellGraphExecutor:
|
|
|
1914
2043
|
self._set_compile_cache_dep_files(phase)
|
|
1915
2044
|
|
|
1916
2045
|
self._graph_executor.set_weights_values(obj.parameters_dict())
|
|
1917
|
-
if jit_config_dict:
|
|
1918
|
-
self._graph_executor.set_jit_config(jit_config_dict)
|
|
1919
|
-
else:
|
|
2046
|
+
if not jit_config_dict:
|
|
1920
2047
|
jit_config_dict = JitConfig().jit_config_dict
|
|
1921
|
-
self._graph_executor.set_jit_config(jit_config_dict)
|
|
1922
2048
|
gc.collect()
|
|
1923
|
-
result = self._graph_executor.compile(
|
|
2049
|
+
result = self._graph_executor.compile(
|
|
2050
|
+
obj, args, kwargs, phase, jit_config_dict)
|
|
1924
2051
|
obj.compile_cache.add(phase)
|
|
1925
2052
|
if not result:
|
|
1926
2053
|
raise RuntimeError("Executor compile failed.")
|
|
@@ -2121,5 +2248,3 @@ def flops_collection(phase='train'):
|
|
|
2121
2248
|
|
|
2122
2249
|
_cell_graph_executor = _CellGraphExecutor()
|
|
2123
2250
|
_pynative_executor = _PyNativeExecutor()
|
|
2124
|
-
|
|
2125
|
-
__all__ = ['ms_memory_recycle', 'jit', 'jit_class', 'flops_collection']
|