mindspore 2.6.0rc1__cp311-cp311-win_amd64.whl → 2.7.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +2 -2
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +42 -11
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/{_deprecated → _extends/optimize}/__init__.py +9 -3
- mindspore/_extends/optimize/cell_utils.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +1109 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/compile_config.py +44 -22
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -2
- mindspore/_extends/parse/parser.py +65 -84
- mindspore/_extends/parse/resources.py +39 -0
- mindspore/_extends/parse/standard_method.py +58 -14
- mindspore/_extends/parse/trope.py +8 -1
- mindspore/_extends/pijit/__init__.py +1 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +2 -5
- mindspore/amp.py +4 -22
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/adasum.py +1 -1
- mindspore/boost/boost_cell_wrapper.py +4 -4
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +43 -12
- mindspore/common/_grad_function.py +2 -1
- mindspore/common/_pijit_context.py +28 -7
- mindspore/common/_stub_tensor.py +1 -209
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +178 -53
- mindspore/common/_utils.py +9 -1
- mindspore/common/api.py +377 -203
- mindspore/common/dtype.py +108 -57
- mindspore/common/dump.py +11 -16
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/{auto_dynamic_shape.py → dynamic_shape/auto_dynamic_shape.py} +17 -23
- mindspore/common/dynamic_shape/enable_dynamic.py +197 -0
- mindspore/common/file_system.py +59 -9
- mindspore/common/generator.py +5 -3
- mindspore/common/hook_handle.py +33 -5
- mindspore/common/jit_config.py +1 -1
- mindspore/common/jit_trace.py +84 -105
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +27 -29
- mindspore/common/recompute.py +5 -7
- mindspore/common/sparse_tensor.py +0 -3
- mindspore/common/symbol.py +0 -1
- mindspore/common/tensor.py +117 -131
- mindspore/communication/_comm_helper.py +46 -4
- mindspore/communication/management.py +79 -7
- mindspore/context.py +67 -55
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +38 -4
- mindspore/dataset/engine/datasets.py +350 -322
- mindspore/dataset/engine/datasets_user_defined.py +70 -24
- mindspore/dataset/engine/iterators.py +2 -2
- mindspore/dataset/engine/obs/config_loader.py +2 -2
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +8 -0
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/py_transforms.py +7 -3
- mindspore/dataset/transforms/transforms.py +10 -6
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +17 -5
- mindspore/dataset/vision/utils.py +632 -21
- mindspore/dataset/vision/validators.py +1 -0
- mindspore/device_context/ascend/device.py +1 -1
- mindspore/device_context/ascend/op_tuning.py +35 -1
- mindspore/device_context/gpu/__init__.py +2 -2
- mindspore/device_context/gpu/device.py +1 -1
- mindspore/device_context/gpu/op_precision.py +4 -2
- mindspore/device_context/gpu/op_tuning.py +6 -3
- mindspore/device_manager.py +16 -9
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +3 -4
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/optim/adadelta.py +13 -20
- mindspore/experimental/optim/adagrad.py +15 -22
- mindspore/experimental/optim/adam.py +17 -24
- mindspore/experimental/optim/adamax.py +14 -22
- mindspore/experimental/optim/adamw.py +28 -34
- mindspore/experimental/optim/asgd.py +15 -25
- mindspore/experimental/optim/lr_scheduler.py +27 -45
- mindspore/experimental/optim/nadam.py +14 -24
- mindspore/experimental/optim/optimizer.py +13 -23
- mindspore/experimental/optim/radam.py +18 -24
- mindspore/experimental/optim/rmsprop.py +14 -25
- mindspore/experimental/optim/rprop.py +15 -26
- mindspore/experimental/optim/sgd.py +9 -19
- mindspore/hal/__init__.py +4 -4
- mindspore/hal/contiguous_tensors_handle.py +2 -2
- mindspore/hal/memory.py +27 -7
- mindspore/include/api/cell.h +65 -5
- mindspore/include/api/cfg.h +24 -7
- mindspore/include/api/context.h +1 -0
- mindspore/include/api/delegate.h +10 -2
- mindspore/include/api/dual_abi_helper.h +100 -19
- mindspore/include/api/graph.h +14 -1
- mindspore/include/api/kernel.h +16 -3
- mindspore/include/api/kernel_api.h +9 -1
- mindspore/include/api/metrics/accuracy.h +9 -0
- mindspore/include/api/model.h +8 -1
- mindspore/include/api/model_group.h +4 -0
- mindspore/include/api/model_parallel_runner.h +2 -0
- mindspore/include/api/status.h +48 -10
- mindspore/include/api/types.h +8 -3
- mindspore/include/c_api/model_c.h +0 -58
- mindspore/include/c_api/tensor_c.h +0 -26
- mindspore/include/dataset/constants.h +9 -0
- mindspore/include/dataset/vision_ascend.h +1 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar10.py +61 -11
- mindspore/mindrecord/tools/cifar10_to_mr.py +5 -0
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mindspore_ops_host.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +6 -46
- mindspore/mint/distributed/__init__.py +5 -0
- mindspore/mint/distributed/distributed.py +429 -23
- mindspore/mint/nn/__init__.py +1 -1
- mindspore/mint/nn/functional.py +53 -6
- mindspore/mint/nn/layer/_functions.py +163 -294
- mindspore/mint/nn/layer/activation.py +8 -6
- mindspore/mint/nn/layer/conv.py +140 -104
- mindspore/mint/nn/layer/normalization.py +11 -25
- mindspore/mint/optim/adam.py +19 -18
- mindspore/mint/optim/adamw.py +14 -8
- mindspore/mint/optim/sgd.py +5 -5
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/cell.py +491 -623
- mindspore/nn/grad/cell_grad.py +11 -12
- mindspore/nn/layer/activation.py +36 -36
- mindspore/nn/layer/basic.py +74 -77
- mindspore/nn/layer/channel_shuffle.py +4 -4
- mindspore/nn/layer/combined.py +4 -2
- mindspore/nn/layer/conv.py +117 -110
- mindspore/nn/layer/dense.py +9 -7
- mindspore/nn/layer/embedding.py +50 -52
- mindspore/nn/layer/image.py +38 -40
- mindspore/nn/layer/math.py +111 -112
- mindspore/nn/layer/normalization.py +56 -44
- mindspore/nn/layer/pooling.py +58 -63
- mindspore/nn/layer/rnn_cells.py +33 -33
- mindspore/nn/layer/rnns.py +56 -56
- mindspore/nn/layer/thor_layer.py +74 -73
- mindspore/nn/layer/transformer.py +11 -1
- mindspore/nn/learning_rate_schedule.py +20 -20
- mindspore/nn/loss/loss.py +79 -81
- mindspore/nn/optim/adam.py +4 -6
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -0
- mindspore/nn/optim/lamb.py +1 -3
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +2 -3
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/probability/distribution/exponential.py +2 -1
- mindspore/nn/probability/distribution/poisson.py +2 -1
- mindspore/nn/sparse/sparse.py +3 -3
- mindspore/nn/wrap/cell_wrapper.py +73 -42
- mindspore/nn/wrap/grad_reducer.py +37 -52
- mindspore/nn/wrap/loss_scale.py +72 -74
- mindspore/numpy/array_creations.py +7 -7
- mindspore/numpy/fft.py +1 -1
- mindspore/numpy/math_ops.py +5 -5
- mindspore/numpy/utils_const.py +1 -1
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +51 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +0 -9
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/{experimental/es/__init__.py → ops/_op_impl/cpu/joinedstr_op.py} +12 -6
- mindspore/ops/_vmap/vmap_array_ops.py +31 -13
- mindspore/ops/_vmap/vmap_nn_ops.py +8 -16
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +54 -13
- mindspore/ops/auto_generate/gen_extend_func.py +27 -145
- mindspore/ops/auto_generate/gen_ops_def.py +1027 -347
- mindspore/ops/auto_generate/gen_ops_prim.py +2341 -1117
- mindspore/ops/auto_generate/pyboost_inner_prim.py +31 -1
- mindspore/ops/composite/__init__.py +10 -0
- mindspore/ops/composite/base.py +9 -5
- mindspore/ops/composite/multitype_ops/__init__.py +12 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +133 -109
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/composite/multitype_ops/add_impl.py +70 -2
- mindspore/ops/composite/multitype_ops/div_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +11 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +5 -3
- mindspore/ops/composite/multitype_ops/mul_impl.py +49 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +57 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +34 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +14 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/_add_attr_func.py +11 -6
- mindspore/ops/function/array_func.py +19 -102
- mindspore/ops/function/debug_func.py +8 -5
- mindspore/ops/function/grad/grad_func.py +5 -13
- mindspore/ops/function/math_func.py +77 -572
- mindspore/ops/function/nn_func.py +46 -94
- mindspore/ops/function/other_func.py +4 -1
- mindspore/ops/function/random_func.py +44 -5
- mindspore/ops/function/vmap_func.py +2 -1
- mindspore/ops/functional.py +4 -4
- mindspore/ops/functional_overload.py +594 -18
- mindspore/ops/op_info_register.py +21 -0
- mindspore/ops/operations/__init__.py +16 -11
- mindspore/ops/operations/_custom_ops_utils.py +689 -34
- mindspore/ops/operations/_inner_ops.py +14 -18
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/array_ops.py +5 -51
- mindspore/ops/operations/comm_ops.py +186 -41
- mindspore/ops/operations/custom_ops.py +303 -177
- mindspore/ops/operations/debug_ops.py +59 -4
- mindspore/ops/operations/image_ops.py +13 -13
- mindspore/ops/operations/manually_defined/ops_def.py +27 -28
- mindspore/ops/operations/math_ops.py +8 -9
- mindspore/ops/operations/nn_ops.py +8 -40
- mindspore/ops/primitive.py +9 -20
- mindspore/ops/tensor_method.py +63 -15
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +1 -1
- mindspore/ops_generate/api/functional_map_cpp_generator.py +10 -9
- mindspore/ops_generate/api/functions_cc_generator.py +58 -10
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +1 -1
- mindspore/ops_generate/common/base_generator.py +14 -0
- mindspore/ops_generate/common/gen_constants.py +8 -3
- mindspore/ops_generate/common/gen_utils.py +0 -19
- mindspore/ops_generate/common/op_proto.py +11 -4
- mindspore/ops_generate/common/template.py +88 -11
- mindspore/ops_generate/gen_ops.py +1 -1
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +4 -4
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_name_h_generator.py +0 -3
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +0 -4
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -2
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +49 -8
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +2 -2
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +31 -16
- mindspore/ops_generate/pyboost/op_template_parser.py +98 -72
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +70 -273
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +14 -6
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +316 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +5 -3
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_internal_functions_cpp_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_functions_h_generator.py +76 -0
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +125 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +4 -3
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +348 -61
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_utils.py +118 -9
- mindspore/ops_generate/tensor_py_cc_generator.py +1 -24
- mindspore/parallel/_auto_parallel_context.py +16 -23
- mindspore/parallel/_cell_wrapper.py +113 -45
- mindspore/parallel/_parallel_serialization.py +4 -3
- mindspore/parallel/_ps_context.py +4 -6
- mindspore/parallel/_tensor.py +167 -12
- mindspore/parallel/_transformer/moe.py +1 -1
- mindspore/parallel/_transformer/transformer.py +17 -12
- mindspore/parallel/_utils.py +5 -11
- mindspore/parallel/auto_parallel.py +35 -14
- mindspore/parallel/checkpoint_convert.py +3 -3
- mindspore/parallel/checkpoint_transform.py +13 -7
- mindspore/parallel/cluster/process_entity/_api.py +88 -49
- mindspore/parallel/cluster/process_entity/_utils.py +95 -7
- mindspore/parallel/cluster/run.py +48 -7
- mindspore/parallel/function/__init__.py +8 -1
- mindspore/parallel/function/reshard_func.py +12 -12
- mindspore/parallel/nn/__init__.py +15 -2
- mindspore/parallel/nn/parallel_cell_wrapper.py +50 -14
- mindspore/parallel/nn/parallel_grad_reducer.py +7 -14
- mindspore/parallel/shard.py +10 -25
- mindspore/parallel/transform_safetensors.py +469 -174
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +7 -7
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +3 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +12 -6
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +4 -4
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +3 -3
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +4 -1
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +2 -1
- mindspore/profiler/analysis/task_manager.py +1 -1
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +5 -1
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +2 -1
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +10 -9
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +43 -23
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +3 -2
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +9 -5
- mindspore/profiler/analysis/viewer/ms_operator_details_viewer.py +132 -0
- mindspore/profiler/common/constant.py +16 -0
- mindspore/profiler/common/msprof_cmd_tool.py +2 -2
- mindspore/profiler/common/path_manager.py +9 -0
- mindspore/profiler/common/profiler_context.py +50 -29
- mindspore/profiler/common/profiler_info.py +0 -16
- mindspore/profiler/common/profiler_meta_data.py +1 -0
- mindspore/profiler/common/profiler_op_analyse.py +239 -0
- mindspore/profiler/common/profiler_output_path.py +23 -8
- mindspore/profiler/common/profiler_parameters.py +128 -35
- mindspore/profiler/dynamic_profile/__init__.py +0 -0
- mindspore/profiler/dynamic_profile/dynamic_monitor_proxy.py +39 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_config_context.py +666 -0
- mindspore/profiler/dynamic_profile/dynamic_profiler_utils.py +62 -0
- mindspore/profiler/dynamic_profiler.py +374 -338
- mindspore/profiler/envprofiler.py +42 -12
- mindspore/profiler/experimental_config.py +112 -7
- mindspore/profiler/mstx.py +33 -12
- mindspore/profiler/platform/__init__.py +2 -3
- mindspore/profiler/platform/cpu_profiler.py +10 -4
- mindspore/profiler/platform/npu_profiler.py +30 -20
- mindspore/profiler/profiler.py +218 -154
- mindspore/profiler/profiler_action_controller.py +65 -77
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/profiler/schedule.py +10 -4
- mindspore/rewrite/common/config.py +1 -0
- mindspore/rewrite/common/namer.py +1 -0
- mindspore/rewrite/common/namespace.py +1 -0
- mindspore/rewrite/node/node.py +31 -11
- mindspore/rewrite/parsers/assign_parser.py +1 -1
- mindspore/rewrite/symbol_tree/symbol_tree.py +2 -2
- mindspore/run_check/_check_version.py +7 -10
- mindspore/runtime/__init__.py +8 -6
- mindspore/runtime/event.py +10 -4
- mindspore/runtime/executor.py +87 -45
- mindspore/runtime/memory.py +31 -32
- mindspore/runtime/thread_bind_core.py +299 -165
- mindspore/safeguard/rewrite_obfuscation.py +12 -13
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +17 -7
- mindspore/train/amp.py +43 -23
- mindspore/train/callback/__init__.py +5 -5
- mindspore/train/callback/_callback.py +2 -1
- mindspore/train/callback/_checkpoint.py +4 -14
- mindspore/train/callback/_flops_collector.py +11 -7
- mindspore/train/callback/_landscape.py +0 -1
- mindspore/train/callback/_train_fault_tolerance.py +98 -21
- mindspore/train/data_sink.py +15 -6
- mindspore/train/dataset_helper.py +14 -5
- mindspore/train/model.py +133 -69
- mindspore/train/serialization.py +168 -126
- mindspore/train/summary/summary_record.py +13 -2
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +3 -2
- mindspore/utils/dryrun.py +0 -6
- mindspore/utils/runtime_execution_order_check.py +163 -77
- mindspore/utils/sdc_detect.py +68 -0
- mindspore/utils/utils.py +14 -17
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/METADATA +5 -4
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/RECORD +403 -442
- mindspore/_deprecated/jit.py +0 -198
- mindspore/_extends/remote/kernel_build_server_ascend.py +0 -75
- mindspore/communication/_hccl_management.py +0 -297
- mindspore/experimental/es/embedding_service.py +0 -891
- mindspore/experimental/es/embedding_service_layer.py +0 -581
- mindspore/profiler/common/validator/__init__.py +0 -14
- mindspore/profiler/common/validator/validate_path.py +0 -84
- mindspore/profiler/parser/__init__.py +0 -14
- mindspore/profiler/parser/aicpu_data_parser.py +0 -272
- mindspore/profiler/parser/ascend_analysis/__init__.py +0 -14
- mindspore/profiler/parser/ascend_analysis/constant.py +0 -71
- mindspore/profiler/parser/ascend_analysis/file_manager.py +0 -180
- mindspore/profiler/parser/ascend_analysis/function_event.py +0 -185
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +0 -136
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +0 -131
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +0 -104
- mindspore/profiler/parser/ascend_analysis/path_manager.py +0 -313
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +0 -123
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +0 -86
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +0 -75
- mindspore/profiler/parser/ascend_cluster_generator.py +0 -116
- mindspore/profiler/parser/ascend_communicate_generator.py +0 -314
- mindspore/profiler/parser/ascend_flops_generator.py +0 -116
- mindspore/profiler/parser/ascend_fpbp_generator.py +0 -82
- mindspore/profiler/parser/ascend_hccl_generator.py +0 -271
- mindspore/profiler/parser/ascend_integrate_generator.py +0 -42
- mindspore/profiler/parser/ascend_memory_generator.py +0 -185
- mindspore/profiler/parser/ascend_msprof_exporter.py +0 -282
- mindspore/profiler/parser/ascend_msprof_generator.py +0 -187
- mindspore/profiler/parser/ascend_op_generator.py +0 -334
- mindspore/profiler/parser/ascend_steptrace_generator.py +0 -94
- mindspore/profiler/parser/ascend_timeline_generator.py +0 -545
- mindspore/profiler/parser/base_timeline_generator.py +0 -483
- mindspore/profiler/parser/container.py +0 -229
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +0 -697
- mindspore/profiler/parser/flops_parser.py +0 -531
- mindspore/profiler/parser/framework_enum.py +0 -111
- mindspore/profiler/parser/framework_parser.py +0 -464
- mindspore/profiler/parser/framework_struct.py +0 -61
- mindspore/profiler/parser/gpu_analysis/__init__.py +0 -14
- mindspore/profiler/parser/gpu_analysis/function_event.py +0 -44
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +0 -89
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +0 -72
- mindspore/profiler/parser/hccl_parser.py +0 -573
- mindspore/profiler/parser/hwts_log_parser.py +0 -122
- mindspore/profiler/parser/integrator.py +0 -526
- mindspore/profiler/parser/memory_usage_parser.py +0 -277
- mindspore/profiler/parser/minddata_analyzer.py +0 -800
- mindspore/profiler/parser/minddata_parser.py +0 -186
- mindspore/profiler/parser/minddata_pipeline_parser.py +0 -299
- mindspore/profiler/parser/op_intermediate_parser.py +0 -149
- mindspore/profiler/parser/optime_parser.py +0 -250
- mindspore/profiler/parser/profiler_info.py +0 -213
- mindspore/profiler/parser/step_trace_parser.py +0 -666
- mindspore/utils/hooks.py +0 -81
- /mindspore/common/{_auto_dynamic.py → dynamic_shape/_auto_dynamic.py} +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/WHEEL +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.6.0rc1.dist-info → mindspore-2.7.0.dist-info}/top_level.txt +0 -0
mindspore/train/model.py
CHANGED
|
@@ -57,8 +57,10 @@ from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_trai
|
|
|
57
57
|
from mindspore.train import amp
|
|
58
58
|
from mindspore._c_expression import _framework_profiler_step_start, _framework_profiler_step_end
|
|
59
59
|
from mindspore._c_expression import _get_optimzer_timestamps
|
|
60
|
+
from mindspore._c_expression import clean_tdt_channel, _clean_rootinfo
|
|
60
61
|
|
|
61
62
|
from mindspore.parallel._utils import _init_auto_parallel_context, _clear_auto_parallel_context
|
|
63
|
+
from .serialization import load_param_into_net
|
|
62
64
|
|
|
63
65
|
def _transfer_tensor_to_tuple(inputs):
|
|
64
66
|
"""
|
|
@@ -130,7 +132,8 @@ def _handle_exception_info(obj, uce_env, tft, e):
|
|
|
130
132
|
if not uce_env:
|
|
131
133
|
logger.error("uce wrapper caught RuntimeError but uce not enable, enter MindIO TTP process.",
|
|
132
134
|
exc_info=True)
|
|
133
|
-
tft
|
|
135
|
+
if tft:
|
|
136
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
134
137
|
raise e
|
|
135
138
|
e_str = str(e)
|
|
136
139
|
logger.warning("uce wrapper caught RuntimeError e_str:{}".format(e_str))
|
|
@@ -151,6 +154,9 @@ def _handle_exception_info(obj, uce_env, tft, e):
|
|
|
151
154
|
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
152
155
|
raise e
|
|
153
156
|
tft.tft_report_error(tft.ReportState.RS_UCE.value)
|
|
157
|
+
elif "HCCEError" in e_str:
|
|
158
|
+
logger.warning("uce wrapper caught HCCEError")
|
|
159
|
+
tft.tft_report_error(tft.ReportState.RS_HCCL_FAILED.value)
|
|
154
160
|
elif "ForceStopError" in e_str:
|
|
155
161
|
logger.warning("uce wrapper caught RuntimeError ForceStopError")
|
|
156
162
|
force_stop_err = tft.ReportState.RS_NORMAL.value
|
|
@@ -165,6 +171,69 @@ def _handle_exception_info(obj, uce_env, tft, e):
|
|
|
165
171
|
raise e
|
|
166
172
|
|
|
167
173
|
|
|
174
|
+
def _handle_training_result_error(model, tft_obj):
|
|
175
|
+
"""
|
|
176
|
+
Handle training result error for resuming training.
|
|
177
|
+
"""
|
|
178
|
+
ckpt_load_fn = tft_obj.ckpt_load_func
|
|
179
|
+
train_network = tft_obj.cb_params.train_network
|
|
180
|
+
logger.warning("Process training result error start.")
|
|
181
|
+
# 1. Clear tdt channel
|
|
182
|
+
logger.warning("Clean tdt channel.")
|
|
183
|
+
clean_tdt_channel()
|
|
184
|
+
|
|
185
|
+
# 2. Load checkpoint
|
|
186
|
+
logger.warning("Load checkpoint.")
|
|
187
|
+
new_param_dict, remove_redundancy = ckpt_load_fn()
|
|
188
|
+
param_not_load, ckpt_not_load = load_param_into_net(train_network, new_param_dict, True, remove_redundancy)
|
|
189
|
+
logger.warning(f"param_not_load: {param_not_load}")
|
|
190
|
+
logger.warning(f"ckpt_not_load: {ckpt_not_load}")
|
|
191
|
+
resume_epoch = new_param_dict.get('epoch_num')
|
|
192
|
+
resume_step = new_param_dict.get('step_num')
|
|
193
|
+
model._initial_step = int(resume_step.asnumpy())
|
|
194
|
+
logger.warning("Process training result error end.")
|
|
195
|
+
return (resume_epoch, resume_step)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _calc_cb_initial_step(org_epoch, org_step, *args, **kwargs):
|
|
199
|
+
"""calculate initial step for callback"""
|
|
200
|
+
train_dataset = args[1]
|
|
201
|
+
dataset_sink_mode = args[3] if len(args) > 3 else kwargs.get('dataset_sink_mode', True)
|
|
202
|
+
sink_size = args[4] if len(args) > 4 else kwargs.get('sink_size', -1)
|
|
203
|
+
|
|
204
|
+
cb_initial_step = 0
|
|
205
|
+
if dataset_sink_mode:
|
|
206
|
+
train_dataset.set_init_step(org_epoch)
|
|
207
|
+
dataset_size = train_dataset.get_dataset_size()
|
|
208
|
+
if sink_size != -1:
|
|
209
|
+
cb_initial_step = org_epoch * sink_size + org_step
|
|
210
|
+
else:
|
|
211
|
+
cb_initial_step = org_epoch * dataset_size + org_step
|
|
212
|
+
else:
|
|
213
|
+
train_dataset.set_init_step(org_step)
|
|
214
|
+
cb_initial_step = org_step
|
|
215
|
+
if hasattr(train_dataset, '_dataset_helper'):
|
|
216
|
+
dataset_helper = train_dataset._dataset_helper
|
|
217
|
+
_reset_training_dataset(cb_initial_step, dataset_helper.iter.dataset.get_dataset_size())
|
|
218
|
+
return cb_initial_step
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _update_ckpt_callback_info(resume_train_step, **kwargs):
|
|
222
|
+
"""
|
|
223
|
+
Update checkpoint callback internal state
|
|
224
|
+
"""
|
|
225
|
+
ckpt_obj = None
|
|
226
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), ModelCheckpoint):
|
|
227
|
+
ckpt_obj = kwargs.get('callbacks')
|
|
228
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
|
229
|
+
for item in kwargs.get('callbacks'):
|
|
230
|
+
if isinstance(item, ModelCheckpoint):
|
|
231
|
+
ckpt_obj = item
|
|
232
|
+
if ckpt_obj is not None:
|
|
233
|
+
ckpt_obj._last_triggered_step = 0
|
|
234
|
+
ckpt_obj._append_step_num = resume_train_step
|
|
235
|
+
|
|
236
|
+
|
|
168
237
|
def _handle_tft(func):
|
|
169
238
|
"""
|
|
170
239
|
Decorator function, which starts uce handle process when an exception occurs during training.
|
|
@@ -180,42 +249,34 @@ def _handle_tft(func):
|
|
|
180
249
|
if isinstance(item, TrainFaultTolerance):
|
|
181
250
|
obj = item
|
|
182
251
|
if obj:
|
|
183
|
-
tft = obj.tft
|
|
184
252
|
tft_env = os.getenv("MS_ENABLE_TFT", "")
|
|
185
|
-
uce_env = "UCE:1" in tft_env or "ARF:1" in tft_env
|
|
253
|
+
uce_env = "UCE:1" in tft_env or "ARF:1" in tft_env or "HCCE:1" in tft_env
|
|
254
|
+
tre_env = "TRE:1" in tft_env
|
|
186
255
|
while True:
|
|
187
256
|
try:
|
|
188
257
|
return func(self, *args, **kwargs)
|
|
189
258
|
except RuntimeError as e:
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
259
|
+
if tre_env and 'TREError' in str(e):
|
|
260
|
+
_, resume_step = _handle_training_result_error(self, obj)
|
|
261
|
+
repair_step = int(resume_step.asnumpy())
|
|
262
|
+
_update_ckpt_callback_info(repair_step, **kwargs)
|
|
263
|
+
logger.warning(f'Resume training after TREError from step {repair_step}.')
|
|
264
|
+
else:
|
|
265
|
+
_handle_exception_info(obj, uce_env, obj.tft, e)
|
|
266
|
+
ret = obj.tft.tft_wait_next_action()
|
|
267
|
+
if ret == obj.tft.Action.EXIT.value:
|
|
268
|
+
raise e
|
|
269
|
+
repair_step = obj.tft.tft_get_repair_step()
|
|
270
|
+
logger.warning(
|
|
271
|
+
"uce wrapper caught repair finish REPAIR STEP: {} batch_num:{}".format(repair_step,
|
|
272
|
+
self.batch_num))
|
|
198
273
|
initial_epoch = int(repair_step / self.batch_num)
|
|
199
274
|
initial_step = repair_step % self.batch_num
|
|
200
275
|
kwargs["initial_epoch"] = initial_epoch
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
cb_initial_step = 0
|
|
207
|
-
if dataset_sink_mode:
|
|
208
|
-
train_dataset.set_init_step(initial_epoch)
|
|
209
|
-
dataset_size = train_dataset.get_dataset_size()
|
|
210
|
-
if sink_size != -1:
|
|
211
|
-
cb_initial_step = initial_epoch * sink_size + initial_step
|
|
212
|
-
else:
|
|
213
|
-
cb_initial_step = initial_epoch * dataset_size + initial_step
|
|
214
|
-
else:
|
|
215
|
-
train_dataset.set_init_step(initial_step)
|
|
216
|
-
cb_initial_step = initial_step
|
|
217
|
-
|
|
218
|
-
kwargs["initial_step"] = cb_initial_step
|
|
276
|
+
cb_initial_step = _calc_cb_initial_step(initial_epoch, initial_step, *args, **kwargs)
|
|
277
|
+
if not self.enable_tre:
|
|
278
|
+
kwargs["initial_step"] = cb_initial_step
|
|
279
|
+
self._initial_step = 0
|
|
219
280
|
# reset all accu grads to zero
|
|
220
281
|
obj._reset_acc_grads()
|
|
221
282
|
logger.warning(
|
|
@@ -223,8 +284,9 @@ def _handle_tft(func):
|
|
|
223
284
|
cb_initial_step))
|
|
224
285
|
continue
|
|
225
286
|
except BaseException as e:
|
|
226
|
-
|
|
227
|
-
|
|
287
|
+
if obj.tft:
|
|
288
|
+
logger.error("uce wrapper caught BaseException error, enter MindIO TTP process.", exc_info=True)
|
|
289
|
+
obj.tft.tft_report_error(obj.tft.ReportState.RS_UNKNOWN.value)
|
|
228
290
|
raise e
|
|
229
291
|
else:
|
|
230
292
|
return func(self, *args, **kwargs)
|
|
@@ -241,9 +303,6 @@ def _check_tft():
|
|
|
241
303
|
ascend_target = MSContext.get_instance().get_ascend_soc_version()
|
|
242
304
|
if ascend_target == 'ascend910':
|
|
243
305
|
raise ValueError("TFT is not supported when using ascend910")
|
|
244
|
-
ms_mode = context.get_context("mode")
|
|
245
|
-
if ms_mode != mindspore.GRAPH_MODE:
|
|
246
|
-
raise ValueError("TFT is only supported in GRAPH_MODE")
|
|
247
306
|
jit_level = context.get_context("jit_level")
|
|
248
307
|
if jit_level == "O2" and ("UCE:1" in tft_env or "ARF:1" in tft_env):
|
|
249
308
|
raise ValueError("TFT is not supported when using jit_level == O2")
|
|
@@ -384,6 +443,11 @@ def _set_with_processed_inputs(network, inputs):
|
|
|
384
443
|
"Reset inputs from a process inputs, should be a list/tuple or a dict, but got %s!" % str(inputs))
|
|
385
444
|
|
|
386
445
|
|
|
446
|
+
def _check_tft_reset_dataset():
|
|
447
|
+
env_tft = os.getenv("MS_ENABLE_TFT", "")
|
|
448
|
+
return any([v in env_tft for v in ["TRE:1", "UCE:1", "HCCE:1", "ARF:1"]])
|
|
449
|
+
|
|
450
|
+
|
|
387
451
|
class Model:
|
|
388
452
|
"""
|
|
389
453
|
High-Level API for training or inference.
|
|
@@ -501,6 +565,10 @@ class Model:
|
|
|
501
565
|
self._lite_infer = True # if backend lite infer fails, set False
|
|
502
566
|
self._mindspore_lite_model_group_id = id(self) & 0xFFFF
|
|
503
567
|
self.batch_num = -1
|
|
568
|
+
self.enable_tre = "TRE:1" in os.getenv("MS_ENABLE_TFT", "")
|
|
569
|
+
self.enable_hcce = "HCCE:1" in os.getenv("MS_ENABLE_TFT", "")
|
|
570
|
+
self._initial_step = None
|
|
571
|
+
self._need_reset_data = _check_tft_reset_dataset()
|
|
504
572
|
_clear_auto_parallel_context(self._network)
|
|
505
573
|
|
|
506
574
|
def _check_for_graph_cell(self, kwargs):
|
|
@@ -700,7 +768,7 @@ class Model:
|
|
|
700
768
|
logger.info("Begin to connect network with dataset.")
|
|
701
769
|
network = connect_network_with_dataset(network, dataset_helper)
|
|
702
770
|
|
|
703
|
-
if _get_recovery_context("enable_recovery") and is_train:
|
|
771
|
+
if (_get_recovery_context("enable_recovery") or self._need_reset_data) and is_train:
|
|
704
772
|
_set_training_dataset(dataset_helper)
|
|
705
773
|
|
|
706
774
|
network.set_train(is_train)
|
|
@@ -744,7 +812,7 @@ class Model:
|
|
|
744
812
|
"""
|
|
745
813
|
if os.environ.get("MS_ENABLE_CKPT_D2H_ASYNC") != "1":
|
|
746
814
|
return
|
|
747
|
-
if
|
|
815
|
+
if context.get_context("device_target") == "Ascend":
|
|
748
816
|
cb_params.need_ckpt, cb_params.save_checkpoint_steps, \
|
|
749
817
|
cb_params.last_triggered_step = self._check_need_ckpt(cb_params.list_callback)
|
|
750
818
|
logger.info(f"need_ckpt:{cb_params.need_ckpt},"
|
|
@@ -812,8 +880,8 @@ class Model:
|
|
|
812
880
|
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
813
881
|
epoch (int): Total number of iterations on the data. Default: 1.
|
|
814
882
|
"""
|
|
815
|
-
if context.get_context("
|
|
816
|
-
raise RuntimeError('Pre-init process only supports
|
|
883
|
+
if context.get_context("device_target") != "Ascend":
|
|
884
|
+
raise RuntimeError('Pre-init process only supports Ascend target currently.')
|
|
817
885
|
|
|
818
886
|
if not train_dataset and not valid_dataset:
|
|
819
887
|
raise ValueError("The argument 'train_dataset' and 'valid_dataset' can not both be None or empty.")
|
|
@@ -957,6 +1025,7 @@ class Model:
|
|
|
957
1025
|
cb_params.latest_ckpt_file = None
|
|
958
1026
|
cb_params.loss_scale_mananger = self._loss_scale_manager
|
|
959
1027
|
cb_params.is_arf = _get_recovery_context("is_arf")
|
|
1028
|
+
cb_params.initial_step = self._initial_step
|
|
960
1029
|
|
|
961
1030
|
# build callback list
|
|
962
1031
|
with _CallbackManager(callbacks) as list_callback:
|
|
@@ -995,7 +1064,7 @@ class Model:
|
|
|
995
1064
|
initial_epoch (int): Epoch at which to start train, it used for resuming a previous training run.
|
|
996
1065
|
Default: 0.
|
|
997
1066
|
"""
|
|
998
|
-
is_graph =
|
|
1067
|
+
is_graph = context.get_context("mode") == context.GRAPH_MODE
|
|
999
1068
|
dataset_size = train_dataset.get_dataset_size()
|
|
1000
1069
|
if dataset_size % sink_size != 0:
|
|
1001
1070
|
logger.info("In dataset_sink mode (dataset_size % sink_size) should equal to 0, "
|
|
@@ -1064,6 +1133,7 @@ class Model:
|
|
|
1064
1133
|
if cb_params.is_arf:
|
|
1065
1134
|
cb_params.is_arf = False
|
|
1066
1135
|
_set_recovery_context(is_arf=False)
|
|
1136
|
+
_clean_rootinfo()
|
|
1067
1137
|
|
|
1068
1138
|
# Embedding cache server only run one step.
|
|
1069
1139
|
if is_embedding_cache_server:
|
|
@@ -1142,8 +1212,6 @@ class Model:
|
|
|
1142
1212
|
if not enable_recovery:
|
|
1143
1213
|
self.enable_recovery = False
|
|
1144
1214
|
else:
|
|
1145
|
-
if context.get_context("mode") != context.GRAPH_MODE:
|
|
1146
|
-
raise RuntimeError("Recovery for training only support graph mode currently.")
|
|
1147
1215
|
self.enable_recovery = enable_recovery and _is_role_worker()
|
|
1148
1216
|
|
|
1149
1217
|
def _check_need_load_ckpt(self, cb_params, dataset_size, sink_size=-1):
|
|
@@ -1278,6 +1346,7 @@ class Model:
|
|
|
1278
1346
|
if cb_params.is_arf:
|
|
1279
1347
|
cb_params.is_arf = False
|
|
1280
1348
|
_set_recovery_context(is_arf=False)
|
|
1349
|
+
_clean_rootinfo()
|
|
1281
1350
|
# Embedding cache server only run one step.
|
|
1282
1351
|
if is_embedding_cache_server:
|
|
1283
1352
|
break
|
|
@@ -2120,9 +2189,6 @@ class Model:
|
|
|
2120
2189
|
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
|
|
2121
2190
|
sink_size (int): Control the amount of data in each sink.
|
|
2122
2191
|
"""
|
|
2123
|
-
if context.get_context("mode") != context.GRAPH_MODE:
|
|
2124
|
-
raise RuntimeError("Pre-compile process that generate parameter layout for the train network "
|
|
2125
|
-
"only supports GRAPH MODE and Ascend target currently.")
|
|
2126
2192
|
if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
|
2127
2193
|
raise RuntimeError("'infer_train_layout' only supports 'semi_auto_parallel' and 'auto_parallel' "
|
|
2128
2194
|
"mode, but got {}.".format(_get_parallel_mode()))
|
|
@@ -2241,6 +2307,7 @@ class Model:
|
|
|
2241
2307
|
|
|
2242
2308
|
Examples:
|
|
2243
2309
|
>>> import numpy as np
|
|
2310
|
+
>>> import mindspore as ms
|
|
2244
2311
|
>>> import mindspore.nn as nn
|
|
2245
2312
|
>>> from mindspore import Tensor
|
|
2246
2313
|
>>> from mindspore.train import Model
|
|
@@ -2250,28 +2317,28 @@ class Model:
|
|
|
2250
2317
|
>>> from mindspore.parallel.auto_parallel import AutoParallel
|
|
2251
2318
|
>>>
|
|
2252
2319
|
>>> class Net(nn.Cell):
|
|
2253
|
-
|
|
2254
|
-
|
|
2255
|
-
|
|
2256
|
-
|
|
2257
|
-
|
|
2258
|
-
|
|
2259
|
-
|
|
2260
|
-
|
|
2261
|
-
|
|
2262
|
-
|
|
2263
|
-
|
|
2264
|
-
|
|
2265
|
-
|
|
2266
|
-
|
|
2267
|
-
|
|
2268
|
-
|
|
2269
|
-
|
|
2270
|
-
|
|
2271
|
-
|
|
2272
|
-
|
|
2273
|
-
|
|
2274
|
-
|
|
2320
|
+
... def __init__(self):
|
|
2321
|
+
... super(Net, self).__init__()
|
|
2322
|
+
... self.fc1 = nn.Dense(128, 768, activation='relu')
|
|
2323
|
+
... self.fc2 = nn.Dense(128, 768, activation='relu')
|
|
2324
|
+
... self.fc3 = nn.Dense(128, 768, activation='relu')
|
|
2325
|
+
... self.fc4 = nn.Dense(768, 768, activation='relu')
|
|
2326
|
+
... self.relu4 = nn.ReLU()
|
|
2327
|
+
... self.relu5 = nn.ReLU()
|
|
2328
|
+
... self.transpose = P.Transpose()
|
|
2329
|
+
... self.matmul1 = P.MatMul()
|
|
2330
|
+
... self.matmul2 = P.MatMul()
|
|
2331
|
+
...
|
|
2332
|
+
... def construct(self, x):
|
|
2333
|
+
... q = self.fc1(x)
|
|
2334
|
+
... k = self.fc2(x)
|
|
2335
|
+
... v = self.fc3(x)
|
|
2336
|
+
... k = self.transpose(k, (1, 0))
|
|
2337
|
+
... c = self.relu4(self.matmul1(q, k))
|
|
2338
|
+
... s = self.relu5(self.matmul2(c, v))
|
|
2339
|
+
... s = self.fc4(s)
|
|
2340
|
+
... return s
|
|
2341
|
+
...
|
|
2275
2342
|
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
2276
2343
|
>>> init()
|
|
2277
2344
|
>>> inputs = Tensor(np.ones([32, 128]).astype(np.float32))
|
|
@@ -2281,9 +2348,6 @@ class Model:
|
|
|
2281
2348
|
>>> predict_map = model.infer_predict_layout(inputs)
|
|
2282
2349
|
"""
|
|
2283
2350
|
_init_auto_parallel_context(self._network)
|
|
2284
|
-
if context.get_context("mode") != context.GRAPH_MODE:
|
|
2285
|
-
raise RuntimeError("Pre-compile process that generate parameter layout for the predict network "
|
|
2286
|
-
"only supports GRAPH MODE and Ascend target currently.")
|
|
2287
2351
|
if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
|
2288
2352
|
raise RuntimeError('Infer predict layout only supports semi auto parallel and auto parallel mode.')
|
|
2289
2353
|
_parallel_predict_check()
|