mindspore 2.6.0rc1__cp311-cp311-win_amd64.whl → 2.7.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +2 -2
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +42 -11
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
- mindspore/_extends/optimize/cell_utils.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +44 -22
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -2
- mindspore/_extends/parse/parser.py +65 -84
- mindspore/_extends/parse/resources.py +39 -0
- mindspore/_extends/parse/standard_method.py +58 -14
- mindspore/_extends/parse/trope.py +8 -1
- mindspore/_extends/pijit/__init__.py +1 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +2 -5
- mindspore/amp.py +4 -22
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +4 -4
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +43 -12
- mindspore/common/_grad_function.py +2 -1
- mindspore/common/_pijit_context.py +28 -7
- mindspore/common/_stub_tensor.py +1 -209
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +178 -53
- mindspore/common/_utils.py +9 -1
- mindspore/common/api.py +377 -203
- mindspore/common/dtype.py +108 -57
- mindspore/common/dump.py +11 -16
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +17 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/generator.py +5 -3
- mindspore/common/hook_handle.py +33 -5
- mindspore/common/jit_config.py +1 -1
- mindspore/common/jit_trace.py +84 -105
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +27 -29
- mindspore/common/recompute.py +5 -7
- mindspore/common/sparse_tensor.py +0 -3
- mindspore/common/symbol.py +0 -1
- mindspore/common/tensor.py +117 -131
- mindspore/communication/_comm_helper.py +46 -4
- mindspore/communication/management.py +79 -7
- mindspore/context.py +67 -55
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +38 -4
- mindspore/dataset/engine/datasets.py +350 -322
- mindspore/dataset/engine/datasets_user_defined.py +70 -24
- mindspore/dataset/engine/iterators.py +2 -2
- mindspore/dataset/engine/obs/config_loader.py +2 -2
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms.py +7 -3
- mindspore/dataset/transforms/transforms.py +10 -6
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/dataset/vision/validators.py +1 -0
- mindspore/device_context/ascend/device.py +1 -1
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/device_context/gpu/__init__.py +2 -2
- mindspore/device_context/gpu/device.py +1 -1
- mindspore/device_context/gpu/op_precision.py +4 -2
- mindspore/device_context/gpu/op_tuning.py +6 -3
- mindspore/device_manager.py +16 -9
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -4
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/optim/adadelta.py +13 -20
- mindspore/experimental/optim/adagrad.py +15 -22
- mindspore/experimental/optim/adam.py +17 -24
- mindspore/experimental/optim/adamax.py +14 -22
- mindspore/experimental/optim/adamw.py +28 -34
- mindspore/experimental/optim/asgd.py +15 -25
- mindspore/experimental/optim/lr_scheduler.py +27 -45
- mindspore/experimental/optim/nadam.py +14 -24
- mindspore/experimental/optim/optimizer.py +13 -23
- mindspore/experimental/optim/radam.py +18 -24
- mindspore/experimental/optim/rmsprop.py +14 -25
- mindspore/experimental/optim/rprop.py +15 -26
- mindspore/experimental/optim/sgd.py +9 -19
- mindspore/hal/__init__.py +4 -4
- mindspore/hal/contiguous_tensors_handle.py +2 -2
- mindspore/hal/memory.py +27 -7
- mindspore/include/api/cell.h +65 -5
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +10 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +8 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +8 -3
- mindspore/include/c_api/model_c.h +0 -58
- mindspore/include/c_api/tensor_c.h +0 -26
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/vision_ascend.h +1 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar10.py +61 -11
- mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mindspore_ops_host.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +6 -46
- mindspore/mint/distributed/__init__.py +5 -0
- mindspore/mint/distributed/distributed.py +429 -23
- mindspore/mint/nn/__init__.py +1 -1
- mindspore/mint/nn/functional.py +53 -6
- mindspore/mint/nn/layer/_functions.py +163 -294
- mindspore/mint/nn/layer/activation.py +8 -6
- mindspore/mint/nn/layer/conv.py +140 -104
- mindspore/mint/nn/layer/normalization.py +11 -25
- mindspore/mint/optim/adam.py +19 -18
- mindspore/mint/optim/adamw.py +14 -8
- mindspore/mint/optim/sgd.py +5 -5
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/cell.py +491 -623
- mindspore/nn/grad/cell_grad.py +11 -12
- mindspore/nn/layer/activation.py +36 -36
- mindspore/nn/layer/basic.py +74 -77
- mindspore/nn/layer/channel_shuffle.py +4 -4
- mindspore/nn/layer/combined.py +4 -2
- mindspore/nn/layer/conv.py +117 -110
- mindspore/nn/layer/dense.py +9 -7
- mindspore/nn/layer/embedding.py +50 -52
- mindspore/nn/layer/image.py +38 -40
- mindspore/nn/layer/math.py +111 -112
- mindspore/nn/layer/normalization.py +56 -44
- mindspore/nn/layer/pooling.py +58 -63
- mindspore/nn/layer/rnn_cells.py +33 -33
- mindspore/nn/layer/rnns.py +56 -56
- mindspore/nn/layer/thor_layer.py +74 -73
- mindspore/nn/layer/transformer.py +11 -1
- mindspore/nn/learning_rate_schedule.py +20 -20
- mindspore/nn/loss/loss.py +79 -81
- mindspore/nn/optim/adam.py +4 -6
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -0
- mindspore/nn/optim/lamb.py +1 -3
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +2 -3
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -1
- mindspore/nn/probability/distribution/poisson.py +2 -1
- mindspore/nn/sparse/sparse.py +3 -3
- mindspore/nn/wrap/cell_wrapper.py +73 -42
- mindspore/nn/wrap/grad_reducer.py +37 -52
- mindspore/nn/wrap/loss_scale.py +72 -74
- mindspore/numpy/array_creations.py +7 -7
- mindspore/numpy/fft.py +1 -1
- mindspore/numpy/math_ops.py +5 -5
- mindspore/numpy/utils_const.py +1 -1
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/{experimental/es/__init__.py → ops/_op_impl/cpu/joinedstr_op.py} +12 -6
- mindspore/ops/_vmap/vmap_array_ops.py +31 -13
- mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +54 -13
- mindspore/ops/auto_generate/gen_extend_func.py +27 -145
- mindspore/ops/auto_generate/gen_ops_def.py +1027 -347
- mindspore/ops/auto_generate/gen_ops_prim.py +2341 -1117
- mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
- mindspore/ops/composite/__init__.py +10 -0
- mindspore/ops/composite/base.py +9 -5
- mindspore/ops/composite/multitype_ops/__init__.py +12 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +133 -109
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
- mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
- mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/_add_attr_func.py +11 -6
- mindspore/ops/function/array_func.py +19 -102
- mindspore/ops/function/debug_func.py +8 -5
- mindspore/ops/function/grad/grad_func.py +5 -13
- mindspore/ops/function/math_func.py +77 -572
- mindspore/ops/function/nn_func.py +46 -94
- mindspore/ops/function/other_func.py +4 -1
- mindspore/ops/function/random_func.py +44 -5
- mindspore/ops/function/vmap_func.py +2 -1
- mindspore/ops/functional.py +4 -4
- mindspore/ops/functional_overload.py +594 -18
- mindspore/ops/op_info_register.py +21 -0
- mindspore/ops/operations/__init__.py +16 -11
- mindspore/ops/operations/_custom_ops_utils.py +689 -34
- mindspore/ops/operations/_inner_ops.py +14 -18
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +5 -51
- mindspore/ops/operations/comm_ops.py +186 -41
- mindspore/ops/operations/custom_ops.py +303 -177
- mindspore/ops/operations/debug_ops.py +59 -4
- mindspore/ops/operations/image_ops.py +13 -13
- mindspore/ops/operations/manually_defined/ops_def.py +27 -28
- mindspore/ops/operations/math_ops.py +8 -9
- mindspore/ops/operations/nn_ops.py +8 -40
- mindspore/ops/primitive.py +9 -20
- mindspore/ops/tensor_method.py +63 -15
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
- mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
- mindspore/ops_generate/api/functions_cc_generator.py +58 -10
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
- mindspore/ops_generate/common/base_generator.py +14 -0
- mindspore/ops_generate/common/gen_constants.py +8 -3
- mindspore/ops_generate/common/gen_utils.py +0 -19
- mindspore/ops_generate/common/op_proto.py +11 -4
- mindspore/ops_generate/common/template.py +88 -11
- mindspore/ops_generate/gen_ops.py +1 -1
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
- mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
- mindspore/parallel/_auto_parallel_context.py +16 -23
- mindspore/parallel/_cell_wrapper.py +113 -45
- mindspore/parallel/_parallel_serialization.py +4 -3
- mindspore/parallel/_ps_context.py +4 -6
- mindspore/parallel/_tensor.py +167 -12
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/transformer.py +17 -12
- mindspore/parallel/_utils.py +5 -11
- mindspore/parallel/auto_parallel.py +35 -14
- mindspore/parallel/checkpoint_convert.py +3 -3
- mindspore/parallel/checkpoint_transform.py +13 -7
- mindspore/parallel/cluster/process_entity/_api.py +88 -49
- mindspore/parallel/cluster/process_entity/_utils.py +95 -7
- mindspore/parallel/cluster/run.py +48 -7
- mindspore/parallel/function/__init__.py +8 -1
- mindspore/parallel/function/reshard_func.py +12 -12
- mindspore/parallel/nn/__init__.py +15 -2
- mindspore/parallel/nn/parallel_cell_wrapper.py +50 -14
- mindspore/parallel/nn/parallel_grad_reducer.py +7 -14
- mindspore/parallel/shard.py +10 -25
- mindspore/parallel/transform_safetensors.py +469 -174
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +12 -6
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
- mindspore/profiler/analysis/task_manager.py +1 -1
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +10 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +43 -23
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
- mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
- mindspore/profiler/common/constant.py +16 -0
- mindspore/profiler/common/msprof_cmd_tool.py +2 -2
- mindspore/profiler/common/path_manager.py +9 -0
- mindspore/profiler/common/profiler_context.py +50 -29
- mindspore/profiler/common/profiler_info.py +0 -16
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +239 -0
- mindspore/profiler/common/profiler_output_path.py +23 -8
- mindspore/profiler/common/profiler_parameters.py +128 -35
- mindspore/profiler/dynamic_profile/__init__.py +0 -0
- mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
- mindspore/profiler/dynamic_profiler.py +374 -338
- mindspore/profiler/envprofiler.py +42 -12
- mindspore/profiler/experimental_config.py +112 -7
- mindspore/profiler/mstx.py +33 -12
- mindspore/profiler/platform/__init__.py +2 -3
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +30 -20
- mindspore/profiler/profiler.py +218 -154
- mindspore/profiler/profiler_action_controller.py +65 -77
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/profiler/schedule.py +10 -4
- mindspore/rewrite/common/config.py +1 -0
- mindspore/rewrite/common/namer.py +1 -0
- mindspore/rewrite/common/namespace.py +1 -0
- mindspore/rewrite/node/node.py +31 -11
- mindspore/rewrite/parsers/assign_parser.py +1 -1
- mindspore/rewrite/symbol_tree/symbol_tree.py +2 -2
- mindspore/run_check/_check_version.py +7 -10
- mindspore/runtime/__init__.py +8 -6
- mindspore/runtime/event.py +10 -4
- mindspore/runtime/executor.py +87 -45
- mindspore/runtime/memory.py +31 -32
- mindspore/runtime/thread_bind_core.py +299 -165
- mindspore/safeguard/rewrite_obfuscation.py +12 -13
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +17 -7
- mindspore/train/amp.py +43 -23
- mindspore/train/callback/__init__.py +5 -5
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +4 -14
- mindspore/train/callback/_flops_collector.py +11 -7
- mindspore/train/callback/_landscape.py +0 -1
- mindspore/train/callback/_train_fault_tolerance.py +98 -21
- mindspore/train/data_sink.py +15 -6
- mindspore/train/dataset_helper.py +14 -5
- mindspore/train/model.py +133 -69
- mindspore/train/serialization.py +168 -126
- mindspore/train/summary/summary_record.py +13 -2
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +0 -6
- mindspore/utils/runtime_execution_order_check.py +163 -77
- mindspore/utils/sdc_detect.py +68 -0
- mindspore/utils/utils.py +14 -17
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/METADATA +5 -4
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/RECORD +403 -442
- mindspore/_deprecated/jit.py +0 -198
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/es/embedding_service.py +0 -891
- mindspore/experimental/es/embedding_service_layer.py +0 -581
- mindspore/profiler/common/validator/__init__.py +0 -14
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/profiler/parser/__init__.py +0 -14
- mindspore/profiler/parser/aicpu_data_parser.py +0 -272
- mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
- mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
- mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
- mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
- mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
- mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
- mindspore/profiler/parser/ascend_flops_generator.py +0 -116
- mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
- mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
- mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
- mindspore/profiler/parser/ascend_memory_generator.py +0 -185
- mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
- mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
- mindspore/profiler/parser/ascend_op_generator.py +0 -334
- mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
- mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
- mindspore/profiler/parser/base_timeline_generator.py +0 -483
- mindspore/profiler/parser/container.py +0 -229
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
- mindspore/profiler/parser/flops_parser.py +0 -531
- mindspore/profiler/parser/framework_enum.py +0 -111
- mindspore/profiler/parser/framework_parser.py +0 -464
- mindspore/profiler/parser/framework_struct.py +0 -61
- mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
- mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
- mindspore/profiler/parser/hccl_parser.py +0 -573
- mindspore/profiler/parser/hwts_log_parser.py +0 -122
- mindspore/profiler/parser/integrator.py +0 -526
- mindspore/profiler/parser/memory_usage_parser.py +0 -277
- mindspore/profiler/parser/minddata_analyzer.py +0 -800
- mindspore/profiler/parser/minddata_parser.py +0 -186
- mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
- mindspore/profiler/parser/op_intermediate_parser.py +0 -149
- mindspore/profiler/parser/optime_parser.py +0 -250
- mindspore/profiler/parser/profiler_info.py +0 -213
- mindspore/profiler/parser/step_trace_parser.py +0 -666
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/WHEEL +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/top_level.txt +0 -0
mindspore/train/serialization.py
CHANGED
|
@@ -31,15 +31,14 @@ from multiprocessing import active_children
|
|
|
31
31
|
import multiprocessing as mp
|
|
32
32
|
from collections import OrderedDict
|
|
33
33
|
from io import BytesIO
|
|
34
|
+
from functools import partial
|
|
34
35
|
|
|
35
36
|
import math
|
|
36
37
|
import sys
|
|
37
38
|
import time
|
|
38
|
-
import google
|
|
39
39
|
import numpy as np
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
from safetensors import safe_open
|
|
40
|
+
from safetensors.numpy import save_file
|
|
41
|
+
import google
|
|
43
42
|
|
|
44
43
|
from mindspore.train.checkpoint_pb2 import Checkpoint
|
|
45
44
|
from mindspore.train.mind_ir_pb2 import ModelProto as mindir_model
|
|
@@ -53,7 +52,6 @@ from mindspore.log import vlog_print
|
|
|
53
52
|
from mindspore._checkparam import check_input_data, check_input_dataset
|
|
54
53
|
from mindspore import _checkparam as Validator
|
|
55
54
|
from mindspore.common import dtype as mstype
|
|
56
|
-
from mindspore.common import np_dtype
|
|
57
55
|
from mindspore.common.api import _cell_graph_executor as _executor
|
|
58
56
|
from mindspore.common.api import _JitExecutor
|
|
59
57
|
from mindspore.common.api import _get_parameter_layout
|
|
@@ -76,6 +74,7 @@ from mindspore.parallel.checkpoint_transform import restore_group_info_list as n
|
|
|
76
74
|
from mindspore.parallel.checkpoint_transform import load_distributed_checkpoint as new_load_distributed_checkpoint
|
|
77
75
|
from mindspore.parallel.checkpoint_transform import merge_sliced_parameter as new_merge_sliced_parameter
|
|
78
76
|
from mindspore.parallel.checkpoint_transform import build_searched_strategy as new_build_searched_strategy
|
|
77
|
+
from mindspore.parallel.transform_safetensors import _fast_safe_open
|
|
79
78
|
from mindspore.train._utils import read_proto, get_parameter_redundancy, _progress_bar, _load_and_transform
|
|
80
79
|
from mindspore._c_expression import load_mindir, _encrypt, _decrypt, _is_cipher_file, \
|
|
81
80
|
split_mindir, split_dynamic_mindir
|
|
@@ -86,12 +85,9 @@ tensor_to_ms_type = {"Int8": mstype.int8, "UInt8": mstype.uint8, "Int16": mstype
|
|
|
86
85
|
"Float16": mstype.float16, "Float32": mstype.float32, "Float64": mstype.float64,
|
|
87
86
|
"Bool": mstype.bool_, "str": mstype.string, "BFloat16": mstype.bfloat16, "Int4": mstype.qint4x2}
|
|
88
87
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
if hasattr(np_dtype, "bfloat16"):
|
|
94
|
-
tensor_to_np_type["BFloat16"] = np_dtype.bfloat16
|
|
88
|
+
_tensor_to_np_type = {"Int8": np.int8, "UInt8": np.uint8, "Int16": np.int16, "UInt16": np.uint16,
|
|
89
|
+
"Int32": np.int32, "UInt32": np.uint32, "Int64": np.int64, "UInt64": np.uint64,
|
|
90
|
+
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_, "str": "U"}
|
|
95
91
|
|
|
96
92
|
np_type_convert = {"int32": np.int32, "float32": np.float32, "float16": np.float16, "float64": np.float64}
|
|
97
93
|
|
|
@@ -99,6 +95,8 @@ mindir_to_tensor_type = {1: mstype.float32, 2: mstype.uint8, 3: mstype.int8, 4:
|
|
|
99
95
|
5: mstype.int16, 6: mstype.int32, 7: mstype.int64, 10: mstype.float16,
|
|
100
96
|
11: mstype.float64, 12: mstype.uint32, 13: mstype.uint64}
|
|
101
97
|
|
|
98
|
+
safetensors_to_mstype = {'Int4': mstype.qint4x2}
|
|
99
|
+
|
|
102
100
|
_ckpt_mutex = RLock()
|
|
103
101
|
|
|
104
102
|
# unit is KB
|
|
@@ -112,6 +110,21 @@ INT_64_MAX = 9223372036854775807
|
|
|
112
110
|
cpu_cast = Cast().set_device("CPU")
|
|
113
111
|
|
|
114
112
|
_ckpt_fs = FileSystem()
|
|
113
|
+
_ckpt_fs_initialized = False
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def tensor_to_np_type(tensor_type_str):
|
|
117
|
+
"""tensor to numpy type"""
|
|
118
|
+
if tensor_type_str == "BFloat16":
|
|
119
|
+
from mindspore.common import np_dtype
|
|
120
|
+
if not np_dtype.np_dtype_valid(True):
|
|
121
|
+
raise TypeError(
|
|
122
|
+
"The Numpy bfloat16 data type is not supported now, please ensure that the current "
|
|
123
|
+
"Numpy version is not less than the version when the mindspore is compiled, "
|
|
124
|
+
"and the major versions are same."
|
|
125
|
+
)
|
|
126
|
+
return np_dtype.bfloat16
|
|
127
|
+
return _tensor_to_np_type.get(tensor_type_str)
|
|
115
128
|
|
|
116
129
|
|
|
117
130
|
def init_ckpt_file_system(fs: FileSystem):
|
|
@@ -121,8 +134,12 @@ def init_ckpt_file_system(fs: FileSystem):
|
|
|
121
134
|
_register_basic_file_system(fs)
|
|
122
135
|
|
|
123
136
|
|
|
124
|
-
|
|
125
|
-
|
|
137
|
+
def _ensure_ckpt_fs_initialized():
|
|
138
|
+
"""Ensure checkpoint file system is initialized"""
|
|
139
|
+
global _ckpt_fs_initialized
|
|
140
|
+
if not _ckpt_fs_initialized:
|
|
141
|
+
init_ckpt_file_system(_ckpt_fs)
|
|
142
|
+
_ckpt_fs_initialized = True
|
|
126
143
|
|
|
127
144
|
|
|
128
145
|
def _wait_async_process_save_ckpt():
|
|
@@ -272,10 +289,7 @@ def _update_param(param, new_param, strict_load):
|
|
|
272
289
|
|
|
273
290
|
if param.data.dtype != new_param.data.dtype:
|
|
274
291
|
if _type_convert(param, new_param, strict_load):
|
|
275
|
-
|
|
276
|
-
new_tensor = cpu_cast(new_param.data, param.data.dtype)
|
|
277
|
-
else:
|
|
278
|
-
new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype)
|
|
292
|
+
new_tensor = Tensor(new_param.data.asnumpy(), param.data.dtype)
|
|
279
293
|
param.set_data(new_tensor, param.sliced)
|
|
280
294
|
return
|
|
281
295
|
|
|
@@ -313,7 +327,7 @@ def _update_param(param, new_param, strict_load):
|
|
|
313
327
|
def _type_convert(param, new_param, strict_load):
|
|
314
328
|
"""Whether to convert parameter's type during load checkpoint into network."""
|
|
315
329
|
float_type = (mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16)
|
|
316
|
-
int_type = (mstype.int8, mstype.int16, mstype.int32, mstype.int64)
|
|
330
|
+
int_type = (mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.qint4x2)
|
|
317
331
|
if not strict_load and ({param.data.dtype, new_param.data.dtype}.issubset(float_type) or
|
|
318
332
|
{param.data.dtype, new_param.data.dtype}.issubset(int_type)):
|
|
319
333
|
logger.warning(f"The type of {new_param.name}:{new_param.data.dtype} in 'parameter_dict' is different from "
|
|
@@ -359,7 +373,7 @@ def _save_weight(checkpoint_dir, model_name, iteration, params):
|
|
|
359
373
|
|
|
360
374
|
|
|
361
375
|
def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False,
|
|
362
|
-
format="ckpt"):
|
|
376
|
+
format="ckpt", remove_redundancy=None):
|
|
363
377
|
"""Execute the process of saving checkpoint into file."""
|
|
364
378
|
try:
|
|
365
379
|
with _ckpt_mutex:
|
|
@@ -383,9 +397,6 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
383
397
|
|
|
384
398
|
crc_num = 0
|
|
385
399
|
for name, value in data_list.items():
|
|
386
|
-
if name == "random_op":
|
|
387
|
-
_write_random_seed(name, value, f)
|
|
388
|
-
continue
|
|
389
400
|
if value[0] == "mapparameter":
|
|
390
401
|
_write_mapparameter(name, value, f, map_param_inc)
|
|
391
402
|
continue
|
|
@@ -428,16 +439,19 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
428
439
|
elif format == "safetensors":
|
|
429
440
|
save_dict = {}
|
|
430
441
|
crc_num = 0
|
|
442
|
+
meta_data = {"format": "ms"}
|
|
443
|
+
if remove_redundancy is not None and isinstance(remove_redundancy, bool):
|
|
444
|
+
meta_data["remove_redundancy"] = str(remove_redundancy)
|
|
431
445
|
for name in sorted(data_list.keys()):
|
|
432
446
|
value = data_list[name]
|
|
433
447
|
if isinstance(value[2], np.ndarray):
|
|
448
|
+
if value[1] == str(mstype.qint4x2):
|
|
449
|
+
meta_data[name] = str(mstype.qint4x2)
|
|
434
450
|
save_dict[name] = value[2]
|
|
435
451
|
else:
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
new_np_array = np_array.reshape(value[0])
|
|
440
|
-
save_dict[name] = new_np_array
|
|
452
|
+
if value[2].dtype == mstype.qint4x2:
|
|
453
|
+
meta_data[name] = str(mstype.qint4x2)
|
|
454
|
+
save_dict[name] = value[2].asnumpy()
|
|
441
455
|
|
|
442
456
|
if crc_check:
|
|
443
457
|
crc_num = binascii.crc32(bytes(name, encoding='utf-8'), crc_num)
|
|
@@ -445,10 +459,12 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
445
459
|
bytes(save_dict[name]), crc_num)
|
|
446
460
|
safetensors_save_time_start = time.time()
|
|
447
461
|
if crc_check:
|
|
448
|
-
|
|
449
|
-
|
|
462
|
+
meta_data.update({"crc_num": str(crc_num)})
|
|
463
|
+
if save_dict:
|
|
464
|
+
save_file(save_dict, tmp_name, metadata=meta_data)
|
|
450
465
|
else:
|
|
451
466
|
save_file(save_dict, tmp_name)
|
|
467
|
+
|
|
452
468
|
safetensors_save_time_end = time.time()
|
|
453
469
|
cost_time = safetensors_save_time_end - safetensors_save_time_start
|
|
454
470
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save safetensors io cost time:{cost_time}.")
|
|
@@ -457,25 +473,13 @@ def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_
|
|
|
457
473
|
f"simultaneously modified a file.")
|
|
458
474
|
elif _ckpt_fs.backend != "mindio":
|
|
459
475
|
os.rename(tmp_name, ckpt_file_name)
|
|
460
|
-
|
|
476
|
+
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
|
461
477
|
except BaseException as e:
|
|
462
478
|
logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, "
|
|
463
479
|
"or the disk space is insufficient and so on.", ckpt_file_name)
|
|
464
480
|
raise e
|
|
465
481
|
|
|
466
482
|
|
|
467
|
-
def _write_random_seed(name, value, f):
|
|
468
|
-
"""Write random op into protobuf file."""
|
|
469
|
-
checkpoint_list = Checkpoint()
|
|
470
|
-
param_value = checkpoint_list.value.add()
|
|
471
|
-
param_value.tag = name
|
|
472
|
-
param_tensor = param_value.tensor
|
|
473
|
-
param_tensor.dims.extend(0)
|
|
474
|
-
param_tensor.tensor_type = "random_op"
|
|
475
|
-
param_tensor.tensor_content = value
|
|
476
|
-
f.write(checkpoint_list.SerializeToString())
|
|
477
|
-
|
|
478
|
-
|
|
479
483
|
def _write_parameter_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False, ckpt_total_io_time=0):
|
|
480
484
|
"""Write parameter data into protobuf file."""
|
|
481
485
|
data_size = value[2].nbytes / 1024
|
|
@@ -599,7 +603,7 @@ def _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format):
|
|
|
599
603
|
return ckpt_file_name
|
|
600
604
|
|
|
601
605
|
|
|
602
|
-
def
|
|
606
|
+
def _check_load_checkpoint_unsupported_param(format, dec_key, dec_mode):
|
|
603
607
|
"""check load checkpoint unsupported param"""
|
|
604
608
|
if format != "safetensors":
|
|
605
609
|
return
|
|
@@ -614,7 +618,7 @@ def _check_load_checkpoint_upsupported_param(format, dec_key, dec_mode):
|
|
|
614
618
|
f"be set to default value '{default_value}', but got '{current_value}'.")
|
|
615
619
|
|
|
616
620
|
|
|
617
|
-
def
|
|
621
|
+
def _check_save_checkpoint_unsupported_param(format, enc_key, enc_mode, map_param_inc=False, global_step_num=None):
|
|
618
622
|
"""check save checkpoint unsupported param"""
|
|
619
623
|
if format != "safetensors":
|
|
620
624
|
return
|
|
@@ -644,11 +648,11 @@ def _check_async_save(async_save):
|
|
|
644
648
|
|
|
645
649
|
|
|
646
650
|
def _async_process_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False,
|
|
647
|
-
crc_check=False, format="ckpt", cond=None):
|
|
651
|
+
crc_check=False, format="ckpt", cond=None, remove_redundancy=None):
|
|
648
652
|
"""Check whether the process is pulled up successfully, execute the process of saving checkpoint into file."""
|
|
649
653
|
with cond:
|
|
650
654
|
cond.notify()
|
|
651
|
-
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
|
|
655
|
+
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format, remove_redundancy)
|
|
652
656
|
|
|
653
657
|
|
|
654
658
|
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
@@ -729,6 +733,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
729
733
|
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
730
734
|
"""
|
|
731
735
|
start_save_time = time.time()
|
|
736
|
+
_ensure_ckpt_fs_initialized()
|
|
732
737
|
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
|
|
733
738
|
integrated_save = Validator.check_bool(integrated_save)
|
|
734
739
|
async_save = _check_async_save(async_save)
|
|
@@ -739,7 +744,9 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
739
744
|
map_param_inc = kwargs.get('incremental', False)
|
|
740
745
|
logger.info("Execute the process of saving checkpoint files.")
|
|
741
746
|
global_step_num = kwargs.get('global_step_num', None)
|
|
742
|
-
|
|
747
|
+
remove_redundancy = kwargs.get('remove_redundancy', None)
|
|
748
|
+
remove_redundancy = Validator.check_isinstance("remove_redundancy", remove_redundancy, (type(None), bool))
|
|
749
|
+
_check_save_checkpoint_unsupported_param(format, enc_key, enc_mode, map_param_inc, global_step_num)
|
|
743
750
|
|
|
744
751
|
if append_dict and "__exception_save__" in append_dict:
|
|
745
752
|
s1 = mindspore.hal.Stream()
|
|
@@ -768,16 +775,6 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
768
775
|
data_list_np = OrderedDict()
|
|
769
776
|
with _ckpt_mutex:
|
|
770
777
|
for param in save_obj:
|
|
771
|
-
if param["name"] == "random_op":
|
|
772
|
-
if os.getenv("AITURBO") == "1":
|
|
773
|
-
data_list_np["random_op"] = []
|
|
774
|
-
data_list_np["random_op"].append(param["data"])
|
|
775
|
-
if crc_check:
|
|
776
|
-
bytes_value = bytes(data_list_np[key][0])
|
|
777
|
-
data_list_np[key].append(binascii.crc32(bytes_value))
|
|
778
|
-
else:
|
|
779
|
-
data_list["random_op"] = param["data"]
|
|
780
|
-
continue
|
|
781
778
|
key = param["name"]
|
|
782
779
|
data_list[key] = []
|
|
783
780
|
data_list_np[key] = []
|
|
@@ -841,7 +838,7 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
841
838
|
while process_flag:
|
|
842
839
|
process = ctx.Process(target=_async_process_save,
|
|
843
840
|
args=(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check,
|
|
844
|
-
format, cond), daemon=True, name="asyn_save_ckpt")
|
|
841
|
+
format, cond, remove_redundancy), daemon=True, name="asyn_save_ckpt")
|
|
845
842
|
process.start()
|
|
846
843
|
with cond:
|
|
847
844
|
wait_flag = cond.wait(timeout=5)
|
|
@@ -854,11 +851,12 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
|
|
|
854
851
|
data_copy = copy.deepcopy(data_list)
|
|
855
852
|
_wait_async_thread_save_ckpt()
|
|
856
853
|
thr = Thread(target=_exec_save,
|
|
857
|
-
args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format
|
|
854
|
+
args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format,
|
|
855
|
+
remove_redundancy),
|
|
858
856
|
name="asyn_save_ckpt")
|
|
859
857
|
thr.start()
|
|
860
858
|
else:
|
|
861
|
-
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
|
|
859
|
+
_exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format, remove_redundancy)
|
|
862
860
|
|
|
863
861
|
mstx.range_end(range_id)
|
|
864
862
|
logger.info("Saving checkpoint process is finished.")
|
|
@@ -926,10 +924,13 @@ def _convert_dict_to_param_dict(save_obj, choice_func):
|
|
|
926
924
|
"""Convert a dict of Parameter to param_list."""
|
|
927
925
|
param_list = []
|
|
928
926
|
for (key, value) in save_obj.items():
|
|
929
|
-
if isinstance(key, str)
|
|
927
|
+
if isinstance(key, str):
|
|
930
928
|
if choice_func is not None and not choice_func(key):
|
|
931
929
|
continue
|
|
932
|
-
|
|
930
|
+
if isinstance(value, np.ndarray):
|
|
931
|
+
each_param = {"name": key, "data": Parameter(Tensor.from_numpy(value))}
|
|
932
|
+
if isinstance(value, (Parameter, str)) or _is_buffer_type(value):
|
|
933
|
+
each_param = {"name": key, "data": value}
|
|
933
934
|
param_list.append(each_param)
|
|
934
935
|
else:
|
|
935
936
|
raise TypeError(f"For save_checkpoint, when save_obj is made up by dict, the key should be str and"
|
|
@@ -941,16 +942,12 @@ def _convert_dict_to_param_dict(save_obj, choice_func):
|
|
|
941
942
|
def _convert_cell_param_and_names_to_dict(save_obj, choice_func, is_parallel_mode):
|
|
942
943
|
"""Convert cell.parameters_and_names to OrderedDict."""
|
|
943
944
|
param_dict = OrderedDict()
|
|
945
|
+
is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
|
|
944
946
|
for _, param in save_obj.parameters_and_names():
|
|
945
|
-
if param.name.startswith("accu_grads") or param.name.endswith("expert_load"):
|
|
946
|
-
continue
|
|
947
|
-
not_sliced = not param.sliced
|
|
948
|
-
is_graph_mode = context.get_context('mode') == context.GRAPH_MODE
|
|
949
947
|
# All parameters are initialized immediately under PyNative mode, skip this judgement.
|
|
950
|
-
judgment = not_sliced or param.has_init
|
|
951
948
|
if param.param_info.is_pipeline_shared_param:
|
|
952
949
|
continue
|
|
953
|
-
if
|
|
950
|
+
if is_parallel_mode and is_graph_mode and (not param.sliced or param.has_init):
|
|
954
951
|
continue
|
|
955
952
|
if choice_func is not None and not choice_func(param.name):
|
|
956
953
|
continue
|
|
@@ -974,12 +971,6 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
974
971
|
if not is_parallel_mode:
|
|
975
972
|
save_obj.init_parameters_data()
|
|
976
973
|
param_dict = _convert_cell_param_and_names_to_dict(save_obj, choice_func, is_parallel_mode)
|
|
977
|
-
if append_dict and "random_op" in append_dict:
|
|
978
|
-
phase = 'train' + '.' + str(save_obj.create_time) + '.' + str(id(save_obj)) + '.' + save_obj.arguments_key
|
|
979
|
-
if phase in save_obj.compile_cache and _executor.has_compiled(phase):
|
|
980
|
-
random_byte = _executor._graph_executor.get_random_status(phase)
|
|
981
|
-
param_list.append({"name": "random_op", "data": random_byte})
|
|
982
|
-
append_dict.pop("random_op")
|
|
983
974
|
for (key, value) in param_dict.items():
|
|
984
975
|
each_param = {"name": key}
|
|
985
976
|
if isinstance(value, MapParameter):
|
|
@@ -1002,15 +993,14 @@ def _convert_cell_to_param_list(save_obj, integrated_save, append_dict, choice_f
|
|
|
1002
993
|
param_data.append(str(param_tensor.dtype))
|
|
1003
994
|
param_data.append(value.key)
|
|
1004
995
|
else:
|
|
1005
|
-
param_data = value.data
|
|
1006
996
|
if append_dict and "__exception_save__" in append_dict:
|
|
1007
997
|
param_data = Tensor(Tensor_.move_to(value, "CPU", False))
|
|
998
|
+
else:
|
|
999
|
+
param_data = Tensor(value.data)
|
|
1008
1000
|
|
|
1009
1001
|
# in automatic model parallel scenario, some parameters were split to all the devices,
|
|
1010
1002
|
# which should be combined before saving
|
|
1011
1003
|
if key in parameter_layout_dict:
|
|
1012
|
-
if not append_dict or "__exception_save__" not in append_dict:
|
|
1013
|
-
param_data = Tensor(value.data)
|
|
1014
1004
|
param_data = _get_merged_param_data(save_obj, parameter_layout_dict, key, param_data,
|
|
1015
1005
|
integrated_save)
|
|
1016
1006
|
|
|
@@ -1215,12 +1205,26 @@ def _check_param_type(param_config, key, target_type, requested):
|
|
|
1215
1205
|
return None
|
|
1216
1206
|
|
|
1217
1207
|
|
|
1208
|
+
def _check_remove_redundancy(remove_redundancy, f):
|
|
1209
|
+
"""Check whether remove_redundancy is consistent with the safetensors file."""
|
|
1210
|
+
if f.metadata() is not None and "remove_redundancy" in f.metadata().keys():
|
|
1211
|
+
if f.metadata()["remove_redundancy"] == "True" and not remove_redundancy:
|
|
1212
|
+
logger.warning("For 'load_checkpoint', the safetensors file is deduplicated, "
|
|
1213
|
+
"but remove_redundancy is set to False.")
|
|
1214
|
+
return True
|
|
1215
|
+
if f.metadata()["remove_redundancy"] == "False" and remove_redundancy:
|
|
1216
|
+
logger.warning("For 'load_checkpoint', the safetensors file is non-deduplicated, "
|
|
1217
|
+
"but remove_redundancy is set to True.")
|
|
1218
|
+
return False
|
|
1219
|
+
return remove_redundancy
|
|
1220
|
+
|
|
1221
|
+
|
|
1218
1222
|
def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix, choice_func, dec_key,
|
|
1219
|
-
dec_mode, crc_check, format):
|
|
1223
|
+
dec_mode, crc_check, format, remove_redundancy):
|
|
1220
1224
|
"""load parameter into parameter_dict"""
|
|
1221
1225
|
ckpt_file_name = _check_ckpt_file_name(ckpt_file_name, format)
|
|
1222
1226
|
if format == "safetensors":
|
|
1223
|
-
with
|
|
1227
|
+
with _fast_safe_open(ckpt_file_name, framework='np') as f:
|
|
1224
1228
|
cal_crc_num = 0
|
|
1225
1229
|
total_io_cost_time = 0
|
|
1226
1230
|
for k in sorted(f.keys()):
|
|
@@ -1234,8 +1238,13 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
|
|
|
1234
1238
|
io_end_time = time.time()
|
|
1235
1239
|
io_cost_time = io_end_time - io_start_time
|
|
1236
1240
|
total_io_cost_time += io_cost_time
|
|
1237
|
-
|
|
1238
|
-
|
|
1241
|
+
if f.metadata() is not None and k in f.metadata().keys():
|
|
1242
|
+
sf_dtype = f.metadata()[k]
|
|
1243
|
+
ms_dtype = safetensors_to_mstype[sf_dtype]
|
|
1244
|
+
parameter_dict[k] = Parameter(Tensor(value, dtype=ms_dtype))
|
|
1245
|
+
else:
|
|
1246
|
+
parameter_dict[k] = Parameter(Tensor.from_numpy(value))
|
|
1247
|
+
remove_redundancy = _check_remove_redundancy(remove_redundancy, f)
|
|
1239
1248
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
|
|
1240
1249
|
f"Load safetensors io cost time:{total_io_cost_time}.")
|
|
1241
1250
|
if crc_check:
|
|
@@ -1248,7 +1257,7 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
|
|
|
1248
1257
|
if cal_crc_num != crc_num:
|
|
1249
1258
|
raise ValueError("For 'load_checkpoint', the crc check has failed. "
|
|
1250
1259
|
"Please check whether the ckpt file is damaged.")
|
|
1251
|
-
return
|
|
1260
|
+
return remove_redundancy
|
|
1252
1261
|
checkpoint_list = _parse_ckpt_proto(ckpt_file_name, dec_key, dec_mode, crc_check)
|
|
1253
1262
|
try:
|
|
1254
1263
|
param_data_list = []
|
|
@@ -1261,9 +1270,6 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
|
|
|
1261
1270
|
logger.warning("For load_checkpoint, this parameter `filter_prefix` will be deprecated, "
|
|
1262
1271
|
"please use `choice_func` instead.")
|
|
1263
1272
|
for element_id, element in enumerate(checkpoint_list.value):
|
|
1264
|
-
if element.tag == "random_op":
|
|
1265
|
-
parameter_dict["random_op"] = element.tensor.tensor_content
|
|
1266
|
-
continue
|
|
1267
1273
|
if not _whether_load_param(specify_prefix, filter_prefix, element.tag):
|
|
1268
1274
|
continue
|
|
1269
1275
|
if specify_prefix is None and filter_prefix is None and \
|
|
@@ -1278,11 +1284,7 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
|
|
|
1278
1284
|
continue
|
|
1279
1285
|
data = element.tensor.tensor_content
|
|
1280
1286
|
data_type = element.tensor.tensor_type
|
|
1281
|
-
np_type = tensor_to_np_type.get(data_type)
|
|
1282
1287
|
ms_type = tensor_to_ms_type[data_type]
|
|
1283
|
-
if data_type == 'str':
|
|
1284
|
-
str_length = int(len(data) / 4)
|
|
1285
|
-
np_type = np_type + str(str_length)
|
|
1286
1288
|
param_data_list.append(data)
|
|
1287
1289
|
if (element_id == len(checkpoint_list.value) - 1) or \
|
|
1288
1290
|
(element.tag != checkpoint_list.value[element_id + 1].tag):
|
|
@@ -1290,6 +1292,8 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
|
|
|
1290
1292
|
param_data_list.clear()
|
|
1291
1293
|
dims = element.tensor.dims
|
|
1292
1294
|
if data_type == 'str':
|
|
1295
|
+
str_length = int(len(data) / 4)
|
|
1296
|
+
np_type = "U" + str(str_length)
|
|
1293
1297
|
str_value = np.frombuffer(new_data, np_type)
|
|
1294
1298
|
parameter_dict[element.tag] = str(str_value[0])
|
|
1295
1299
|
else:
|
|
@@ -1301,6 +1305,7 @@ def _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter
|
|
|
1301
1305
|
_offload_if_config(parameter)
|
|
1302
1306
|
|
|
1303
1307
|
logger.info("Loading checkpoint files process is finished.")
|
|
1308
|
+
return remove_redundancy
|
|
1304
1309
|
|
|
1305
1310
|
except BaseException as e:
|
|
1306
1311
|
logger.critical("Failed to load the checkpoint file '%s'.", ckpt_file_name)
|
|
@@ -1320,6 +1325,9 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1320
1325
|
And using either of those two args will override `choice_func` at the same time.
|
|
1321
1326
|
- If none of the parameters are loaded from checkpoint file, it will throw ValueError.
|
|
1322
1327
|
- When loading a checkpoint that has removed redundancy, the network should be compiled.
|
|
1328
|
+
- When `net` is not None, it will verify whether the `remove_redundancy` parameter matches the
|
|
1329
|
+
deduplication flag in the loaded safetensors file. If they are different, load the file according to
|
|
1330
|
+
the deduplication flag in the file.
|
|
1323
1331
|
|
|
1324
1332
|
Args:
|
|
1325
1333
|
ckpt_file_name (str): Checkpoint file name.
|
|
@@ -1392,13 +1400,14 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1392
1400
|
"""
|
|
1393
1401
|
start_load_time = time.time()
|
|
1394
1402
|
vlog_print("1", "ME", __file__, sys._getframe().f_lineno, "Begin load checkpoint.")
|
|
1403
|
+
_ensure_ckpt_fs_initialized()
|
|
1395
1404
|
specify_prefix = _check_prefix(specify_prefix)
|
|
1396
1405
|
filter_prefix = _check_prefix(filter_prefix)
|
|
1397
1406
|
dec_key = Validator.check_isinstance('dec_key', dec_key, (type(None), bytes))
|
|
1398
1407
|
dec_mode = Validator.check_isinstance('dec_mode', dec_mode, str)
|
|
1399
1408
|
crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
|
|
1400
1409
|
remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
|
|
1401
|
-
|
|
1410
|
+
_check_load_checkpoint_unsupported_param(format, dec_key, dec_mode)
|
|
1402
1411
|
logger.info("Execute the process of loading checkpoint files.")
|
|
1403
1412
|
|
|
1404
1413
|
parameter_dict = {}
|
|
@@ -1424,8 +1433,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
1424
1433
|
f"passed the CRC check and has been corrupted.")
|
|
1425
1434
|
parameter_dict[key] = Parameter(Tensor(value[0]), name=key)
|
|
1426
1435
|
else:
|
|
1427
|
-
_load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix,
|
|
1428
|
-
|
|
1436
|
+
remove_redundancy = _load_into_param_dict(ckpt_file_name, parameter_dict, specify_prefix, filter_prefix,
|
|
1437
|
+
choice_func, dec_key, dec_mode, crc_check, format, remove_redundancy)
|
|
1429
1438
|
|
|
1430
1439
|
if not parameter_dict:
|
|
1431
1440
|
raise ValueError(f"The loaded parameter dict is empty after filter or specify, please check whether "
|
|
@@ -1672,9 +1681,22 @@ def _check_load_param_into_net(net, parameter_dict):
|
|
|
1672
1681
|
msg = ("For 'load_param_into_net', the argument 'parameter_dict' should be a dict, "
|
|
1673
1682
|
"but got {}.".format(type(parameter_dict)))
|
|
1674
1683
|
raise TypeError(msg)
|
|
1675
|
-
|
|
1676
|
-
|
|
1677
|
-
|
|
1684
|
+
for key, value in parameter_dict.items():
|
|
1685
|
+
if not isinstance(key, str) or not isinstance(value, (Parameter, str, list)):
|
|
1686
|
+
logger.critical("Load parameters into net failed.")
|
|
1687
|
+
msg = ("For 'parameter_dict', the element in the argument 'parameter_dict' should be a "
|
|
1688
|
+
"'str' and 'Parameter' , but got {} and {}.".format(type(key), type(value)))
|
|
1689
|
+
raise TypeError(msg)
|
|
1690
|
+
|
|
1691
|
+
|
|
1692
|
+
def _check_remove_redundancy_net(net):
|
|
1693
|
+
"""Check whether the network is compiled with the remove_redundancy feature."""
|
|
1694
|
+
if get_group_size() == 1:
|
|
1695
|
+
raise TypeError(f"The deduplication feature for loading checkpoint can only be used "
|
|
1696
|
+
f"in parallel scenarios, but got stand_alone.")
|
|
1697
|
+
if not net.compile_cache and not net.parameter_layout_dict:
|
|
1698
|
+
raise ValueError("When loading a parameter dict that has removed redundancy, "
|
|
1699
|
+
"the network should be compiled.")
|
|
1678
1700
|
|
|
1679
1701
|
|
|
1680
1702
|
def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundancy=False):
|
|
@@ -1721,18 +1743,14 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
|
|
|
1721
1743
|
<https://mindspore.cn/tutorials/en/master/beginner/save_load.html#saving-and-loading-the-model-weight>`_
|
|
1722
1744
|
"""
|
|
1723
1745
|
_check_load_param_into_net(net, parameter_dict)
|
|
1724
|
-
for key, value in parameter_dict.items():
|
|
1725
|
-
if not isinstance(key, str) or not isinstance(value, (Parameter, str, list)):
|
|
1726
|
-
logger.critical("Load parameters into net failed.")
|
|
1727
|
-
msg = ("For 'parameter_dict', the element in the argument 'parameter_dict' should be a "
|
|
1728
|
-
"'str' and 'Parameter' , but got {} and {}.".format(type(key), type(value)))
|
|
1729
|
-
raise TypeError(msg)
|
|
1730
1746
|
|
|
1731
1747
|
strict_load = Validator.check_bool(strict_load)
|
|
1732
1748
|
remove_redundancy = Validator.check_isinstance('remove_redundancy', remove_redundancy, bool)
|
|
1733
1749
|
logger.info("Execute the process of loading parameters into net.")
|
|
1734
1750
|
param_not_load = []
|
|
1751
|
+
param_loaded = set()
|
|
1735
1752
|
ckpt_not_load = list(parameter_dict.keys())
|
|
1753
|
+
is_parallel_mode = _is_auto_parallel_mode(net)
|
|
1736
1754
|
for _, param in net.parameters_and_names():
|
|
1737
1755
|
if param.param_info.is_pipeline_shared_param:
|
|
1738
1756
|
continue
|
|
@@ -1748,22 +1766,23 @@ def load_param_into_net(net, parameter_dict, strict_load=False, remove_redundanc
|
|
|
1748
1766
|
if hasattr(param, "init_param") and not param.init_param:
|
|
1749
1767
|
param.init_param = True
|
|
1750
1768
|
ckpt_not_load.remove(param.name)
|
|
1769
|
+
param_loaded.add(param.name)
|
|
1751
1770
|
else:
|
|
1771
|
+
if param.name.startswith("accu_grads"):
|
|
1772
|
+
continue
|
|
1773
|
+
if param.param_info.is_pipeline_shared_param:
|
|
1774
|
+
continue
|
|
1775
|
+
if is_parallel_mode and not param.sliced:
|
|
1776
|
+
continue
|
|
1752
1777
|
param_not_load.append(param.name)
|
|
1753
1778
|
|
|
1754
1779
|
if param_not_load and not strict_load:
|
|
1755
1780
|
_load_dismatch_prefix_params(net, parameter_dict, param_not_load, strict_load)
|
|
1756
1781
|
|
|
1757
1782
|
if remove_redundancy:
|
|
1758
|
-
|
|
1759
|
-
raise TypeError(f"The deduplication feature for loading checkpoint can only be used "
|
|
1760
|
-
f"in parallel scenarios, but got stand_alone.")
|
|
1761
|
-
if not net.compile_cache and not net.parameter_layout_dict:
|
|
1762
|
-
raise ValueError("When loading a parameter dict that has removed redundancy, "
|
|
1763
|
-
"the network should be compiled.")
|
|
1783
|
+
_check_remove_redundancy_net(net)
|
|
1764
1784
|
param_layout = net.parameter_layout_dict
|
|
1765
|
-
_single_parameter_broadcast(net, param_layout, param_not_load)
|
|
1766
|
-
mindspore.hal.synchronize()
|
|
1785
|
+
_single_parameter_broadcast(net, param_layout, param_not_load, param_loaded)
|
|
1767
1786
|
|
|
1768
1787
|
logger.info("Loading parameters into net is finished.")
|
|
1769
1788
|
if param_not_load:
|
|
@@ -1878,9 +1897,10 @@ def _save_graph(network, file_name):
|
|
|
1878
1897
|
file_name (str): Graph file name into which the graph will be saved.
|
|
1879
1898
|
"""
|
|
1880
1899
|
logger.info("Execute the process of saving graph.")
|
|
1881
|
-
|
|
1882
1900
|
file_name = os.path.realpath(file_name)
|
|
1883
1901
|
graph_pb = network.get_func_graph_proto()
|
|
1902
|
+
if os.path.isfile(file_name) and graph_pb:
|
|
1903
|
+
os.remove(file_name)
|
|
1884
1904
|
if graph_pb:
|
|
1885
1905
|
with open(file_name, "wb") as f:
|
|
1886
1906
|
os.chmod(file_name, stat.S_IRUSR | stat.S_IWUSR)
|
|
@@ -2193,6 +2213,11 @@ def _save_onnx(net, file_name, *inputs, **kwargs):
|
|
|
2193
2213
|
file_name += ".onnx"
|
|
2194
2214
|
if os.path.exists(file_name):
|
|
2195
2215
|
os.chmod(file_name, stat.S_IWUSR)
|
|
2216
|
+
else:
|
|
2217
|
+
dir_path = os.path.dirname(file_name)
|
|
2218
|
+
if not os.path.exists(dir_path):
|
|
2219
|
+
os.makedirs(dir_path, mode=0o700, exist_ok=True)
|
|
2220
|
+
os.chmod(dir_path, 0o700)
|
|
2196
2221
|
with open(file_name, 'wb') as f:
|
|
2197
2222
|
f.write(onnx_stream)
|
|
2198
2223
|
os.chmod(file_name, stat.S_IRUSR)
|
|
@@ -2242,7 +2267,7 @@ def _get_data_file(is_encrypt, kwargs, data_file_name):
|
|
|
2242
2267
|
if is_encrypt():
|
|
2243
2268
|
place_holder_data = _encrypt(place_holder_data, len(place_holder_data), kwargs["enc_key"],
|
|
2244
2269
|
len(kwargs["enc_key"]), kwargs["enc_mode"])
|
|
2245
|
-
parameter_size =
|
|
2270
|
+
parameter_size = offset / 1024
|
|
2246
2271
|
try:
|
|
2247
2272
|
f = open(data_file_name, "wb")
|
|
2248
2273
|
f.write(place_holder_data)
|
|
@@ -2284,9 +2309,11 @@ def _split_save(net_dict, model, file_name, is_encrypt, **kwargs):
|
|
|
2284
2309
|
external_local = os.path.join(file_prefix + "_variables", "data_" + str(index))
|
|
2285
2310
|
data_file_name = os.path.join(dirname, external_local)
|
|
2286
2311
|
f, parameter_size, offset = _get_data_file(is_encrypt, kwargs, data_file_name)
|
|
2312
|
+
|
|
2313
|
+
round = 0
|
|
2314
|
+
names = []
|
|
2315
|
+
|
|
2287
2316
|
try:
|
|
2288
|
-
round = 0
|
|
2289
|
-
names = []
|
|
2290
2317
|
for param_proto in model.graph.parameter:
|
|
2291
2318
|
name = param_proto.name[param_proto.name.find(":") + 1:]
|
|
2292
2319
|
names.append((name, param_proto))
|
|
@@ -2587,7 +2614,7 @@ def parse_print(print_file_name):
|
|
|
2587
2614
|
dims = print_.tensor.dims
|
|
2588
2615
|
data_type = print_.tensor.tensor_type
|
|
2589
2616
|
data = print_.tensor.tensor_content
|
|
2590
|
-
np_type = tensor_to_np_type
|
|
2617
|
+
np_type = tensor_to_np_type(data_type)
|
|
2591
2618
|
param_data = np.fromstring(data, np_type)
|
|
2592
2619
|
ms_type = tensor_to_ms_type.get(data_type)
|
|
2593
2620
|
if dims and dims != [0]:
|
|
@@ -2730,28 +2757,35 @@ def convert_model(mindir_file, convert_file, file_format):
|
|
|
2730
2757
|
export(net, *net_input, file_name=convert_file, file_format=file_format)
|
|
2731
2758
|
|
|
2732
2759
|
|
|
2733
|
-
def
|
|
2734
|
-
return _load_and_transform(path, name_map, mindspore.load_checkpoint,
|
|
2760
|
+
def _load_ckpt_to_new_name_map(path, name_map=None):
|
|
2761
|
+
return _load_and_transform(path, name_map, mindspore.load_checkpoint, None)
|
|
2735
2762
|
|
|
2736
2763
|
|
|
2737
|
-
def
|
|
2738
|
-
|
|
2764
|
+
def _load_sf_to_new_name_map(path, name_map=None):
|
|
2765
|
+
load_func = partial(mindspore.load_checkpoint, format="safetensors")
|
|
2766
|
+
return _load_and_transform(path, name_map, load_func, None)
|
|
2739
2767
|
|
|
2740
2768
|
|
|
2741
2769
|
def _process_file(file_info):
|
|
2742
2770
|
cur_ckpt_path, name_map, save_path, file = file_info
|
|
2743
|
-
|
|
2771
|
+
if name_map is not None:
|
|
2772
|
+
param_dict = _load_ckpt_to_new_name_map(cur_ckpt_path, name_map)
|
|
2773
|
+
else:
|
|
2774
|
+
param_dict = mindspore.load_checkpoint(cur_ckpt_path)
|
|
2744
2775
|
safetensors_filename = file.replace(".ckpt", ".safetensors")
|
|
2745
2776
|
dst_file = os.path.join(save_path, safetensors_filename)
|
|
2746
|
-
|
|
2777
|
+
mindspore.save_checkpoint(param_dict, dst_file, format='safetensors')
|
|
2747
2778
|
|
|
2748
2779
|
|
|
2749
2780
|
def _process_file_safetensors(file_info):
|
|
2750
2781
|
cur_safe_path, name_map, save_path, file = file_info
|
|
2751
|
-
|
|
2782
|
+
if name_map is not None:
|
|
2783
|
+
param_dict = _load_sf_to_new_name_map(cur_safe_path, name_map)
|
|
2784
|
+
else:
|
|
2785
|
+
param_dict = mindspore.load_checkpoint(cur_safe_path, format="safetensors")
|
|
2752
2786
|
ckpt_filename = file.replace(".safetensors", ".ckpt")
|
|
2753
2787
|
dst_file = os.path.join(save_path, ckpt_filename)
|
|
2754
|
-
mindspore.save_checkpoint(
|
|
2788
|
+
mindspore.save_checkpoint(param_dict, dst_file)
|
|
2755
2789
|
|
|
2756
2790
|
|
|
2757
2791
|
def _gather_safetensors_tasks(file_path, save_path, file_name_regex, name_map):
|
|
@@ -2862,10 +2896,14 @@ def ckpt_to_safetensors(file_path, save_path=None, name_map=None, file_name_rege
|
|
|
2862
2896
|
if save_path and not os.path.exists(save_path):
|
|
2863
2897
|
os.makedirs(save_path, exist_ok=True)
|
|
2864
2898
|
|
|
2865
|
-
|
|
2899
|
+
if name_map is not None:
|
|
2900
|
+
param_dict = _load_ckpt_to_new_name_map(file_path, name_map)
|
|
2901
|
+
else:
|
|
2902
|
+
param_dict = mindspore.load_checkpoint(file_path)
|
|
2903
|
+
|
|
2866
2904
|
safetensors_filename = os.path.basename(file_path).replace(".ckpt", ".safetensors")
|
|
2867
2905
|
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), safetensors_filename)
|
|
2868
|
-
|
|
2906
|
+
mindspore.save_checkpoint(param_dict, dst_file, format='safetensors')
|
|
2869
2907
|
|
|
2870
2908
|
|
|
2871
2909
|
def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_regex=None, processes_num=1):
|
|
@@ -2924,10 +2962,14 @@ def safetensors_to_ckpt(file_path, save_path=None, name_map=None, file_name_rege
|
|
|
2924
2962
|
if save_path and not os.path.exists(save_path):
|
|
2925
2963
|
os.makedirs(save_path, exist_ok=True)
|
|
2926
2964
|
|
|
2927
|
-
|
|
2965
|
+
if name_map is not None:
|
|
2966
|
+
param_dict = _load_sf_to_new_name_map(file_path, name_map)
|
|
2967
|
+
else:
|
|
2968
|
+
param_dict = mindspore.load_checkpoint(file_path, format="safetensors")
|
|
2969
|
+
|
|
2928
2970
|
ckpt_filename = os.path.basename(file_path).replace(".safetensors", ".ckpt")
|
|
2929
2971
|
dst_file = os.path.join(save_path if save_path else os.path.dirname(file_path), ckpt_filename)
|
|
2930
|
-
mindspore.save_checkpoint(
|
|
2972
|
+
mindspore.save_checkpoint(param_dict, dst_file)
|
|
2931
2973
|
|
|
2932
2974
|
|
|
2933
2975
|
def restore_group_info_list(group_info_file_name):
|