mindspore 2.6.0rc1__cp311-cp311-win_amd64.whl → 2.7.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +2 -2
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +42 -11
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
- mindspore/_extends/optimize/cell_utils.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +44 -22
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -2
- mindspore/_extends/parse/parser.py +65 -84
- mindspore/_extends/parse/resources.py +39 -0
- mindspore/_extends/parse/standard_method.py +58 -14
- mindspore/_extends/parse/trope.py +8 -1
- mindspore/_extends/pijit/__init__.py +1 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +2 -5
- mindspore/amp.py +4 -22
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +4 -4
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +43 -12
- mindspore/common/_grad_function.py +2 -1
- mindspore/common/_pijit_context.py +28 -7
- mindspore/common/_stub_tensor.py +1 -209
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +178 -53
- mindspore/common/_utils.py +9 -1
- mindspore/common/api.py +377 -203
- mindspore/common/dtype.py +108 -57
- mindspore/common/dump.py +11 -16
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +17 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/generator.py +5 -3
- mindspore/common/hook_handle.py +33 -5
- mindspore/common/jit_config.py +1 -1
- mindspore/common/jit_trace.py +84 -105
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +27 -29
- mindspore/common/recompute.py +5 -7
- mindspore/common/sparse_tensor.py +0 -3
- mindspore/common/symbol.py +0 -1
- mindspore/common/tensor.py +117 -131
- mindspore/communication/_comm_helper.py +46 -4
- mindspore/communication/management.py +79 -7
- mindspore/context.py +67 -55
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +38 -4
- mindspore/dataset/engine/datasets.py +350 -322
- mindspore/dataset/engine/datasets_user_defined.py +70 -24
- mindspore/dataset/engine/iterators.py +2 -2
- mindspore/dataset/engine/obs/config_loader.py +2 -2
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms.py +7 -3
- mindspore/dataset/transforms/transforms.py +10 -6
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/dataset/vision/validators.py +1 -0
- mindspore/device_context/ascend/device.py +1 -1
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/device_context/gpu/__init__.py +2 -2
- mindspore/device_context/gpu/device.py +1 -1
- mindspore/device_context/gpu/op_precision.py +4 -2
- mindspore/device_context/gpu/op_tuning.py +6 -3
- mindspore/device_manager.py +16 -9
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -4
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/optim/adadelta.py +13 -20
- mindspore/experimental/optim/adagrad.py +15 -22
- mindspore/experimental/optim/adam.py +17 -24
- mindspore/experimental/optim/adamax.py +14 -22
- mindspore/experimental/optim/adamw.py +28 -34
- mindspore/experimental/optim/asgd.py +15 -25
- mindspore/experimental/optim/lr_scheduler.py +27 -45
- mindspore/experimental/optim/nadam.py +14 -24
- mindspore/experimental/optim/optimizer.py +13 -23
- mindspore/experimental/optim/radam.py +18 -24
- mindspore/experimental/optim/rmsprop.py +14 -25
- mindspore/experimental/optim/rprop.py +15 -26
- mindspore/experimental/optim/sgd.py +9 -19
- mindspore/hal/__init__.py +4 -4
- mindspore/hal/contiguous_tensors_handle.py +2 -2
- mindspore/hal/memory.py +27 -7
- mindspore/include/api/cell.h +65 -5
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +10 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +8 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +8 -3
- mindspore/include/c_api/model_c.h +0 -58
- mindspore/include/c_api/tensor_c.h +0 -26
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/vision_ascend.h +1 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar10.py +61 -11
- mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mindspore_ops_host.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +6 -46
- mindspore/mint/distributed/__init__.py +5 -0
- mindspore/mint/distributed/distributed.py +429 -23
- mindspore/mint/nn/__init__.py +1 -1
- mindspore/mint/nn/functional.py +53 -6
- mindspore/mint/nn/layer/_functions.py +163 -294
- mindspore/mint/nn/layer/activation.py +8 -6
- mindspore/mint/nn/layer/conv.py +140 -104
- mindspore/mint/nn/layer/normalization.py +11 -25
- mindspore/mint/optim/adam.py +19 -18
- mindspore/mint/optim/adamw.py +14 -8
- mindspore/mint/optim/sgd.py +5 -5
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/cell.py +491 -623
- mindspore/nn/grad/cell_grad.py +11 -12
- mindspore/nn/layer/activation.py +36 -36
- mindspore/nn/layer/basic.py +74 -77
- mindspore/nn/layer/channel_shuffle.py +4 -4
- mindspore/nn/layer/combined.py +4 -2
- mindspore/nn/layer/conv.py +117 -110
- mindspore/nn/layer/dense.py +9 -7
- mindspore/nn/layer/embedding.py +50 -52
- mindspore/nn/layer/image.py +38 -40
- mindspore/nn/layer/math.py +111 -112
- mindspore/nn/layer/normalization.py +56 -44
- mindspore/nn/layer/pooling.py +58 -63
- mindspore/nn/layer/rnn_cells.py +33 -33
- mindspore/nn/layer/rnns.py +56 -56
- mindspore/nn/layer/thor_layer.py +74 -73
- mindspore/nn/layer/transformer.py +11 -1
- mindspore/nn/learning_rate_schedule.py +20 -20
- mindspore/nn/loss/loss.py +79 -81
- mindspore/nn/optim/adam.py +4 -6
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -0
- mindspore/nn/optim/lamb.py +1 -3
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +2 -3
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -1
- mindspore/nn/probability/distribution/poisson.py +2 -1
- mindspore/nn/sparse/sparse.py +3 -3
- mindspore/nn/wrap/cell_wrapper.py +73 -42
- mindspore/nn/wrap/grad_reducer.py +37 -52
- mindspore/nn/wrap/loss_scale.py +72 -74
- mindspore/numpy/array_creations.py +7 -7
- mindspore/numpy/fft.py +1 -1
- mindspore/numpy/math_ops.py +5 -5
- mindspore/numpy/utils_const.py +1 -1
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/{experimental/es/__init__.py → ops/_op_impl/cpu/joinedstr_op.py} +12 -6
- mindspore/ops/_vmap/vmap_array_ops.py +31 -13
- mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +54 -13
- mindspore/ops/auto_generate/gen_extend_func.py +27 -145
- mindspore/ops/auto_generate/gen_ops_def.py +1027 -347
- mindspore/ops/auto_generate/gen_ops_prim.py +2341 -1117
- mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
- mindspore/ops/composite/__init__.py +10 -0
- mindspore/ops/composite/base.py +9 -5
- mindspore/ops/composite/multitype_ops/__init__.py +12 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +133 -109
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
- mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
- mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/_add_attr_func.py +11 -6
- mindspore/ops/function/array_func.py +19 -102
- mindspore/ops/function/debug_func.py +8 -5
- mindspore/ops/function/grad/grad_func.py +5 -13
- mindspore/ops/function/math_func.py +77 -572
- mindspore/ops/function/nn_func.py +46 -94
- mindspore/ops/function/other_func.py +4 -1
- mindspore/ops/function/random_func.py +44 -5
- mindspore/ops/function/vmap_func.py +2 -1
- mindspore/ops/functional.py +4 -4
- mindspore/ops/functional_overload.py +594 -18
- mindspore/ops/op_info_register.py +21 -0
- mindspore/ops/operations/__init__.py +16 -11
- mindspore/ops/operations/_custom_ops_utils.py +689 -34
- mindspore/ops/operations/_inner_ops.py +14 -18
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +5 -51
- mindspore/ops/operations/comm_ops.py +186 -41
- mindspore/ops/operations/custom_ops.py +303 -177
- mindspore/ops/operations/debug_ops.py +59 -4
- mindspore/ops/operations/image_ops.py +13 -13
- mindspore/ops/operations/manually_defined/ops_def.py +27 -28
- mindspore/ops/operations/math_ops.py +8 -9
- mindspore/ops/operations/nn_ops.py +8 -40
- mindspore/ops/primitive.py +9 -20
- mindspore/ops/tensor_method.py +63 -15
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
- mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
- mindspore/ops_generate/api/functions_cc_generator.py +58 -10
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
- mindspore/ops_generate/common/base_generator.py +14 -0
- mindspore/ops_generate/common/gen_constants.py +8 -3
- mindspore/ops_generate/common/gen_utils.py +0 -19
- mindspore/ops_generate/common/op_proto.py +11 -4
- mindspore/ops_generate/common/template.py +88 -11
- mindspore/ops_generate/gen_ops.py +1 -1
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
- mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
- mindspore/parallel/_auto_parallel_context.py +16 -23
- mindspore/parallel/_cell_wrapper.py +113 -45
- mindspore/parallel/_parallel_serialization.py +4 -3
- mindspore/parallel/_ps_context.py +4 -6
- mindspore/parallel/_tensor.py +167 -12
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/transformer.py +17 -12
- mindspore/parallel/_utils.py +5 -11
- mindspore/parallel/auto_parallel.py +35 -14
- mindspore/parallel/checkpoint_convert.py +3 -3
- mindspore/parallel/checkpoint_transform.py +13 -7
- mindspore/parallel/cluster/process_entity/_api.py +88 -49
- mindspore/parallel/cluster/process_entity/_utils.py +95 -7
- mindspore/parallel/cluster/run.py +48 -7
- mindspore/parallel/function/__init__.py +8 -1
- mindspore/parallel/function/reshard_func.py +12 -12
- mindspore/parallel/nn/__init__.py +15 -2
- mindspore/parallel/nn/parallel_cell_wrapper.py +50 -14
- mindspore/parallel/nn/parallel_grad_reducer.py +7 -14
- mindspore/parallel/shard.py +10 -25
- mindspore/parallel/transform_safetensors.py +469 -174
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +12 -6
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
- mindspore/profiler/analysis/task_manager.py +1 -1
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +10 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +43 -23
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
- mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
- mindspore/profiler/common/constant.py +16 -0
- mindspore/profiler/common/msprof_cmd_tool.py +2 -2
- mindspore/profiler/common/path_manager.py +9 -0
- mindspore/profiler/common/profiler_context.py +50 -29
- mindspore/profiler/common/profiler_info.py +0 -16
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +239 -0
- mindspore/profiler/common/profiler_output_path.py +23 -8
- mindspore/profiler/common/profiler_parameters.py +128 -35
- mindspore/profiler/dynamic_profile/__init__.py +0 -0
- mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
- mindspore/profiler/dynamic_profiler.py +374 -338
- mindspore/profiler/envprofiler.py +42 -12
- mindspore/profiler/experimental_config.py +112 -7
- mindspore/profiler/mstx.py +33 -12
- mindspore/profiler/platform/__init__.py +2 -3
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +30 -20
- mindspore/profiler/profiler.py +218 -154
- mindspore/profiler/profiler_action_controller.py +65 -77
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/profiler/schedule.py +10 -4
- mindspore/rewrite/common/config.py +1 -0
- mindspore/rewrite/common/namer.py +1 -0
- mindspore/rewrite/common/namespace.py +1 -0
- mindspore/rewrite/node/node.py +31 -11
- mindspore/rewrite/parsers/assign_parser.py +1 -1
- mindspore/rewrite/symbol_tree/symbol_tree.py +2 -2
- mindspore/run_check/_check_version.py +7 -10
- mindspore/runtime/__init__.py +8 -6
- mindspore/runtime/event.py +10 -4
- mindspore/runtime/executor.py +87 -45
- mindspore/runtime/memory.py +31 -32
- mindspore/runtime/thread_bind_core.py +299 -165
- mindspore/safeguard/rewrite_obfuscation.py +12 -13
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +17 -7
- mindspore/train/amp.py +43 -23
- mindspore/train/callback/__init__.py +5 -5
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +4 -14
- mindspore/train/callback/_flops_collector.py +11 -7
- mindspore/train/callback/_landscape.py +0 -1
- mindspore/train/callback/_train_fault_tolerance.py +98 -21
- mindspore/train/data_sink.py +15 -6
- mindspore/train/dataset_helper.py +14 -5
- mindspore/train/model.py +133 -69
- mindspore/train/serialization.py +168 -126
- mindspore/train/summary/summary_record.py +13 -2
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +0 -6
- mindspore/utils/runtime_execution_order_check.py +163 -77
- mindspore/utils/sdc_detect.py +68 -0
- mindspore/utils/utils.py +14 -17
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/METADATA +5 -4
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/RECORD +403 -442
- mindspore/_deprecated/jit.py +0 -198
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/es/embedding_service.py +0 -891
- mindspore/experimental/es/embedding_service_layer.py +0 -581
- mindspore/profiler/common/validator/__init__.py +0 -14
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/profiler/parser/__init__.py +0 -14
- mindspore/profiler/parser/aicpu_data_parser.py +0 -272
- mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
- mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
- mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
- mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
- mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
- mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
- mindspore/profiler/parser/ascend_flops_generator.py +0 -116
- mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
- mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
- mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
- mindspore/profiler/parser/ascend_memory_generator.py +0 -185
- mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
- mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
- mindspore/profiler/parser/ascend_op_generator.py +0 -334
- mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
- mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
- mindspore/profiler/parser/base_timeline_generator.py +0 -483
- mindspore/profiler/parser/container.py +0 -229
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
- mindspore/profiler/parser/flops_parser.py +0 -531
- mindspore/profiler/parser/framework_enum.py +0 -111
- mindspore/profiler/parser/framework_parser.py +0 -464
- mindspore/profiler/parser/framework_struct.py +0 -61
- mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
- mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
- mindspore/profiler/parser/hccl_parser.py +0 -573
- mindspore/profiler/parser/hwts_log_parser.py +0 -122
- mindspore/profiler/parser/integrator.py +0 -526
- mindspore/profiler/parser/memory_usage_parser.py +0 -277
- mindspore/profiler/parser/minddata_analyzer.py +0 -800
- mindspore/profiler/parser/minddata_parser.py +0 -186
- mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
- mindspore/profiler/parser/op_intermediate_parser.py +0 -149
- mindspore/profiler/parser/optime_parser.py +0 -250
- mindspore/profiler/parser/profiler_info.py +0 -213
- mindspore/profiler/parser/step_trace_parser.py +0 -666
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/WHEEL +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/top_level.txt +0 -0
|
@@ -25,8 +25,9 @@ from mindspore.communication import get_rank, get_group_size
|
|
|
25
25
|
from mindspore import log as logger
|
|
26
26
|
from mindspore.train.serialization import _get_cur_rank_dp
|
|
27
27
|
from mindspore._c_expression import _repair_device, _stop_device, _tft_sem_post, _tft_sem_enable
|
|
28
|
-
from mindspore._c_expression import _rebuild_world_group, _rebuild_sub_group, _finalize_comm
|
|
28
|
+
from mindspore._c_expression import _rebuild_world_group, _rebuild_sub_group, _finalize_comm, _clean_rootinfo
|
|
29
29
|
from mindspore._c_expression import clean_tdt_channel
|
|
30
|
+
from mindspore._c_expression import _pre_launch_send_recv
|
|
30
31
|
from mindspore._c_expression import send_recv, reset_params
|
|
31
32
|
from mindspore._c_expression import CollectiveManager
|
|
32
33
|
from mindspore._c_expression import _get_uce_process_strategy, _get_uce_mem_info
|
|
@@ -35,6 +36,7 @@ from mindspore.ops.operations.manually_defined._inner import TensorReport
|
|
|
35
36
|
import mindspore
|
|
36
37
|
import mindspore.common.dtype as mstype
|
|
37
38
|
from mindspore.parallel._recovery_context import _set_recovery_context
|
|
39
|
+
from mindspore import runtime
|
|
38
40
|
|
|
39
41
|
|
|
40
42
|
def _get_ckpt_dir(step, ckpt_save_path, is_tmp_file):
|
|
@@ -80,7 +82,7 @@ def _save_checkpoint_on_failure(step, save_info, args, cb_ctx):
|
|
|
80
82
|
append_dict["loss_scale"] = outputs[2]
|
|
81
83
|
|
|
82
84
|
ckpt_file = f"ttp_rank_{str(cur_rank)}-{str(cur_epoch_num)}_{str(step_num_in_epoch)}.ckpt"
|
|
83
|
-
cur_ckpt_dir = _get_ckpt_dir(step, ckpt_save_path, True)
|
|
85
|
+
cur_ckpt_dir = os.path.join(_get_ckpt_dir(step, ckpt_save_path, True), "rank_" + str(cur_rank))
|
|
84
86
|
os.makedirs(cur_ckpt_dir, exist_ok=True)
|
|
85
87
|
cur_file = os.path.join(cur_ckpt_dir, ckpt_file)
|
|
86
88
|
save_checkpoint(cb_params.train_network, cur_file,
|
|
@@ -110,7 +112,7 @@ def _tft_exit_cb(ctx):
|
|
|
110
112
|
|
|
111
113
|
def _tft_repair_callback(step, need_rebuild, error_ranks, repair_info, args, cb_ctx):
|
|
112
114
|
""" Callback used for TFT repair function."""
|
|
113
|
-
logger.warning("Enter _tft_repair_callback repair type: {
|
|
115
|
+
logger.warning(f"Enter _tft_repair_callback repair type: {repair_info['repair_type']}")
|
|
114
116
|
if (repair_info["repair_type"] in (cb_ctx.tft.RepairType.RT_UCE_HIGHLEVEL.value,
|
|
115
117
|
cb_ctx.tft.RepairType.RT_UCE_LOWLEVEL.value)):
|
|
116
118
|
logger.warning("Enter _tft_repair_callback uce REPARI_DEVICE device_id : {}".format(cb_ctx.device_id))
|
|
@@ -138,7 +140,7 @@ def _tft_repair_callback(step, need_rebuild, error_ranks, repair_info, args, cb_
|
|
|
138
140
|
|
|
139
141
|
def _tft_clean_callback(is_uce_error, args, ctx):
|
|
140
142
|
""" Callback used for TFT clean function."""
|
|
141
|
-
logger.warning("Enter _tft_clean_callback")
|
|
143
|
+
logger.warning(f"Enter _tft_clean_callback, device id:{ctx.device_id}")
|
|
142
144
|
ret = 0
|
|
143
145
|
if is_uce_error:
|
|
144
146
|
_get_uce_mem_info(ctx.device_id)
|
|
@@ -154,12 +156,16 @@ def _tft_clean_callback(is_uce_error, args, ctx):
|
|
|
154
156
|
logger.warning("Enter _tft_clean_callback resume_hccl_comm")
|
|
155
157
|
CollectiveManager.get_instance().resume_hccl_comm()
|
|
156
158
|
logger.warning("Finish _tft_clean_callback, ret: {}".format(ret))
|
|
159
|
+
if ctx.tft.tft_get_repair_type() == "recover":
|
|
160
|
+
logger.warning(f"Destroy hcom")
|
|
161
|
+
_finalize_comm()
|
|
162
|
+
logger.warning(f"Destroy hcom end")
|
|
157
163
|
return ret
|
|
158
164
|
|
|
159
165
|
|
|
160
166
|
def _tft_stop_callback(args, cb_ctx):
|
|
161
167
|
""" Callback used for TFT stop function."""
|
|
162
|
-
logger.warning("Enter _tft_stop_callback device_id: {
|
|
168
|
+
logger.warning(f"Enter _tft_stop_callback device_id: {cb_ctx.device_id}")
|
|
163
169
|
_stop_device(cb_ctx.device_id)
|
|
164
170
|
if (not cb_ctx.is_uce_rank) and (not cb_ctx._is_params_consistent()): # pylint: disable=W0212
|
|
165
171
|
raise RuntimeError("Can't stop device, because training parameters are left in inconsistent state!")
|
|
@@ -167,23 +173,25 @@ def _tft_stop_callback(args, cb_ctx):
|
|
|
167
173
|
if cb_ctx.tft.tft_get_repair_type() == "recover":
|
|
168
174
|
logger.warning(f"Reset limit step")
|
|
169
175
|
cb_ctx.tft.tft_reset_limit_step()
|
|
170
|
-
logger.
|
|
176
|
+
logger.warning("Finish _tft_stop_callback")
|
|
171
177
|
|
|
172
178
|
|
|
173
179
|
def _tft_rebuild_sub_groups(fault_ranks, args, ctx):
|
|
174
180
|
"""Callback used for TFT Rebuild Group function."""
|
|
175
|
-
logger.warning(f"Enter _tft_rebuild_sub_groups, device id:
|
|
176
|
-
_finalize_comm()
|
|
181
|
+
logger.warning(f"Enter _tft_rebuild_sub_groups, device id: {ctx.device_id}")
|
|
177
182
|
_rebuild_world_group()
|
|
178
183
|
_rebuild_sub_group()
|
|
179
184
|
_set_recovery_context(is_arf=True)
|
|
185
|
+
logger.warning(f"try to pre launch send recv before real launch")
|
|
186
|
+
_pre_launch_send_recv(context.get_context('device_id'))
|
|
187
|
+
logger.warning(f"Pre launch send recv before real launch end")
|
|
180
188
|
logger.warning("Enter _tft_rebuild_sub_groups ok ")
|
|
181
189
|
|
|
182
190
|
|
|
183
191
|
class TrainFaultTolerance(Callback):
|
|
184
192
|
"""
|
|
185
193
|
This callback is used to enable the TFT feature
|
|
186
|
-
`MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/
|
|
194
|
+
`MindIO TFT <https://www.hiascend.com/document/detail/zh/mindx-dl/600/clusterscheduling/ref/mindiottp/mindiotft001.html>`_
|
|
187
195
|
and will execute TFT operations during training process, such as TFT init, report and exception handle.
|
|
188
196
|
|
|
189
197
|
Note:
|
|
@@ -299,27 +307,69 @@ class TrainFaultTolerance(Callback):
|
|
|
299
307
|
|
|
300
308
|
def __init__(self, ckpt_save_path=None, **kwargs):
|
|
301
309
|
super(TrainFaultTolerance, self).__init__()
|
|
310
|
+
logger.info(f"MS_ENABLE_TFT: {os.getenv('MS_ENABLE_TFT', '')}")
|
|
311
|
+
if self._only_enable_tsp():
|
|
312
|
+
self.tft = _tft_handler.get_tft()
|
|
313
|
+
self._check_init()
|
|
314
|
+
self.tft.tft_register_stream_sync_handler(runtime.synchronize, self)
|
|
315
|
+
return
|
|
302
316
|
self.save_cb = kwargs.get("ckpt_save_fn", None)
|
|
303
317
|
self.ckpt_save_path = ckpt_save_path
|
|
304
318
|
if self.save_cb is None and self.ckpt_save_path is None:
|
|
305
319
|
raise ValueError("TrainFaultTolerance construct need to set ckpt_save_fn or ckpt_save_path!")
|
|
320
|
+
self.cb_params = None
|
|
321
|
+
self.initial_step = kwargs.get("initial_step", 0)
|
|
322
|
+
self.device_id = context.get_context("device_id")
|
|
323
|
+
self.cur_step_num = 0
|
|
324
|
+
self.cur_epoch_num = 0
|
|
325
|
+
self.clean_unique_id = False
|
|
326
|
+
# For TREError(Training Result Error) scene, parameter `ckpt_load_fn` must be provided to load checkpoint
|
|
327
|
+
# from file for resuming training, the `ckpt_load_fn` is a function, prototype of which is:
|
|
328
|
+
# `def load_checkpoint() -> tuple(dict, bool)`, the return value is a tuple containing 2 values,
|
|
329
|
+
# i.e. (param_dict, remove_redundancy)
|
|
330
|
+
self.ckpt_load_func = kwargs.get("ckpt_load_fn", None)
|
|
331
|
+
if self._only_enable_tre():
|
|
332
|
+
return
|
|
306
333
|
self.tft = _tft_handler.get_tft()
|
|
307
334
|
self._check_init()
|
|
335
|
+
if self._only_enable_tre_and_tsp():
|
|
336
|
+
self.tft.tft_register_stream_sync_handler(runtime.synchronize, self)
|
|
337
|
+
return
|
|
308
338
|
self.global_step = None
|
|
309
339
|
self.learning_rate = None
|
|
310
340
|
self.has_init_replica = False
|
|
311
341
|
self.is_uce_rank = False
|
|
312
|
-
|
|
313
|
-
self.initial_step = kwargs.get("initial_step", 0)
|
|
314
|
-
self.device_id = context.get_context("device_id")
|
|
342
|
+
|
|
315
343
|
self.assign = mindspore.ops.Assign()
|
|
316
344
|
self.g_one = Parameter(Tensor([1], dtype=mstype.int32))
|
|
317
345
|
self.s1 = mindspore.hal.Stream()
|
|
318
|
-
self.cur_step_num = 0
|
|
319
|
-
self.cur_epoch_num = 0
|
|
320
346
|
_tft_sem_enable()
|
|
321
347
|
self._tft_register()
|
|
322
348
|
|
|
349
|
+
def _only_enable_tre(self):
|
|
350
|
+
"""Check if only configured MS_ENABLE_TFT='{TRE:1}'"""
|
|
351
|
+
env_enable = os.getenv("MS_ENABLE_TFT", "")
|
|
352
|
+
non_tre_flags = ["TTP:1", "UCE:1", "ARF:1"]
|
|
353
|
+
if any(flag in env_enable for flag in non_tre_flags):
|
|
354
|
+
return False
|
|
355
|
+
return "TRE:1" in env_enable
|
|
356
|
+
|
|
357
|
+
def _only_enable_tsp(self):
|
|
358
|
+
"""Check if only configured MS_ENABLE_TFT='{TSP:1}'"""
|
|
359
|
+
env_enable = os.getenv("MS_ENABLE_TFT", "")
|
|
360
|
+
non_tsp_flags = ["TTP:1", "UCE:1", "ARF:1", "TRE:1"]
|
|
361
|
+
if any(flag in env_enable for flag in non_tsp_flags):
|
|
362
|
+
return False
|
|
363
|
+
return "TSP:1" in env_enable
|
|
364
|
+
|
|
365
|
+
def _only_enable_tre_and_tsp(self):
|
|
366
|
+
"""Check if only configured MS_ENABLE_TFT='{TRE:1, TSP:1}'"""
|
|
367
|
+
env_enable = os.getenv("MS_ENABLE_TFT", "")
|
|
368
|
+
other_flags = ["TTP:1", "UCE:1", "ARF:1"]
|
|
369
|
+
if any(flag in env_enable for flag in other_flags):
|
|
370
|
+
return False
|
|
371
|
+
return "TRE:1" in env_enable and "TSP:1" in env_enable
|
|
372
|
+
|
|
323
373
|
def _check_init(self):
|
|
324
374
|
"""Check if the mindio-ttp had inited"""
|
|
325
375
|
if self.tft is None:
|
|
@@ -330,11 +380,9 @@ class TrainFaultTolerance(Callback):
|
|
|
330
380
|
_tft_handler.init(config=None)
|
|
331
381
|
self.tft = _tft_handler.get_tft()
|
|
332
382
|
logger.warning(f"TFT handle init ok.")
|
|
333
|
-
mode = context.get_context("mode")
|
|
334
383
|
device_target = context.get_context("device_target")
|
|
335
|
-
if device_target != "Ascend"
|
|
336
|
-
raise ValueError(f"MindIO adataper only support on Ascend device
|
|
337
|
-
f"device:{device_target}, run mode: {mode}")
|
|
384
|
+
if device_target != "Ascend":
|
|
385
|
+
raise ValueError(f"MindIO adataper only support on Ascend device but got device {device_target}!")
|
|
338
386
|
|
|
339
387
|
def _is_params_consistent(self):
|
|
340
388
|
for key, param in self.cb_params.train_network.parameters_and_names():
|
|
@@ -411,6 +459,8 @@ class TrainFaultTolerance(Callback):
|
|
|
411
459
|
self.tft.tft_register_clean_handler(_tft_clean_callback, self)
|
|
412
460
|
self.tft.tft_register_repair_handler(_tft_repair_callback, self)
|
|
413
461
|
self.tft.tft_register_rebuild_group_handler(_tft_rebuild_sub_groups, self)
|
|
462
|
+
if "TSP:1" in os.getenv("MS_ENABLE_TFT", ""):
|
|
463
|
+
self.tft.tft_register_stream_sync_handler(runtime.synchronize, self)
|
|
414
464
|
|
|
415
465
|
def _reset_acc_grads(self):
|
|
416
466
|
accu_grad_params = map(lambda e: e[1],
|
|
@@ -420,6 +470,12 @@ class TrainFaultTolerance(Callback):
|
|
|
420
470
|
if reset_params(accu_grad_list) != 0:
|
|
421
471
|
raise ValueError("Call reset_params failed.")
|
|
422
472
|
|
|
473
|
+
def _clear_unique_id(self):
|
|
474
|
+
"""Clean unique id on first train step end"""
|
|
475
|
+
if not self.clean_unique_id and ("ARF:1" in os.getenv("MS_ENABLE_TFT", "")):
|
|
476
|
+
_clean_rootinfo()
|
|
477
|
+
self.clean_unique_id = True
|
|
478
|
+
|
|
423
479
|
def on_train_step_end(self, run_context):
|
|
424
480
|
"""
|
|
425
481
|
Report status to MindIO TFT after every step finished.
|
|
@@ -428,13 +484,21 @@ class TrainFaultTolerance(Callback):
|
|
|
428
484
|
run_context (RunContext): Context of the train running. Refer to
|
|
429
485
|
:class:`mindspore.train.RunContext` for detail.
|
|
430
486
|
"""
|
|
431
|
-
if self.
|
|
432
|
-
|
|
433
|
-
|
|
487
|
+
if self._only_enable_tre():
|
|
488
|
+
return
|
|
489
|
+
|
|
434
490
|
cb_params = run_context.original_args()
|
|
435
491
|
logger.info("START Set optimizer finish step status to TFT. step: {}".format(cb_params.cur_step_num))
|
|
436
492
|
self.cur_step_num = cb_params.cur_step_num
|
|
437
493
|
self.cur_epoch_num = cb_params.cur_epoch_num
|
|
494
|
+
if self._only_enable_tsp() or self._only_enable_tre_and_tsp():
|
|
495
|
+
logger.info("Go into tft_pause_train.")
|
|
496
|
+
self.tft.tft_pause_train(self.cur_step_num)
|
|
497
|
+
return
|
|
498
|
+
|
|
499
|
+
if self.has_init_replica is False:
|
|
500
|
+
self.has_init_replica = True
|
|
501
|
+
self._set_tft_optimizer_replica(run_context)
|
|
438
502
|
if cb_params.optimizer is not None:
|
|
439
503
|
self.global_step = cb_params.optimizer.global_step.clone()
|
|
440
504
|
self.assign(cb_params.optimizer.tft_g_one_flag, self.g_one)
|
|
@@ -444,7 +508,13 @@ class TrainFaultTolerance(Callback):
|
|
|
444
508
|
else:
|
|
445
509
|
raise ValueError("TFT feature need optimizer or network's optimizer!")
|
|
446
510
|
self.tft.tft_end_updating_os(cb_params.cur_step_num + self.initial_step)
|
|
511
|
+
if cb_params.is_arf:
|
|
512
|
+
self.clean_unique_id = False
|
|
513
|
+
self._clear_unique_id()
|
|
447
514
|
logger.info("END Set optimizer finish step status to TFT.")
|
|
515
|
+
if "TSP:1" in os.getenv("MS_ENABLE_TFT", ""):
|
|
516
|
+
logger.info("Go into tft_pause_train.")
|
|
517
|
+
self.tft.tft_pause_train(self.cur_step_num)
|
|
448
518
|
|
|
449
519
|
def on_train_begin(self, run_context):
|
|
450
520
|
"""
|
|
@@ -454,7 +524,12 @@ class TrainFaultTolerance(Callback):
|
|
|
454
524
|
run_context (RunContext): Context of the train running. Refer to
|
|
455
525
|
:class:`mindspore.train.RunContext` for detail.
|
|
456
526
|
"""
|
|
527
|
+
if self._only_enable_tsp():
|
|
528
|
+
return
|
|
457
529
|
cb_params = run_context.original_args()
|
|
530
|
+
if self._only_enable_tre():
|
|
531
|
+
self.cb_params = cb_params
|
|
532
|
+
return
|
|
458
533
|
sink_size = cb_params.get("sink_size", 0)
|
|
459
534
|
if sink_size > 1:
|
|
460
535
|
raise ValueError("TFT feature doesn't support sink_size > 1.")
|
|
@@ -470,4 +545,6 @@ class TrainFaultTolerance(Callback):
|
|
|
470
545
|
run_context (RunContext): Context of the train running. Refer to
|
|
471
546
|
:class:`mindspore.train.RunContext` for detail.
|
|
472
547
|
"""
|
|
548
|
+
if self._only_enable_tre() or self._only_enable_tsp() or self._only_enable_tre_and_tsp():
|
|
549
|
+
return
|
|
473
550
|
_tft_handler.unregister_tft()
|
mindspore/train/data_sink.py
CHANGED
|
@@ -16,9 +16,9 @@
|
|
|
16
16
|
from functools import wraps
|
|
17
17
|
import mindspore.ops as ops
|
|
18
18
|
from mindspore import context
|
|
19
|
-
from mindspore.common.dtype import
|
|
19
|
+
from mindspore.common.dtype import _pytype_to_dtype
|
|
20
20
|
from mindspore.common.api import jit
|
|
21
|
-
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
|
|
21
|
+
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, enable_data_broadcast
|
|
22
22
|
from mindspore.train.dataset_helper import _has_dynamic_shape, _check_inputs
|
|
23
23
|
import mindspore.dataset as ds
|
|
24
24
|
from mindspore._c_expression import _set_dataset_mode_config
|
|
@@ -41,6 +41,15 @@ def _init_sink_dataset(dataset, sink_size, input_signature, create_info):
|
|
|
41
41
|
is_info_queue = (create_info and sink_size == 1 and dataset_size != 1 and
|
|
42
42
|
input_signature is None and not dynamic_shape and
|
|
43
43
|
context.get_context('device_target') == 'Ascend')
|
|
44
|
+
|
|
45
|
+
# Don't enable dynamic shape(multi-subgraph) feature in pp/data_broadcast mode,
|
|
46
|
+
# otherwise get_data_info will stuck since some rank do not consume data.
|
|
47
|
+
use_pipeline_parallel = (context.get_auto_parallel_context("pipeline_stages") > 1)
|
|
48
|
+
data_broadcast = enable_data_broadcast()
|
|
49
|
+
|
|
50
|
+
if use_pipeline_parallel or data_broadcast:
|
|
51
|
+
is_info_queue = False
|
|
52
|
+
|
|
44
53
|
transfer_dataset = _exec_datagraph(dataset, sink_size, create_data_info_queue=is_info_queue)
|
|
45
54
|
dataset.__transfer_dataset__ = transfer_dataset
|
|
46
55
|
|
|
@@ -52,7 +61,7 @@ def _init_sink_dataset(dataset, sink_size, input_signature, create_info):
|
|
|
52
61
|
_check_inputs(input_signature, dataset_shapes, dataset_types)
|
|
53
62
|
|
|
54
63
|
queue_name = transfer_dataset.queue_name
|
|
55
|
-
if _need_to_full()
|
|
64
|
+
if _need_to_full():
|
|
56
65
|
device_num = _get_device_num() // _get_pipeline_stages()
|
|
57
66
|
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
|
58
67
|
next_op = ops.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
|
@@ -85,12 +94,12 @@ def _get_next_op(dataset, ori_next_op, is_info_queue):
|
|
|
85
94
|
|
|
86
95
|
queue_name = dataset.__transfer_dataset__.queue_name
|
|
87
96
|
dataset_types, dataset_shapes = dataset.__transfer_dataset__.get_data_info()
|
|
88
|
-
dataset_types = [
|
|
97
|
+
dataset_types = [_pytype_to_dtype(x) for x in dataset_types] # pylint:disable=protected-access
|
|
89
98
|
key = str(dataset_types) + str(dataset_shapes)
|
|
90
99
|
if key in dataset.__sink_aux__.next_ops:
|
|
91
100
|
next_op = dataset.__sink_aux__.next_ops[key]
|
|
92
101
|
else:
|
|
93
|
-
if _need_to_full()
|
|
102
|
+
if _need_to_full():
|
|
94
103
|
device_num = _get_device_num() // _get_pipeline_stages()
|
|
95
104
|
dataset_shapes = _to_full_shapes(dataset_shapes, device_num)
|
|
96
105
|
next_op = ops.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
|
|
@@ -214,7 +223,7 @@ def data_sink(fn, dataset, sink_size=1, jit_config=None, input_signature=None):
|
|
|
214
223
|
loop = sink_size
|
|
215
224
|
create_info = True
|
|
216
225
|
if jit_config is None:
|
|
217
|
-
create_info =
|
|
226
|
+
create_info = loop == 1
|
|
218
227
|
loop = 1
|
|
219
228
|
ori_next_op, is_info_queue = _init_sink_dataset(dataset, loop, input_signature, create_info)
|
|
220
229
|
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
1
|
+
# Copyright 2020-2025 Huawei Technologies Co., Ltd
|
|
2
2
|
#
|
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
4
|
# you may not use this file except in compliance with the License.
|
|
@@ -20,8 +20,8 @@ import copy
|
|
|
20
20
|
|
|
21
21
|
from mindspore import _checkparam as Validator
|
|
22
22
|
from mindspore import log as logger
|
|
23
|
-
from mindspore.common._auto_dynamic import is_auto_dynamic, convert_new_shapes
|
|
24
|
-
from mindspore.common.dtype import
|
|
23
|
+
from mindspore.common.dynamic_shape._auto_dynamic import is_auto_dynamic, convert_new_shapes
|
|
24
|
+
from mindspore.common.dtype import _pytype_to_dtype
|
|
25
25
|
from mindspore.common.api import _cell_graph_executor, _is_args_fullmode, ARG_SPECIFIED
|
|
26
26
|
from mindspore.common._utils import is_shape_unknown
|
|
27
27
|
from mindspore.dataset.core import config as dataset_config
|
|
@@ -34,7 +34,7 @@ from mindspore.parallel._utils import _get_device_num, _get_global_rank, _need_t
|
|
|
34
34
|
_origin_shapes, _dynamic_shape_for_dataset
|
|
35
35
|
from mindspore.parallel._ps_context import _is_role_sched
|
|
36
36
|
from mindspore.ops import operations as P
|
|
37
|
-
from mindspore.common.auto_dynamic_shape import _auto_dynamic_shape
|
|
37
|
+
from mindspore.common.dynamic_shape.auto_dynamic_shape import _auto_dynamic_shape
|
|
38
38
|
|
|
39
39
|
|
|
40
40
|
def _send_data(dataset, epoch_num):
|
|
@@ -275,7 +275,7 @@ def connect_network_with_dataset(network, dataset_helper):
|
|
|
275
275
|
# Need to do full_batch for shapes which also do in the _DatasetIterMSLoopSink
|
|
276
276
|
if _need_to_full():
|
|
277
277
|
dataset_shapes = _to_full_shapes(dataset_shapes, _get_device_num() // _get_pipeline_stages())
|
|
278
|
-
dataset_types = [
|
|
278
|
+
dataset_types = [_pytype_to_dtype(x) for x in dataset_types] # pylint:disable=protected-access
|
|
279
279
|
if not is_dynamic:
|
|
280
280
|
dataset_shapes = _auto_dynamic_shape.auto_dynamic_generate_compile_args(dataset_shapes, True)
|
|
281
281
|
key = str(dataset_types) + str(dataset_shapes)
|
|
@@ -564,6 +564,15 @@ class _DatasetIter:
|
|
|
564
564
|
self.sink_size = dataset.__loop_size__
|
|
565
565
|
create_data_info_queue = (
|
|
566
566
|
sink_size == 1 and self.sink_count == 1 and dataset.get_dataset_size() != 1)
|
|
567
|
+
|
|
568
|
+
# Don't enable dynamic shape(multi-subgraph) feature in pp/data_broadcast mode,
|
|
569
|
+
# otherwise get_data_info will stuck since some rank do not consume data.
|
|
570
|
+
use_pipeline_parallel = (context.get_auto_parallel_context("pipeline_stages") > 1)
|
|
571
|
+
data_broadcast = enable_data_broadcast()
|
|
572
|
+
|
|
573
|
+
if use_pipeline_parallel or data_broadcast:
|
|
574
|
+
create_data_info_queue = False
|
|
575
|
+
|
|
567
576
|
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size,
|
|
568
577
|
create_data_info_queue=create_data_info_queue)
|
|
569
578
|
|