mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.0__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +46 -13
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +209 -29
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +310 -55
- mindspore/communication/management.py +14 -14
- mindspore/context.py +123 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/dnnl.dll +0 -0
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +495 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +266 -21
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +28 -7
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +275 -93
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +113 -3
- mindspore/nn/layer/embedding.py +120 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +734 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
- mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +490 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +558 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +184 -8
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +6 -1
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +12 -146
- mindspore/ops/operations/comm_ops.py +42 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +265 -10
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +28 -8
- mindspore/parallel/_cell_wrapper.py +83 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +81 -11
- mindspore/parallel/_utils.py +13 -1
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +280 -412
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +36 -103
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +28 -2
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +85 -22
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +134 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/dataset_helper.py +7 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +134 -58
- mindspore/train/serialization.py +336 -112
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +258 -252
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
mindspore/train/mind_ir_pb2.py
CHANGED
|
@@ -20,7 +20,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
|
|
|
20
20
|
syntax='proto2',
|
|
21
21
|
serialized_options=None,
|
|
22
22
|
create_key=_descriptor._internal_create_key,
|
|
23
|
-
serialized_pb=b'\n\rmind_ir.proto\x12\x07mind_ir\"\x88\t\n\x0e\x41ttributeProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\t\n\x01\x66\x18\x02 \x01(\x02\x12\t\n\x01i\x18\x03 \x01(\x03\x12\t\n\x01\x64\x18\x04 \x01(\x01\x12\t\n\x01s\x18\x05 \x01(\x0c\x12\x1f\n\x01t\x18\x06 \x01(\x0b\x32\x14.mind_ir.TensorProto\x12\x1e\n\x01g\x18\x07 \x01(\x0b\x32\x13.mind_ir.GraphProto\x12\x0e\n\x06\x66loats\x18\x08 \x03(\x02\x12\x0f\n\x07\x64oubles\x18\t \x03(\x01\x12\x0c\n\x04ints\x18\n \x03(\x03\x12\x0f\n\x07strings\x18\x0b \x03(\x0c\x12%\n\x07tensors\x18\x0c \x03(\x0b\x32\x14.mind_ir.TensorProto\x12#\n\x06graphs\x18\r \x03(\x0b\x32\x13.mind_ir.GraphProto\x12\x12\n\ndoc_string\x18\x0e \x01(\t\x12\x15\n\rref_attr_name\x18\x0f \x01(\t\x12\x33\n\x04type\x18\x10 \x01(\x0e\x32%.mind_ir.AttributeProto.AttributeType\x12\'\n\x06values\x18\x11 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x36\n\x08seq_info\x18\x12 \x01(\x0b\x32$.mind_ir.AttributeProto.SeqInfoProto\x12&\n\x07\x66unctor\x18\x13 \x01(\x0b\x32\x15.mind_ir.FunctorProto\x1aT\n\x0cSeqInfoProto\x12\x12\n\nis_dyn_len\x18\x01 \x01(\x08\x12\x30\n\x0ftuple_elem_item\x18\x02 \x01(\x0b\x32\x17.mind_ir.AttributeProto\"\xaf\x04\n\rAttributeType\x12\r\n\tUNDEFINED\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\t\n\x05UINT8\x10\x02\x12\x08\n\x04INT8\x10\x03\x12\n\n\x06UINT16\x10\x04\x12\t\n\x05INT16\x10\x05\x12\t\n\x05INT32\x10\x06\x12\t\n\x05INT64\x10\x07\x12\n\n\x06STRING\x10\x08\x12\x08\n\x04\x42OOL\x10\t\x12\x0b\n\x07\x46LOAT16\x10\n\x12\n\n\x06\x44OUBLE\x10\x0b\x12\n\n\x06UINT32\x10\x0c\x12\n\n\x06UINT64\x10\r\x12\r\n\tCOMPLEX64\x10\x0e\x12\x0e\n\nCOMPLEX128\x10\x0f\x12\x0c\n\x08\x42\x46LOAT16\x10\x10\x12\n\n\x06TENSOR\x10\x11\x12\t\n\x05GRAPH\x10\x12\x12\x0b\n\x07TENSORS\x10\x13\x12\t\n\x05TUPLE\x10\x14\x12\x08\n\x04LIST\x10\x15\x12\x08\n\x04\x44ICT\x10\x16\x12\n\n\x06UMONAD\x10\x17\x12\x0b\n\x07IOMONAD\x10\x18\x12\x08\n\x04NONE\x10\x19\x12\x14\n\x10PRIMITIVECLOSURE\x10\x1a\x12\x14\n\x10\x46UNCGRAPHCLOSURE\x10\x1b\x12\x12\n\x0ePARTIALCLOSURE\x10\x1c\x12\x14\n\x10UNIONFUNCCLOSURE\x10\x1d\x12\x0e\n\nCSR_TENSOR\x10\x1e\x12\x0e\n\nCOO_TENSOR\x10\x1f\x12\x0e\n\nROW_TENSOR\x10 \x12\x0e\n\nCLASS_TYPE\x10!\x12\x0e\n\nNAME_SPACE\x10\"\x12\n\n\x06SYMBOL\x10#\x12\r\n\tTYPE_NULL\x10$\x12\x0e\n\nMAP_TENSOR\x10%\x12\x0b\n\x07\x46UNCTOR\x10&\x12\n\n\x06SCALAR\x10\'\"\
|
|
23
|
+
serialized_pb=b'\n\rmind_ir.proto\x12\x07mind_ir\"\x88\t\n\x0e\x41ttributeProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\t\n\x01\x66\x18\x02 \x01(\x02\x12\t\n\x01i\x18\x03 \x01(\x03\x12\t\n\x01\x64\x18\x04 \x01(\x01\x12\t\n\x01s\x18\x05 \x01(\x0c\x12\x1f\n\x01t\x18\x06 \x01(\x0b\x32\x14.mind_ir.TensorProto\x12\x1e\n\x01g\x18\x07 \x01(\x0b\x32\x13.mind_ir.GraphProto\x12\x0e\n\x06\x66loats\x18\x08 \x03(\x02\x12\x0f\n\x07\x64oubles\x18\t \x03(\x01\x12\x0c\n\x04ints\x18\n \x03(\x03\x12\x0f\n\x07strings\x18\x0b \x03(\x0c\x12%\n\x07tensors\x18\x0c \x03(\x0b\x32\x14.mind_ir.TensorProto\x12#\n\x06graphs\x18\r \x03(\x0b\x32\x13.mind_ir.GraphProto\x12\x12\n\ndoc_string\x18\x0e \x01(\t\x12\x15\n\rref_attr_name\x18\x0f \x01(\t\x12\x33\n\x04type\x18\x10 \x01(\x0e\x32%.mind_ir.AttributeProto.AttributeType\x12\'\n\x06values\x18\x11 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x36\n\x08seq_info\x18\x12 \x01(\x0b\x32$.mind_ir.AttributeProto.SeqInfoProto\x12&\n\x07\x66unctor\x18\x13 \x01(\x0b\x32\x15.mind_ir.FunctorProto\x1aT\n\x0cSeqInfoProto\x12\x12\n\nis_dyn_len\x18\x01 \x01(\x08\x12\x30\n\x0ftuple_elem_item\x18\x02 \x01(\x0b\x32\x17.mind_ir.AttributeProto\"\xaf\x04\n\rAttributeType\x12\r\n\tUNDEFINED\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\t\n\x05UINT8\x10\x02\x12\x08\n\x04INT8\x10\x03\x12\n\n\x06UINT16\x10\x04\x12\t\n\x05INT16\x10\x05\x12\t\n\x05INT32\x10\x06\x12\t\n\x05INT64\x10\x07\x12\n\n\x06STRING\x10\x08\x12\x08\n\x04\x42OOL\x10\t\x12\x0b\n\x07\x46LOAT16\x10\n\x12\n\n\x06\x44OUBLE\x10\x0b\x12\n\n\x06UINT32\x10\x0c\x12\n\n\x06UINT64\x10\r\x12\r\n\tCOMPLEX64\x10\x0e\x12\x0e\n\nCOMPLEX128\x10\x0f\x12\x0c\n\x08\x42\x46LOAT16\x10\x10\x12\n\n\x06TENSOR\x10\x11\x12\t\n\x05GRAPH\x10\x12\x12\x0b\n\x07TENSORS\x10\x13\x12\t\n\x05TUPLE\x10\x14\x12\x08\n\x04LIST\x10\x15\x12\x08\n\x04\x44ICT\x10\x16\x12\n\n\x06UMONAD\x10\x17\x12\x0b\n\x07IOMONAD\x10\x18\x12\x08\n\x04NONE\x10\x19\x12\x14\n\x10PRIMITIVECLOSURE\x10\x1a\x12\x14\n\x10\x46UNCGRAPHCLOSURE\x10\x1b\x12\x12\n\x0ePARTIALCLOSURE\x10\x1c\x12\x14\n\x10UNIONFUNCCLOSURE\x10\x1d\x12\x0e\n\nCSR_TENSOR\x10\x1e\x12\x0e\n\nCOO_TENSOR\x10\x1f\x12\x0e\n\nROW_TENSOR\x10 \x12\x0e\n\nCLASS_TYPE\x10!\x12\x0e\n\nNAME_SPACE\x10\"\x12\n\n\x06SYMBOL\x10#\x12\r\n\tTYPE_NULL\x10$\x12\x0e\n\nMAP_TENSOR\x10%\x12\x0b\n\x07\x46UNCTOR\x10&\x12\n\n\x06SCALAR\x10\'\"\xae\x01\n\x0c\x46unctorProto\x12/\n\x04type\x18\x01 \x01(\x0e\x32!.mind_ir.FunctorProto.FunctorType\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\'\n\x06values\x18\x03 \x03(\x0b\x32\x17.mind_ir.AttributeProto\"6\n\x0b\x46unctorType\x12\x16\n\x12SHAPE_CALC_FUNCTOR\x10\x01\x12\x0f\n\x0b\x41NY_FUNCTOR\x10\x02\"\x98\x01\n\x0eValueInfoProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12$\n\x06tensor\x18\x02 \x03(\x0b\x32\x14.mind_ir.TensorProto\x12\x12\n\ndoc_string\x18\x03 \x01(\t\x12\x12\n\ndenotation\x18\x04 \x01(\t\x12*\n\tattr_info\x18\x05 \x01(\x0b\x32\x17.mind_ir.AttributeProto\"\xf3\x01\n\tNodeProto\x12\r\n\x05input\x18\x01 \x03(\t\x12\x0e\n\x06output\x18\x02 \x03(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0f\n\x07op_type\x18\x04 \x01(\t\x12*\n\tattribute\x18\x05 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x12\n\ndoc_string\x18\x06 \x01(\t\x12\x0e\n\x06\x64omain\x18\x07 \x01(\t\x12*\n\tnode_attr\x18\x08 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12,\n\x0bprimal_attr\x18\t \x03(\x0b\x32\x17.mind_ir.AttributeProto\"\xf8\x03\n\nModelProto\x12\x12\n\nir_version\x18\x01 \x01(\t\x12\x15\n\rproducer_name\x18\x02 \x01(\t\x12\x18\n\x10producer_version\x18\x03 \x01(\t\x12\x0e\n\x06\x64omain\x18\x04 \x01(\t\x12\x15\n\rmodel_version\x18\x05 \x01(\t\x12\x12\n\ndoc_string\x18\x06 \x01(\t\x12\"\n\x05graph\x18\x07 \x01(\x0b\x32\x13.mind_ir.GraphProto\x12&\n\tfunctions\x18\x08 \x03(\x0b\x32\x13.mind_ir.GraphProto\x12\x30\n\x0cpreprocessor\x18\t \x01(\x0b\x32\x1a.mind_ir.PreprocessorProto\x12\x15\n\rlittle_endian\x18\n \x01(\x08\x12(\n\x08parallel\x18\x0b \x01(\x0b\x32\x16.mind_ir.ParallelProto\x12+\n\nprimitives\x18\x0c \x03(\x0b\x32\x17.mind_ir.PrimitiveProto\x12\x17\n\x0fmind_ir_version\x18\r \x01(\x03\x12\x34\n\tuser_info\x18\x0e \x03(\x0b\x32!.mind_ir.ModelProto.UserInfoEntry\x1a/\n\rUserInfoEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\";\n\x11PreprocessorProto\x12&\n\x02op\x18\x01 \x03(\x0b\x32\x1a.mind_ir.PreprocessOpProto\"\x91\x01\n\x11PreprocessOpProto\x12\x15\n\rinput_columns\x18\x01 \x01(\t\x12\x16\n\x0eoutput_columns\x18\x02 \x01(\t\x12\x17\n\x0fproject_columns\x18\x03 \x01(\t\x12\x0f\n\x07op_type\x18\x04 \x01(\t\x12\x12\n\noperations\x18\x05 \x01(\t\x12\x0f\n\x07offload\x18\x06 \x01(\x08\"\xd2\x02\n\nGraphProto\x12 \n\x04node\x18\x01 \x03(\x0b\x32\x12.mind_ir.NodeProto\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\'\n\tparameter\x18\x03 \x03(\x0b\x32\x14.mind_ir.TensorProto\x12\x12\n\ndoc_string\x18\x04 \x01(\t\x12&\n\x05input\x18\x05 \x03(\x0b\x32\x17.mind_ir.ValueInfoProto\x12\'\n\x06output\x18\x06 \x03(\x0b\x32\x17.mind_ir.ValueInfoProto\x12\x12\n\nbprop_hash\x18\x07 \x01(\t\x12*\n\tattribute\x18\x08 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x16\n\x0e\x62prop_filepath\x18\t \x01(\t\x12.\n\rmap_parameter\x18\n \x03(\x0b\x32\x17.mind_ir.MapTensorProto\"\xda\x07\n\x0bTensorProto\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\x12\x11\n\tdata_type\x18\x02 \x01(\x05\x12\x12\n\nfloat_data\x18\x03 \x03(\x02\x12\x12\n\nint32_data\x18\x04 \x03(\x05\x12\x13\n\x0bstring_data\x18\x05 \x03(\x0c\x12\x12\n\nint64_data\x18\x06 \x03(\x03\x12\x0c\n\x04name\x18\x07 \x01(\t\x12\x12\n\ndoc_string\x18\x08 \x01(\t\x12\x10\n\x08raw_data\x18\t \x01(\x0c\x12\x13\n\x0b\x64ouble_data\x18\n \x03(\x01\x12\x13\n\x0buint64_data\x18\x0b \x03(\x04\x12=\n\rexternal_data\x18\x0c \x01(\x0b\x32&.mind_ir.TensorProto.ExternalDataProto\x12\x0f\n\x07ref_key\x18\r \x01(\t\x12\x10\n\x08min_dims\x18\x0e \x03(\x03\x12\x10\n\x08max_dims\x18\x0f \x03(\x03\x12>\n\x10\x63ompression_type\x18\x10 \x01(\x0e\x32$.mind_ir.TensorProto.CompressionType\x12:\n\x0cquant_params\x18\x11 \x03(\x0b\x32$.mind_ir.TensorProto.QuantParamProto\x1a\x45\n\x11\x45xternalDataProto\x12\x10\n\x08location\x18\x01 \x01(\t\x12\x0e\n\x06offset\x18\x02 \x01(\x03\x12\x0e\n\x06length\x18\x03 \x01(\x03\x1aV\n\x0fQuantParamProto\x12\x17\n\x0fquant_algo_name\x18\x01 \x02(\t\x12*\n\tattribute\x18\x02 \x03(\x0b\x32\x17.mind_ir.AttributeProto\"\xf4\x01\n\x08\x44\x61taType\x12\r\n\tUNDEFINED\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\t\n\x05UINT8\x10\x02\x12\x08\n\x04INT8\x10\x03\x12\n\n\x06UINT16\x10\x04\x12\t\n\x05INT16\x10\x05\x12\t\n\x05INT32\x10\x06\x12\t\n\x05INT64\x10\x07\x12\n\n\x06STRING\x10\x08\x12\x08\n\x04\x42OOL\x10\t\x12\x0b\n\x07\x46LOAT16\x10\n\x12\n\n\x06\x44OUBLE\x10\x0b\x12\n\n\x06UINT32\x10\x0c\x12\n\n\x06UINT64\x10\r\x12\r\n\tCOMPLEX64\x10\x0e\x12\x0e\n\nCOMPLEX128\x10\x0f\x12\x0c\n\x08\x42\x46LOAT16\x10\x10\x12\x0b\n\x07\x46LOAT64\x10\x11\x12\x0b\n\x07QINT4X2\x10\x12\"u\n\x0f\x43ompressionType\x12\x12\n\x0eNO_COMPRESSION\x10\x00\x12\x0c\n\x08INDEXING\x10\x01\x12\n\n\x06SPARSE\x10\x02\x12\x07\n\x03\x46SE\x10\x03\x12\x0f\n\x0b\x42IT_PACKING\x10\x04\x12\x0b\n\x07\x46SE_INT\x10\x05\x12\r\n\tFSE_INFER\x10\x06\"\xd1\x01\n\x0eMapTensorProto\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\rdefault_value\x18\x02 \x02(\x0b\x32\x17.mind_ir.AttributeProto\x12(\n\nkey_tensor\x18\x03 \x02(\x0b\x32\x14.mind_ir.TensorProto\x12*\n\x0cvalue_tensor\x18\x04 \x02(\x0b\x32\x14.mind_ir.TensorProto\x12+\n\rstatus_tensor\x18\x05 \x02(\x0b\x32\x14.mind_ir.TensorProto\"5\n\rParallelProto\x12$\n\x06layout\x18\x01 \x03(\x0b\x32\x14.mind_ir.LayoutProto\"\xfd\x01\n\x0bLayoutProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1e\n\x16\x64\x65vice_arrangement_int\x18\x02 \x03(\x03\x12\x16\n\x0etensor_map_int\x18\x03 \x03(\x03\x12\x17\n\x0fslice_shape_int\x18\x04 \x03(\x03\x12\x12\n\nfield_size\x18\x05 \x01(\x03\x12\x15\n\runiform_split\x18\x06 \x01(\x08\x12\x17\n\x0fopt_shard_group\x18\x07 \x01(\t\x12\x17\n\x0fpipeline_shared\x18\x08 \x01(\x08\x12\x0f\n\x07is_send\x18\t \x01(\x08\x12\x11\n\tpeer_rank\x18\n \x01(\x03\x12\x0e\n\x06sr_tag\x18\x0b \x01(\x03\"\xda\x01\n\x0ePrimitiveProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07op_type\x18\x02 \x01(\t\x12*\n\tattribute\x18\x03 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x15\n\rinstance_name\x18\x04 \x01(\t\x12\x33\n\tprim_type\x18\x05 \x01(\x0e\x32 .mind_ir.PrimitiveProto.PrimType\"1\n\x08PrimType\x12\r\n\tPRIMITIVE\x10\x01\x12\x16\n\x12PRIMITIVE_FUNCTION\x10\x02*R\n\x07Version\x12\x14\n\x10IR_VERSION_START\x10\x00\x12\x0e\n\nIR_VERSION\x10\x01\x12!\n\x1dIR_VERSION_WITH_PRIM_FUNCTION\x10\x02'
|
|
24
24
|
)
|
|
25
25
|
|
|
26
26
|
_VERSION = _descriptor.EnumDescriptor(
|
|
@@ -48,8 +48,8 @@ _VERSION = _descriptor.EnumDescriptor(
|
|
|
48
48
|
],
|
|
49
49
|
containing_type=None,
|
|
50
50
|
serialized_options=None,
|
|
51
|
-
serialized_start=
|
|
52
|
-
serialized_end=
|
|
51
|
+
serialized_start=4557,
|
|
52
|
+
serialized_end=4639,
|
|
53
53
|
)
|
|
54
54
|
_sym_db.RegisterEnumDescriptor(_VERSION)
|
|
55
55
|
|
|
@@ -286,11 +286,16 @@ _FUNCTORPROTO_FUNCTORTYPE = _descriptor.EnumDescriptor(
|
|
|
286
286
|
serialized_options=None,
|
|
287
287
|
type=None,
|
|
288
288
|
create_key=_descriptor._internal_create_key),
|
|
289
|
+
_descriptor.EnumValueDescriptor(
|
|
290
|
+
name='ANY_FUNCTOR', index=1, number=2,
|
|
291
|
+
serialized_options=None,
|
|
292
|
+
type=None,
|
|
293
|
+
create_key=_descriptor._internal_create_key),
|
|
289
294
|
],
|
|
290
295
|
containing_type=None,
|
|
291
296
|
serialized_options=None,
|
|
292
297
|
serialized_start=1310,
|
|
293
|
-
serialized_end=
|
|
298
|
+
serialized_end=1364,
|
|
294
299
|
)
|
|
295
300
|
_sym_db.RegisterEnumDescriptor(_FUNCTORPROTO_FUNCTORTYPE)
|
|
296
301
|
|
|
@@ -399,8 +404,8 @@ _TENSORPROTO_DATATYPE = _descriptor.EnumDescriptor(
|
|
|
399
404
|
],
|
|
400
405
|
containing_type=None,
|
|
401
406
|
serialized_options=None,
|
|
402
|
-
serialized_start=
|
|
403
|
-
serialized_end=
|
|
407
|
+
serialized_start=3448,
|
|
408
|
+
serialized_end=3692,
|
|
404
409
|
)
|
|
405
410
|
_sym_db.RegisterEnumDescriptor(_TENSORPROTO_DATATYPE)
|
|
406
411
|
|
|
@@ -449,8 +454,8 @@ _TENSORPROTO_COMPRESSIONTYPE = _descriptor.EnumDescriptor(
|
|
|
449
454
|
],
|
|
450
455
|
containing_type=None,
|
|
451
456
|
serialized_options=None,
|
|
452
|
-
serialized_start=
|
|
453
|
-
serialized_end=
|
|
457
|
+
serialized_start=3694,
|
|
458
|
+
serialized_end=3811,
|
|
454
459
|
)
|
|
455
460
|
_sym_db.RegisterEnumDescriptor(_TENSORPROTO_COMPRESSIONTYPE)
|
|
456
461
|
|
|
@@ -474,8 +479,8 @@ _PRIMITIVEPROTO_PRIMTYPE = _descriptor.EnumDescriptor(
|
|
|
474
479
|
],
|
|
475
480
|
containing_type=None,
|
|
476
481
|
serialized_options=None,
|
|
477
|
-
serialized_start=
|
|
478
|
-
serialized_end=
|
|
482
|
+
serialized_start=4506,
|
|
483
|
+
serialized_end=4555,
|
|
479
484
|
)
|
|
480
485
|
_sym_db.RegisterEnumDescriptor(_PRIMITIVEPROTO_PRIMTYPE)
|
|
481
486
|
|
|
@@ -720,7 +725,7 @@ _FUNCTORPROTO = _descriptor.Descriptor(
|
|
|
720
725
|
oneofs=[
|
|
721
726
|
],
|
|
722
727
|
serialized_start=1190,
|
|
723
|
-
serialized_end=
|
|
728
|
+
serialized_end=1364,
|
|
724
729
|
)
|
|
725
730
|
|
|
726
731
|
|
|
@@ -779,8 +784,8 @@ _VALUEINFOPROTO = _descriptor.Descriptor(
|
|
|
779
784
|
extension_ranges=[],
|
|
780
785
|
oneofs=[
|
|
781
786
|
],
|
|
782
|
-
serialized_start=
|
|
783
|
-
serialized_end=
|
|
787
|
+
serialized_start=1367,
|
|
788
|
+
serialized_end=1519,
|
|
784
789
|
)
|
|
785
790
|
|
|
786
791
|
|
|
@@ -867,8 +872,8 @@ _NODEPROTO = _descriptor.Descriptor(
|
|
|
867
872
|
extension_ranges=[],
|
|
868
873
|
oneofs=[
|
|
869
874
|
],
|
|
870
|
-
serialized_start=
|
|
871
|
-
serialized_end=
|
|
875
|
+
serialized_start=1522,
|
|
876
|
+
serialized_end=1765,
|
|
872
877
|
)
|
|
873
878
|
|
|
874
879
|
|
|
@@ -906,8 +911,8 @@ _MODELPROTO_USERINFOENTRY = _descriptor.Descriptor(
|
|
|
906
911
|
extension_ranges=[],
|
|
907
912
|
oneofs=[
|
|
908
913
|
],
|
|
909
|
-
serialized_start=
|
|
910
|
-
serialized_end=
|
|
914
|
+
serialized_start=2225,
|
|
915
|
+
serialized_end=2272,
|
|
911
916
|
)
|
|
912
917
|
|
|
913
918
|
_MODELPROTO = _descriptor.Descriptor(
|
|
@@ -1028,8 +1033,8 @@ _MODELPROTO = _descriptor.Descriptor(
|
|
|
1028
1033
|
extension_ranges=[],
|
|
1029
1034
|
oneofs=[
|
|
1030
1035
|
],
|
|
1031
|
-
serialized_start=
|
|
1032
|
-
serialized_end=
|
|
1036
|
+
serialized_start=1768,
|
|
1037
|
+
serialized_end=2272,
|
|
1033
1038
|
)
|
|
1034
1039
|
|
|
1035
1040
|
|
|
@@ -1060,8 +1065,8 @@ _PREPROCESSORPROTO = _descriptor.Descriptor(
|
|
|
1060
1065
|
extension_ranges=[],
|
|
1061
1066
|
oneofs=[
|
|
1062
1067
|
],
|
|
1063
|
-
serialized_start=
|
|
1064
|
-
serialized_end=
|
|
1068
|
+
serialized_start=2274,
|
|
1069
|
+
serialized_end=2333,
|
|
1065
1070
|
)
|
|
1066
1071
|
|
|
1067
1072
|
|
|
@@ -1127,8 +1132,8 @@ _PREPROCESSOPPROTO = _descriptor.Descriptor(
|
|
|
1127
1132
|
extension_ranges=[],
|
|
1128
1133
|
oneofs=[
|
|
1129
1134
|
],
|
|
1130
|
-
serialized_start=
|
|
1131
|
-
serialized_end=
|
|
1135
|
+
serialized_start=2336,
|
|
1136
|
+
serialized_end=2481,
|
|
1132
1137
|
)
|
|
1133
1138
|
|
|
1134
1139
|
|
|
@@ -1222,8 +1227,8 @@ _GRAPHPROTO = _descriptor.Descriptor(
|
|
|
1222
1227
|
extension_ranges=[],
|
|
1223
1228
|
oneofs=[
|
|
1224
1229
|
],
|
|
1225
|
-
serialized_start=
|
|
1226
|
-
serialized_end=
|
|
1230
|
+
serialized_start=2484,
|
|
1231
|
+
serialized_end=2822,
|
|
1227
1232
|
)
|
|
1228
1233
|
|
|
1229
1234
|
|
|
@@ -1268,8 +1273,8 @@ _TENSORPROTO_EXTERNALDATAPROTO = _descriptor.Descriptor(
|
|
|
1268
1273
|
extension_ranges=[],
|
|
1269
1274
|
oneofs=[
|
|
1270
1275
|
],
|
|
1271
|
-
serialized_start=
|
|
1272
|
-
serialized_end=
|
|
1276
|
+
serialized_start=3288,
|
|
1277
|
+
serialized_end=3357,
|
|
1273
1278
|
)
|
|
1274
1279
|
|
|
1275
1280
|
_TENSORPROTO_QUANTPARAMPROTO = _descriptor.Descriptor(
|
|
@@ -1306,8 +1311,8 @@ _TENSORPROTO_QUANTPARAMPROTO = _descriptor.Descriptor(
|
|
|
1306
1311
|
extension_ranges=[],
|
|
1307
1312
|
oneofs=[
|
|
1308
1313
|
],
|
|
1309
|
-
serialized_start=
|
|
1310
|
-
serialized_end=
|
|
1314
|
+
serialized_start=3359,
|
|
1315
|
+
serialized_end=3445,
|
|
1311
1316
|
)
|
|
1312
1317
|
|
|
1313
1318
|
_TENSORPROTO = _descriptor.Descriptor(
|
|
@@ -1451,8 +1456,8 @@ _TENSORPROTO = _descriptor.Descriptor(
|
|
|
1451
1456
|
extension_ranges=[],
|
|
1452
1457
|
oneofs=[
|
|
1453
1458
|
],
|
|
1454
|
-
serialized_start=
|
|
1455
|
-
serialized_end=
|
|
1459
|
+
serialized_start=2825,
|
|
1460
|
+
serialized_end=3811,
|
|
1456
1461
|
)
|
|
1457
1462
|
|
|
1458
1463
|
|
|
@@ -1511,8 +1516,8 @@ _MAPTENSORPROTO = _descriptor.Descriptor(
|
|
|
1511
1516
|
extension_ranges=[],
|
|
1512
1517
|
oneofs=[
|
|
1513
1518
|
],
|
|
1514
|
-
serialized_start=
|
|
1515
|
-
serialized_end=
|
|
1519
|
+
serialized_start=3814,
|
|
1520
|
+
serialized_end=4023,
|
|
1516
1521
|
)
|
|
1517
1522
|
|
|
1518
1523
|
|
|
@@ -1543,8 +1548,8 @@ _PARALLELPROTO = _descriptor.Descriptor(
|
|
|
1543
1548
|
extension_ranges=[],
|
|
1544
1549
|
oneofs=[
|
|
1545
1550
|
],
|
|
1546
|
-
serialized_start=
|
|
1547
|
-
serialized_end=
|
|
1551
|
+
serialized_start=4025,
|
|
1552
|
+
serialized_end=4078,
|
|
1548
1553
|
)
|
|
1549
1554
|
|
|
1550
1555
|
|
|
@@ -1645,8 +1650,8 @@ _LAYOUTPROTO = _descriptor.Descriptor(
|
|
|
1645
1650
|
extension_ranges=[],
|
|
1646
1651
|
oneofs=[
|
|
1647
1652
|
],
|
|
1648
|
-
serialized_start=
|
|
1649
|
-
serialized_end=
|
|
1653
|
+
serialized_start=4081,
|
|
1654
|
+
serialized_end=4334,
|
|
1650
1655
|
)
|
|
1651
1656
|
|
|
1652
1657
|
|
|
@@ -1706,8 +1711,8 @@ _PRIMITIVEPROTO = _descriptor.Descriptor(
|
|
|
1706
1711
|
extension_ranges=[],
|
|
1707
1712
|
oneofs=[
|
|
1708
1713
|
],
|
|
1709
|
-
serialized_start=
|
|
1710
|
-
serialized_end=
|
|
1714
|
+
serialized_start=4337,
|
|
1715
|
+
serialized_end=4555,
|
|
1711
1716
|
)
|
|
1712
1717
|
|
|
1713
1718
|
_ATTRIBUTEPROTO_SEQINFOPROTO.fields_by_name['tuple_elem_item'].message_type = _ATTRIBUTEPROTO
|
mindspore/train/model.py
CHANGED
|
@@ -36,7 +36,7 @@ from mindspore.train.metrics import get_metrics, get_metric_fn
|
|
|
36
36
|
from mindspore._checkparam import check_input_data, check_output_data
|
|
37
37
|
from mindspore import _checkparam as Validator
|
|
38
38
|
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor,\
|
|
39
|
-
FlopsUtilizationCollector,
|
|
39
|
+
FlopsUtilizationCollector, TFTRegister
|
|
40
40
|
from mindspore.train.callback import __all__ as internal_cb_names
|
|
41
41
|
from mindspore.train.callback._cluster_monitor import ClusterMonitor
|
|
42
42
|
from mindspore import context
|
|
@@ -119,6 +119,101 @@ def _save_final_ckpt(func):
|
|
|
119
119
|
func(self, *args, **kwargs)
|
|
120
120
|
return wrapper
|
|
121
121
|
|
|
122
|
+
def _handle_tft(func):
|
|
123
|
+
"""
|
|
124
|
+
Decorator function, which starts uce handle process when an exception occurs during training.
|
|
125
|
+
"""
|
|
126
|
+
@wraps(func)
|
|
127
|
+
def wrapper(self, *args, **kwargs):
|
|
128
|
+
obj = None
|
|
129
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TFTRegister):
|
|
130
|
+
obj = kwargs.get('callbacks')
|
|
131
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
|
132
|
+
for item in kwargs.get('callbacks'):
|
|
133
|
+
if isinstance(item, TFTRegister):
|
|
134
|
+
obj = item
|
|
135
|
+
if obj:
|
|
136
|
+
tft = obj.tft
|
|
137
|
+
tft_env = os.getenv("MS_ENABLE_TFT", "")
|
|
138
|
+
uce_env = "UCE:1" in tft_env
|
|
139
|
+
while True:
|
|
140
|
+
try:
|
|
141
|
+
return func(self, *args, **kwargs)
|
|
142
|
+
except RuntimeError as e:
|
|
143
|
+
logger.info("uce wrapper caught RuntimeError")
|
|
144
|
+
if not uce_env:
|
|
145
|
+
logger.info("uce wrapper caught RuntimeError uce not enable")
|
|
146
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
147
|
+
raise e
|
|
148
|
+
e_str = str(e)
|
|
149
|
+
logger.info("uce wrapper caught RuntimeError e_str:{}".format(e_str))
|
|
150
|
+
if "UCEError" in e_str:
|
|
151
|
+
logger.info("uce wrapper report UCEError")
|
|
152
|
+
tft.tft_report_error(tft.ReportState.RS_UCE.value)
|
|
153
|
+
elif "ForceStopError" in e_str:
|
|
154
|
+
logger.info("uce wrapper caught RuntimeError ForceStopError")
|
|
155
|
+
force_stop_err = tft.ReportState.RS_NORMAL.value
|
|
156
|
+
tft.tft_report_error(force_stop_err)
|
|
157
|
+
else:
|
|
158
|
+
logger.info("uce wrapper caught RuntimeError rankid: {} OTHER ERROR")
|
|
159
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
160
|
+
raise e
|
|
161
|
+
ret = tft.tft_wait_next_action()
|
|
162
|
+
if ret == tft.Action.EXIT.value:
|
|
163
|
+
raise e
|
|
164
|
+
repair_step = tft.tft_get_repair_step()
|
|
165
|
+
logger.info("uce wrapper caught repair finish REPAIR STEP: {} batch_num: \
|
|
166
|
+
{}".format(repair_step, self.batch_num))
|
|
167
|
+
initial_epoch = int(repair_step/self.batch_num)
|
|
168
|
+
initial_step = repair_step % self.batch_num
|
|
169
|
+
kwargs["initial_epoch"] = initial_epoch
|
|
170
|
+
|
|
171
|
+
train_dataset = args[1]
|
|
172
|
+
dataset_sink_mode = args[3] if len(args) > 3 else kwargs.get('dataset_sink_mode', True)
|
|
173
|
+
sink_size = args[4] if len(args) > 4 else kwargs.get('sink_size', -1)
|
|
174
|
+
|
|
175
|
+
cb_initial_step = 0
|
|
176
|
+
if dataset_sink_mode:
|
|
177
|
+
train_dataset.set_init_step(initial_epoch)
|
|
178
|
+
dataset_size = train_dataset.get_dataset_size()
|
|
179
|
+
if sink_size != -1:
|
|
180
|
+
cb_initial_step = initial_epoch * sink_size + initial_step
|
|
181
|
+
else:
|
|
182
|
+
cb_initial_step = initial_epoch * dataset_size + initial_step
|
|
183
|
+
else:
|
|
184
|
+
train_dataset.set_init_step(initial_step)
|
|
185
|
+
cb_initial_step = initial_step
|
|
186
|
+
|
|
187
|
+
kwargs["initial_step"] = cb_initial_step
|
|
188
|
+
|
|
189
|
+
logger.info("uce wrapper repair complete \
|
|
190
|
+
initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch, cb_initial_step))
|
|
191
|
+
continue
|
|
192
|
+
except BaseException as e:
|
|
193
|
+
logger.info("uce wrapper caught BaseException error")
|
|
194
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
195
|
+
raise e
|
|
196
|
+
else:
|
|
197
|
+
return func(self, *args, **kwargs)
|
|
198
|
+
return wrapper
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _check_tft():
|
|
202
|
+
"""Check if TFT is supported"""
|
|
203
|
+
tft_env = os.getenv("MS_ENABLE_TFT")
|
|
204
|
+
device_target = context.get_context("device_target")
|
|
205
|
+
if tft_env and device_target == "Ascend":
|
|
206
|
+
from mindspore._c_expression import MSContext
|
|
207
|
+
ascend_target = MSContext.get_instance().get_ascend_soc_version()
|
|
208
|
+
if ascend_target == 'ascend910':
|
|
209
|
+
raise ValueError("TFT is not supported when using ascend910")
|
|
210
|
+
ms_mode = context.get_context("mode")
|
|
211
|
+
if ms_mode != mindspore.GRAPH_MODE:
|
|
212
|
+
raise ValueError("TFT is only supported in GRAPH_MODE")
|
|
213
|
+
jit_level = context.get_context("jit_level")
|
|
214
|
+
if jit_level == "O2" and "UCE:1" in tft_env:
|
|
215
|
+
raise ValueError("TFT is not supported when using jit_level == O2")
|
|
216
|
+
|
|
122
217
|
|
|
123
218
|
def _append_ccae(callbacks):
|
|
124
219
|
"""Add cluster monitoring when CCAE is enabled."""
|
|
@@ -290,21 +385,11 @@ class Model:
|
|
|
290
385
|
amp_level (str): Option for argument `level` in :func:`mindspore.amp.build_train_network`, level for mixed
|
|
291
386
|
precision training. Supports ["O0", "O1", "O2", "O3", "auto"]. Default: ``"O0"`` .
|
|
292
387
|
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
The operators in the whitelist: [Conv1d, Conv2d, Conv3d, Conv1dTranspose, Conv2dTranspose,
|
|
296
|
-
Conv3dTranspose, Dense, LSTMCell, RNNCell, GRUCell, MatMul, BatchMatMul, PReLU, ReLU, Ger].
|
|
297
|
-
- "O2": Cast network to float16, keep BatchNorm run in float32, using dynamic loss scale.
|
|
298
|
-
- "O3": Cast network to float16, the BatchNorm is also cast to float16, loss scale will not be used.
|
|
299
|
-
- "auto": Set level to recommended level in different devices. Set level to "O2" on GPU, set
|
|
300
|
-
level to "O3" on Ascend. The recommended level is chosen by the expert experience, not applicable to all
|
|
301
|
-
scenarios. User should specify the level for special network.
|
|
302
|
-
|
|
303
|
-
"O2" is recommended on GPU, "O3" is recommended on Ascend.
|
|
388
|
+
For details on `amp_level` , refer to :func:`mindspore.amp.auto_mixed_precision`.
|
|
389
|
+
|
|
304
390
|
The BatchNorm strategy can be changed by `keep_batchnorm_fp32` settings in `kwargs`. `keep_batchnorm_fp32`
|
|
305
391
|
must be a bool. The loss scale strategy can be changed by `loss_scale_manager` setting in `kwargs`.
|
|
306
392
|
`loss_scale_manager` should be a subclass of :class:`mindspore.amp.LossScaleManager`.
|
|
307
|
-
The more detailed explanation of `amp_level` setting can be found at `mindspore.amp.build_train_network`.
|
|
308
393
|
|
|
309
394
|
boost_level (str): Option for argument `level` in `mindspore.boost`, level for boost mode
|
|
310
395
|
training. Supports ["O0", "O1", "O2"]. Default: ``"O0"`` .
|
|
@@ -379,6 +464,7 @@ class Model:
|
|
|
379
464
|
self._mindspore_lite = None
|
|
380
465
|
self._lite_infer = True # if backend lite infer fails, set False
|
|
381
466
|
self._mindspore_lite_model_group_id = id(self) & 0xFFFF
|
|
467
|
+
self.batch_num = -1
|
|
382
468
|
|
|
383
469
|
def _check_for_graph_cell(self, kwargs):
|
|
384
470
|
"""Check for graph cell"""
|
|
@@ -568,9 +654,11 @@ class Model:
|
|
|
568
654
|
dataset.__loop_size__ = 1
|
|
569
655
|
|
|
570
656
|
if dataset_helper is None:
|
|
657
|
+
logger.info("Begin to create DatasetHelper.")
|
|
571
658
|
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
|
|
572
659
|
|
|
573
660
|
if dataset_sink_mode:
|
|
661
|
+
logger.info("Begin to connect network with dataset.")
|
|
574
662
|
network = connect_network_with_dataset(network, dataset_helper)
|
|
575
663
|
|
|
576
664
|
if _get_recovery_context("enable_recovery") and is_train:
|
|
@@ -589,6 +677,10 @@ class Model:
|
|
|
589
677
|
if self._backbone_is_train != is_train:
|
|
590
678
|
network.set_train(is_train)
|
|
591
679
|
self._backbone_is_train = is_train
|
|
680
|
+
# Mode train and eval are the same net, network will be set_grad in _build_train_network.
|
|
681
|
+
# But if mode just want to do predict or eval, must set network set_grad False
|
|
682
|
+
if not is_train:
|
|
683
|
+
network.set_grad(False)
|
|
592
684
|
return network
|
|
593
685
|
|
|
594
686
|
def _check_need_ckpt(self, callbacks):
|
|
@@ -687,6 +779,7 @@ class Model:
|
|
|
687
779
|
if not train_dataset and not valid_dataset:
|
|
688
780
|
raise ValueError("The argument 'train_dataset' and 'valid_dataset' can not both be None or empty.")
|
|
689
781
|
|
|
782
|
+
logger.info("Begin to check device number in model.build() procedure.")
|
|
690
783
|
_device_number_check(self._parallel_mode, self._device_number)
|
|
691
784
|
|
|
692
785
|
if train_dataset:
|
|
@@ -694,27 +787,34 @@ class Model:
|
|
|
694
787
|
raise TypeError("The type of 'train_dataset' must be `Dataset`, "
|
|
695
788
|
"but got {}.".format(type(train_dataset)))
|
|
696
789
|
|
|
790
|
+
logger.info("Begin to check parameter broadcast in model.build() procedure.")
|
|
697
791
|
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
|
698
792
|
if self._parameter_broadcast:
|
|
699
793
|
self._train_network.set_broadcast_flag()
|
|
700
794
|
|
|
795
|
+
logger.info("Begin to exec preprocess in model.build() procedure.")
|
|
701
796
|
train_dataset.__no_send__ = True
|
|
702
797
|
train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
|
|
703
798
|
dataset=train_dataset,
|
|
704
799
|
dataset_sink_mode=True,
|
|
705
800
|
sink_size=sink_size)
|
|
801
|
+
logger.info("Begin to warmup dataset in model.build() procedure.")
|
|
706
802
|
self._warmup_dataset(epoch, train_dataset, sink_size)
|
|
707
803
|
|
|
708
804
|
# Since dataset pipeline has been triggered, delete flag
|
|
709
805
|
delattr(train_dataset, "__no_send__")
|
|
710
806
|
|
|
711
807
|
# Waiting for the dataset warmup ready
|
|
808
|
+
logger.info("Begin waiting for dataset warmup in model.build() procedure.")
|
|
712
809
|
self._waiting_for_dataset_warmup_ready(train_dataset)
|
|
810
|
+
logger.info("The dataset warmup was successful in model.build() procedure.")
|
|
713
811
|
|
|
714
812
|
if context.get_auto_parallel_context("pipeline_stages") > 1 and valid_dataset:
|
|
715
813
|
train_network.add_flags_recursive(is_first_iteration=True)
|
|
716
814
|
for inputs in train_dataset_helper:
|
|
815
|
+
logger.info("Begin to compile train network in model.build() procedure.")
|
|
717
816
|
train_network.compile(*inputs)
|
|
817
|
+
self._train_network.parameter_layout_dict = train_network.parameter_layout_dict
|
|
718
818
|
break
|
|
719
819
|
|
|
720
820
|
if valid_dataset:
|
|
@@ -732,6 +832,7 @@ class Model:
|
|
|
732
832
|
if context.get_auto_parallel_context("pipeline_stages") > 1:
|
|
733
833
|
eval_network.add_flags_recursive(is_first_iteration=False)
|
|
734
834
|
for inputs in valid_dataset_helper:
|
|
835
|
+
logger.info("Begin to compile eval network in model.build() procedure.")
|
|
735
836
|
eval_network.compile(*inputs)
|
|
736
837
|
break
|
|
737
838
|
|
|
@@ -746,9 +847,10 @@ class Model:
|
|
|
746
847
|
|
|
747
848
|
return [callbacks]
|
|
748
849
|
|
|
850
|
+
@_handle_tft
|
|
749
851
|
@_save_final_ckpt
|
|
750
852
|
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1, initial_epoch=0,
|
|
751
|
-
valid_dataset=None, valid_frequency=1, valid_dataset_sink_mode=True):
|
|
853
|
+
valid_dataset=None, valid_frequency=1, valid_dataset_sink_mode=True, initial_step=0):
|
|
752
854
|
"""
|
|
753
855
|
Training.
|
|
754
856
|
|
|
@@ -772,12 +874,14 @@ class Model:
|
|
|
772
874
|
self._train_network.set_broadcast_flag()
|
|
773
875
|
|
|
774
876
|
cb_params = _InternalCallbackParam()
|
|
877
|
+
cb_params.cur_step_num = initial_step
|
|
775
878
|
cb_params.train_network = self._train_network
|
|
776
879
|
cb_params.epoch_num = epoch - initial_epoch
|
|
777
880
|
if dataset_sink_mode and sink_size > 0:
|
|
778
881
|
cb_params.batch_num = sink_size
|
|
779
882
|
else:
|
|
780
883
|
cb_params.batch_num = train_dataset.get_dataset_size()
|
|
884
|
+
self.batch_num = cb_params.batch_num
|
|
781
885
|
cb_params.mode = "train"
|
|
782
886
|
cb_params.loss_fn = self._loss_fn
|
|
783
887
|
cb_params.optimizer = self._optimizer
|
|
@@ -806,11 +910,13 @@ class Model:
|
|
|
806
910
|
with _CallbackManager(callbacks) as list_callback:
|
|
807
911
|
self._check_reuse_dataset(train_dataset)
|
|
808
912
|
if not dataset_sink_mode:
|
|
809
|
-
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch,
|
|
913
|
+
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch,
|
|
914
|
+
valid_infos)
|
|
810
915
|
elif context.get_context("device_target") == "CPU":
|
|
811
916
|
logger.info("The CPU cannot support dataset sink mode currently."
|
|
812
917
|
"So the training process will be performed with dataset not sink.")
|
|
813
|
-
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch,
|
|
918
|
+
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch,
|
|
919
|
+
valid_infos)
|
|
814
920
|
else:
|
|
815
921
|
self._train_dataset_sink_process(epoch, train_dataset, list_callback,
|
|
816
922
|
cb_params, sink_size, initial_epoch, valid_infos)
|
|
@@ -850,9 +956,7 @@ class Model:
|
|
|
850
956
|
train_dataset.__total_batch__ = epoch * sink_size
|
|
851
957
|
|
|
852
958
|
cb_params.sink_size = sink_size
|
|
853
|
-
cb_params.cur_step_num = 0
|
|
854
959
|
cb_params.dataset_sink_mode = True
|
|
855
|
-
|
|
856
960
|
run_context = RunContext(cb_params)
|
|
857
961
|
list_callback.on_train_begin(run_context)
|
|
858
962
|
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
|
@@ -861,7 +965,6 @@ class Model:
|
|
|
861
965
|
dataset_helper = train_dataset._dataset_helper
|
|
862
966
|
|
|
863
967
|
self.epoch_iter = 0
|
|
864
|
-
|
|
865
968
|
self._check_enable_recovery()
|
|
866
969
|
# Used to check whether need perform recovery for process which is restarted.
|
|
867
970
|
self._check_need_load_ckpt(cb_params, dataset_size, sink_size)
|
|
@@ -997,7 +1100,6 @@ class Model:
|
|
|
997
1100
|
dataset_size (int): The number of batches in a dataset.
|
|
998
1101
|
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
999
1102
|
"""
|
|
1000
|
-
|
|
1001
1103
|
if not self.enable_recovery:
|
|
1002
1104
|
self.need_load_ckpt = False
|
|
1003
1105
|
|
|
@@ -1084,7 +1186,6 @@ class Model:
|
|
|
1084
1186
|
dataset=train_dataset,
|
|
1085
1187
|
dataset_sink_mode=False,
|
|
1086
1188
|
epoch_num=epoch)
|
|
1087
|
-
cb_params.cur_step_num = 0
|
|
1088
1189
|
cb_params.dataset_sink_mode = False
|
|
1089
1190
|
run_context = RunContext(cb_params)
|
|
1090
1191
|
list_callback.on_train_begin(run_context)
|
|
@@ -1106,7 +1207,6 @@ class Model:
|
|
|
1106
1207
|
"returned by 'train_dataset'".format(len_element))
|
|
1107
1208
|
cb_params.cur_step_num += 1
|
|
1108
1209
|
self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
|
1109
|
-
|
|
1110
1210
|
cb_params.train_dataset_element = next_element
|
|
1111
1211
|
list_callback.on_train_step_begin(run_context)
|
|
1112
1212
|
self._check_network_mode(self._train_network, True)
|
|
@@ -1150,31 +1250,6 @@ class Model:
|
|
|
1150
1250
|
|
|
1151
1251
|
list_callback.on_train_end(run_context)
|
|
1152
1252
|
|
|
1153
|
-
def _wrapper_train(self, callbacks):
|
|
1154
|
-
"""
|
|
1155
|
-
This method used to wrap train function with ttp wrapper which will do event notify when
|
|
1156
|
-
exceptions throw.
|
|
1157
|
-
|
|
1158
|
-
Args:
|
|
1159
|
-
callbacks (function): Callbacks passed by train method.
|
|
1160
|
-
"""
|
|
1161
|
-
|
|
1162
|
-
if not callbacks:
|
|
1163
|
-
return self._train
|
|
1164
|
-
cbs = callbacks if isinstance(callbacks, list) else [callbacks]
|
|
1165
|
-
obj = None
|
|
1166
|
-
_train_wrapper = None
|
|
1167
|
-
for item in cbs:
|
|
1168
|
-
if isinstance(item, MindIOTTPAdapter):
|
|
1169
|
-
obj = item
|
|
1170
|
-
|
|
1171
|
-
if (obj is not None) and (obj.enable is True):
|
|
1172
|
-
logger.info("MindIO TTP is enable, so we wrapper ttp exception handdler for self train method.")
|
|
1173
|
-
_train_wrapper = obj.wrapper_ttp_persist(self._train)
|
|
1174
|
-
|
|
1175
|
-
return self._train if not _train_wrapper else _train_wrapper
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
1253
|
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=False, sink_size=-1, initial_epoch=0):
|
|
1179
1254
|
"""
|
|
1180
1255
|
Training API.
|
|
@@ -1240,9 +1315,10 @@ class Model:
|
|
|
1240
1315
|
... loss_scale_manager=loss_scale_manager)
|
|
1241
1316
|
>>> model.train(2, dataset)
|
|
1242
1317
|
"""
|
|
1318
|
+
_check_tft()
|
|
1319
|
+
device_target = context.get_context("device_target")
|
|
1243
1320
|
# prepare dataset for obfuscated model
|
|
1244
1321
|
train_dataset = self._prepare_obf_dataset(train_dataset)
|
|
1245
|
-
device_target = context.get_context("device_target")
|
|
1246
1322
|
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
1247
1323
|
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
|
1248
1324
|
dataset_sink_mode = False
|
|
@@ -1283,16 +1359,14 @@ class Model:
|
|
|
1283
1359
|
_device_number_check(self._parallel_mode, self._device_number)
|
|
1284
1360
|
|
|
1285
1361
|
callbacks = _append_ccae(callbacks)
|
|
1286
|
-
_train_wrapper = None
|
|
1287
1362
|
if callbacks:
|
|
1288
1363
|
self._check_methods_for_custom_callbacks(callbacks, "train")
|
|
1289
|
-
|
|
1290
|
-
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
initial_epoch=initial_epoch)
|
|
1364
|
+
self._train(epoch,
|
|
1365
|
+
train_dataset,
|
|
1366
|
+
callbacks=callbacks,
|
|
1367
|
+
dataset_sink_mode=dataset_sink_mode,
|
|
1368
|
+
sink_size=sink_size,
|
|
1369
|
+
initial_epoch=initial_epoch)
|
|
1296
1370
|
|
|
1297
1371
|
# When it's distributed training and using MindRT,
|
|
1298
1372
|
# the node id should be reset to start from 0.
|
|
@@ -1396,7 +1470,7 @@ class Model:
|
|
|
1396
1470
|
|
|
1397
1471
|
Tutorial Examples:
|
|
1398
1472
|
- `Advanced Encapsulation: Model - Train and Save Model
|
|
1399
|
-
<https://www.mindspore.cn/
|
|
1473
|
+
<https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
|
|
1400
1474
|
"""
|
|
1401
1475
|
device_target = context.get_context("device_target")
|
|
1402
1476
|
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
@@ -1493,7 +1567,9 @@ class Model:
|
|
|
1493
1567
|
if hasattr(self._train_network, '_is_check_and_refresh') and not self._train_network._is_check_and_refresh:
|
|
1494
1568
|
self._train_network.check_names_and_refresh_name()
|
|
1495
1569
|
self._train_network._is_check_and_refresh = True
|
|
1570
|
+
logger.info("Begin to init dataset in model.build() procedure.")
|
|
1496
1571
|
self._init(train_dataset, valid_dataset, sink_size, epoch)
|
|
1572
|
+
logger.info("The model.build() which contains dataset warmup and network compile is success.")
|
|
1497
1573
|
|
|
1498
1574
|
def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
|
|
1499
1575
|
"""
|
|
@@ -1663,7 +1739,7 @@ class Model:
|
|
|
1663
1739
|
|
|
1664
1740
|
Tutorial Examples:
|
|
1665
1741
|
- `Advanced Encapsulation: Model - Train and Save Model
|
|
1666
|
-
<https://www.mindspore.cn/
|
|
1742
|
+
<https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
|
|
1667
1743
|
"""
|
|
1668
1744
|
valid_dataset = self._prepare_obf_dataset(valid_dataset)
|
|
1669
1745
|
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|