mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.5.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +8 -3
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +0 -5
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/compile_config.py +64 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
- mindspore/_extends/parse/parser.py +23 -5
- mindspore/_extends/parse/standard_method.py +123 -27
- mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
- mindspore/amp.py +7 -1
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/boost_cell_wrapper.py +136 -41
- mindspore/common/__init__.py +3 -1
- mindspore/common/_register_for_tensor.py +0 -1
- mindspore/common/_stub_tensor.py +25 -4
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +6132 -0
- mindspore/common/api.py +98 -21
- mindspore/common/dtype.py +34 -34
- mindspore/common/dump.py +2 -1
- mindspore/common/file_system.py +8 -3
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +3 -1
- mindspore/common/initializer.py +3 -4
- mindspore/common/lazy_inline.py +8 -2
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/parameter.py +31 -15
- mindspore/common/tensor.py +713 -1337
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/comm_func.py +215 -173
- mindspore/communication/management.py +23 -20
- mindspore/context.py +285 -191
- mindspore/dataset/__init__.py +23 -19
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +84 -3
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +5 -4
- mindspore/dataset/engine/datasets.py +192 -149
- mindspore/dataset/engine/datasets_audio.py +14 -0
- mindspore/dataset/engine/datasets_standard_format.py +11 -11
- mindspore/dataset/engine/datasets_text.py +38 -1
- mindspore/dataset/engine/datasets_user_defined.py +100 -66
- mindspore/dataset/engine/datasets_vision.py +81 -8
- mindspore/dataset/engine/iterators.py +281 -63
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +26 -2
- mindspore/dataset/engine/serializer_deserializer.py +1 -1
- mindspore/dataset/engine/validators.py +43 -11
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +29 -12
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +94 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +127 -0
- mindspore/device_context/cpu/__init__.py +25 -0
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +134 -0
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/optim/adadelta.py +26 -22
- mindspore/experimental/optim/adam.py +3 -0
- mindspore/experimental/optim/lr_scheduler.py +33 -24
- mindspore/experimental/optim/radam.py +33 -30
- mindspore/hal/device.py +28 -0
- mindspore/hal/event.py +17 -0
- mindspore/hal/memory.py +94 -3
- mindspore/hal/stream.py +91 -6
- mindspore/include/api/context.h +0 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +12 -0
- mindspore/mindrecord/__init__.py +1 -1
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +824 -218
- mindspore/mint/distributed/__init__.py +66 -4
- mindspore/mint/distributed/distributed.py +2594 -44
- mindspore/mint/linalg/__init__.py +6 -0
- mindspore/mint/nn/__init__.py +473 -14
- mindspore/mint/nn/functional.py +486 -11
- mindspore/mint/nn/layer/__init__.py +17 -4
- mindspore/mint/nn/layer/_functions.py +330 -0
- mindspore/mint/nn/layer/activation.py +169 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +727 -0
- mindspore/mint/nn/layer/normalization.py +215 -19
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +170 -0
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/cell.py +126 -19
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +6 -6
- mindspore/nn/layer/basic.py +35 -25
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/embedding.py +3 -3
- mindspore/nn/layer/normalization.py +8 -7
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +47 -13
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +48 -26
- mindspore/nn/learning_rate_schedule.py +5 -3
- mindspore/nn/loss/loss.py +31 -36
- mindspore/nn/optim/ada_grad.py +1 -0
- mindspore/nn/optim/adadelta.py +2 -2
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/cell_wrapper.py +4 -6
- mindspore/nn/wrap/loss_scale.py +3 -4
- mindspore/numpy/array_creations.py +60 -62
- mindspore/numpy/array_ops.py +148 -143
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +16 -16
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +2 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +94 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_vmap/vmap_array_ops.py +20 -19
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
- mindspore/ops/_vmap/vmap_math_ops.py +11 -9
- mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
- mindspore/ops/auto_generate/gen_extend_func.py +554 -60
- mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
- mindspore/ops/auto_generate/gen_ops_prim.py +8024 -3409
- mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
- mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
- mindspore/ops/function/__init__.py +12 -0
- mindspore/ops/function/array_func.py +561 -159
- mindspore/ops/function/clip_func.py +64 -0
- mindspore/ops/function/debug_func.py +28 -20
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +5 -4
- mindspore/ops/function/math_func.py +1659 -290
- mindspore/ops/function/nn_func.py +988 -317
- mindspore/ops/function/parameter_func.py +3 -56
- mindspore/ops/function/random_func.py +243 -33
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/functional.py +18 -5
- mindspore/ops/functional_overload.py +897 -0
- mindspore/ops/operations/__init__.py +3 -2
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -34
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +38 -8
- mindspore/ops/operations/array_ops.py +45 -303
- mindspore/ops/operations/comm_ops.py +19 -16
- mindspore/ops/operations/custom_ops.py +11 -55
- mindspore/ops/operations/debug_ops.py +42 -47
- mindspore/ops/operations/inner_ops.py +6 -4
- mindspore/ops/operations/linalg_ops.py +3 -2
- mindspore/ops/operations/manually_defined/ops_def.py +185 -104
- mindspore/ops/operations/math_ops.py +11 -216
- mindspore/ops/operations/nn_ops.py +146 -308
- mindspore/ops/primitive.py +23 -21
- mindspore/ops/tensor_method.py +1669 -0
- mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
- mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
- mindspore/ops_generate/arg_handler.py +0 -61
- mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
- mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/base_generator.py +11 -0
- mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
- mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
- mindspore/ops_generate/functional_overload_py_generator.py +110 -0
- mindspore/ops_generate/functions_cc_generator.py +233 -0
- mindspore/ops_generate/gen_aclnn_implement.py +110 -114
- mindspore/ops_generate/gen_constants.py +157 -3
- mindspore/ops_generate/gen_ops.py +245 -990
- mindspore/ops_generate/gen_pyboost_func.py +97 -998
- mindspore/ops_generate/gen_utils.py +119 -33
- mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
- mindspore/ops_generate/op_api_proto.py +206 -0
- mindspore/ops_generate/op_def_py_generator.py +131 -0
- mindspore/ops_generate/op_prim_py_generator.py +480 -0
- mindspore/ops_generate/op_proto.py +373 -108
- mindspore/ops_generate/op_template_parser.py +436 -0
- mindspore/ops_generate/ops_def_cc_generator.py +288 -0
- mindspore/ops_generate/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/ops_name_h_generator.py +68 -0
- mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
- mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
- mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
- mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
- mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
- mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
- mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
- mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
- mindspore/ops_generate/pyboost_utils.py +92 -33
- mindspore/ops_generate/template.py +294 -44
- mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
- mindspore/parallel/__init__.py +3 -3
- mindspore/parallel/_auto_parallel_context.py +24 -33
- mindspore/parallel/_parallel_serialization.py +13 -2
- mindspore/parallel/_utils.py +4 -1
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +44 -0
- mindspore/parallel/cluster/process_entity/_api.py +131 -37
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +20 -3
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +3 -0
- mindspore/parallel/transform_safetensors.py +119 -253
- mindspore/profiler/__init__.py +17 -4
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +174 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +202 -0
- mindspore/profiler/common/path_manager.py +371 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +476 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +210 -0
- mindspore/profiler/common/profiler_path_manager.py +120 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +270 -37
- mindspore/profiler/envprofiler.py +138 -0
- mindspore/profiler/mstx.py +199 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +309 -0
- mindspore/profiler/profiler.py +580 -93
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +114 -0
- mindspore/profiler/schedule.py +208 -0
- mindspore/rewrite/api/symbol_tree.py +1 -2
- mindspore/run_check/_check_version.py +2 -6
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +148 -0
- mindspore/runtime/memory.py +392 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +2 -2
- mindspore/train/_utils.py +53 -18
- mindspore/train/amp.py +8 -4
- mindspore/train/callback/_checkpoint.py +32 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +105 -69
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_summary_collector.py +44 -6
- mindspore/train/callback/_tft_register.py +31 -10
- mindspore/train/dataset_helper.py +11 -11
- mindspore/train/metrics/precision.py +4 -5
- mindspore/train/mind_ir_pb2.py +167 -46
- mindspore/train/model.py +13 -15
- mindspore/train/serialization.py +462 -76
- mindspore/train/summary/summary_record.py +1 -2
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +4 -2
- mindspore/utils/dryrun.py +138 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/METADATA +2 -3
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/RECORD +362 -238
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
|
@@ -15,7 +15,6 @@
|
|
|
15
15
|
"""Dataset help for minddata dataset"""
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
|
-
import os
|
|
19
18
|
import math
|
|
20
19
|
import copy
|
|
21
20
|
|
|
@@ -25,9 +24,11 @@ from mindspore.common._auto_dynamic import is_auto_dynamic, convert_new_shapes
|
|
|
25
24
|
from mindspore.common.dtype import pytype_to_dtype
|
|
26
25
|
from mindspore.common.api import _cell_graph_executor, _is_args_fullmode, ARG_SPECIFIED
|
|
27
26
|
from mindspore.common._utils import is_shape_unknown
|
|
27
|
+
from mindspore.dataset.core import config as dataset_config
|
|
28
28
|
from mindspore.dataset.engine import offload
|
|
29
29
|
from mindspore import context, nn
|
|
30
|
-
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes,
|
|
30
|
+
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \
|
|
31
|
+
_construct_tensor_list, enable_data_broadcast
|
|
31
32
|
from mindspore.parallel._utils import _get_device_num, _get_global_rank, _need_to_full, \
|
|
32
33
|
_to_full_shapes, _get_pipeline_stages, _change_symbols_for_parallel, _is_in_auto_parallel_mode, \
|
|
33
34
|
_origin_shapes, _dynamic_shape_for_dataset
|
|
@@ -263,16 +264,14 @@ def connect_network_with_dataset(network, dataset_helper):
|
|
|
263
264
|
"The dataset has been connected to other network, please check the code.")
|
|
264
265
|
is_dynamic = bool(network.get_inputs())
|
|
265
266
|
queue_name = dataset.__transfer_dataset__.queue_name
|
|
267
|
+
|
|
266
268
|
# In pipeline parallel, some stages have no GetNext, should not get in.
|
|
269
|
+
# Don't enable dynamic shape(multi-subgraph) feature in pp/dataset_broadcast mode,
|
|
270
|
+
# otherwise get_data_info will stuck since some rank do not consume data.
|
|
267
271
|
use_pipeline_parallel = (context.get_auto_parallel_context("pipeline_stages") > 1)
|
|
272
|
+
data_broadcast = enable_data_broadcast()
|
|
268
273
|
|
|
269
|
-
|
|
270
|
-
dynamic_sink1_env = os.getenv("MS_DEV_DYNAMIC_SINK1", None)
|
|
271
|
-
dynamic_sink1 = True
|
|
272
|
-
if dynamic_sink1_env and dynamic_sink1_env.strip() in ['False', 'false']:
|
|
273
|
-
dynamic_sink1 = False
|
|
274
|
-
|
|
275
|
-
if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic) and not use_pipeline_parallel and dynamic_sink1:
|
|
274
|
+
if _dynamic_sink_scenario(dataset, dataset_iter, is_dynamic) and not use_pipeline_parallel and not data_broadcast:
|
|
276
275
|
dataset_types, dataset_shapes = dataset_helper.get_data_info()
|
|
277
276
|
# Need to do full_batch for shapes which also do in the _DatasetIterMSLoopSink
|
|
278
277
|
if _need_to_full():
|
|
@@ -314,7 +313,7 @@ def connect_network_with_dataset(network, dataset_helper):
|
|
|
314
313
|
aux.__shape_type__ = str(dataset_types) + str(dataset_shapes)
|
|
315
314
|
|
|
316
315
|
if _dynamic_sink_data(dataset, dataset_iter) and _dynamic_sink_exception_scenario(dataset_iter, is_dynamic) and \
|
|
317
|
-
not use_pipeline_parallel and
|
|
316
|
+
not use_pipeline_parallel and not data_broadcast:
|
|
318
317
|
dataset_helper.get_data_info()
|
|
319
318
|
network.add_flags(sink_mode=True)
|
|
320
319
|
return network
|
|
@@ -686,8 +685,9 @@ class _DatasetIterNormal:
|
|
|
686
685
|
self.dataset = dataset
|
|
687
686
|
self.device_num = _get_device_num()
|
|
688
687
|
self.global_rank = _get_global_rank()
|
|
688
|
+
do_copy = dataset_config.get_iterator_mode()["do_copy"]
|
|
689
689
|
self.iter = self.dataset.create_tuple_iterator(
|
|
690
|
-
num_epochs=epoch_num, do_copy=
|
|
690
|
+
num_epochs=epoch_num, do_copy=do_copy)
|
|
691
691
|
|
|
692
692
|
def __iter__(self):
|
|
693
693
|
return self
|
|
@@ -32,11 +32,9 @@ class Precision(EvaluationBase):
|
|
|
32
32
|
.. math::
|
|
33
33
|
\text{precision} = \frac{\text{true_positive}}{\text{true_positive} + \text{false_positive}}
|
|
34
34
|
|
|
35
|
-
Note:
|
|
36
|
-
In the multi-label cases, the elements of :math:`y` and :math:`y_{pred}` must be 0 or 1.
|
|
37
|
-
|
|
38
35
|
Args:
|
|
39
|
-
eval_type (str): ``'classification'`` or ``'multilabel'`` are supported.
|
|
36
|
+
eval_type (str): ``'classification'`` or ``'multilabel'`` are supported. See the update method below
|
|
37
|
+
for what it does. Default: ``'classification'`` .
|
|
40
38
|
|
|
41
39
|
Supported Platforms:
|
|
42
40
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -76,7 +74,8 @@ class Precision(EvaluationBase):
|
|
|
76
74
|
@rearrange_inputs
|
|
77
75
|
def update(self, *inputs):
|
|
78
76
|
"""
|
|
79
|
-
Updates the internal evaluation result with `y_pred` and `y`.
|
|
77
|
+
Updates the internal evaluation result with `y_pred` and `y`. In the multi-label cases, the elements of
|
|
78
|
+
:math:`y` and :math:`y_pred` must be 0 or 1.
|
|
80
79
|
|
|
81
80
|
Args:
|
|
82
81
|
inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray.
|
mindspore/train/mind_ir_pb2.py
CHANGED
|
@@ -20,7 +20,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
|
|
|
20
20
|
syntax='proto2',
|
|
21
21
|
serialized_options=None,
|
|
22
22
|
create_key=_descriptor._internal_create_key,
|
|
23
|
-
serialized_pb=b'\n\rmind_ir.proto\x12\x07mind_ir\"\
|
|
23
|
+
serialized_pb=b'\n\rmind_ir.proto\x12\x07mind_ir\"\xd8\t\n\x0e\x41ttributeProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\t\n\x01\x66\x18\x02 \x01(\x02\x12\t\n\x01i\x18\x03 \x01(\x03\x12\t\n\x01\x64\x18\x04 \x01(\x01\x12\t\n\x01s\x18\x05 \x01(\x0c\x12\x1f\n\x01t\x18\x06 \x01(\x0b\x32\x14.mind_ir.TensorProto\x12\x1e\n\x01g\x18\x07 \x01(\x0b\x32\x13.mind_ir.GraphProto\x12\x0e\n\x06\x66loats\x18\x08 \x03(\x02\x12\x0f\n\x07\x64oubles\x18\t \x03(\x01\x12\x0c\n\x04ints\x18\n \x03(\x03\x12\x0f\n\x07strings\x18\x0b \x03(\x0c\x12%\n\x07tensors\x18\x0c \x03(\x0b\x32\x14.mind_ir.TensorProto\x12#\n\x06graphs\x18\r \x03(\x0b\x32\x13.mind_ir.GraphProto\x12\x12\n\ndoc_string\x18\x0e \x01(\t\x12\x15\n\rref_attr_name\x18\x0f \x01(\t\x12\x33\n\x04type\x18\x10 \x01(\x0e\x32%.mind_ir.AttributeProto.AttributeType\x12\'\n\x06values\x18\x11 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x36\n\x08seq_info\x18\x12 \x01(\x0b\x32$.mind_ir.AttributeProto.SeqInfoProto\x12&\n\x07\x66unctor\x18\x13 \x01(\x0b\x32\x15.mind_ir.FunctorProto\x12\x35\n\x0cgraph_holder\x18\x14 \x01(\x0b\x32\x1f.mind_ir.ScalarGraphHolderProto\x1aT\n\x0cSeqInfoProto\x12\x12\n\nis_dyn_len\x18\x01 \x01(\x08\x12\x30\n\x0ftuple_elem_item\x18\x02 \x01(\x0b\x32\x17.mind_ir.AttributeProto\"\xc8\x04\n\rAttributeType\x12\r\n\tUNDEFINED\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\t\n\x05UINT8\x10\x02\x12\x08\n\x04INT8\x10\x03\x12\n\n\x06UINT16\x10\x04\x12\t\n\x05INT16\x10\x05\x12\t\n\x05INT32\x10\x06\x12\t\n\x05INT64\x10\x07\x12\n\n\x06STRING\x10\x08\x12\x08\n\x04\x42OOL\x10\t\x12\x0b\n\x07\x46LOAT16\x10\n\x12\n\n\x06\x44OUBLE\x10\x0b\x12\n\n\x06UINT32\x10\x0c\x12\n\n\x06UINT64\x10\r\x12\r\n\tCOMPLEX64\x10\x0e\x12\x0e\n\nCOMPLEX128\x10\x0f\x12\x0c\n\x08\x42\x46LOAT16\x10\x10\x12\n\n\x06TENSOR\x10\x11\x12\t\n\x05GRAPH\x10\x12\x12\x0b\n\x07TENSORS\x10\x13\x12\t\n\x05TUPLE\x10\x14\x12\x08\n\x04LIST\x10\x15\x12\x08\n\x04\x44ICT\x10\x16\x12\n\n\x06UMONAD\x10\x17\x12\x0b\n\x07IOMONAD\x10\x18\x12\x08\n\x04NONE\x10\x19\x12\x14\n\x10PRIMITIVECLOSURE\x10\x1a\x12\x14\n\x10\x46UNCGRAPHCLOSURE\x10\x1b\x12\x12\n\x0ePARTIALCLOSURE\x10\x1c\x12\x14\n\x10UNIONFUNCCLOSURE\x10\x1d\x12\x0e\n\nCSR_TENSOR\x10\x1e\x12\x0e\n\nCOO_TENSOR\x10\x1f\x12\x0e\n\nROW_TENSOR\x10 \x12\x0e\n\nCLASS_TYPE\x10!\x12\x0e\n\nNAME_SPACE\x10\"\x12\n\n\x06SYMBOL\x10#\x12\r\n\tTYPE_NULL\x10$\x12\x0e\n\nMAP_TENSOR\x10%\x12\x0b\n\x07\x46UNCTOR\x10&\x12\n\n\x06SCALAR\x10\'\x12\x17\n\x13SCALAR_GRAPH_HOLDER\x10(\"\xae\x01\n\x0c\x46unctorProto\x12/\n\x04type\x18\x01 \x01(\x0e\x32!.mind_ir.FunctorProto.FunctorType\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\'\n\x06values\x18\x03 \x03(\x0b\x32\x17.mind_ir.AttributeProto\"6\n\x0b\x46unctorType\x12\x16\n\x12SHAPE_CALC_FUNCTOR\x10\x01\x12\x0f\n\x0b\x41NY_FUNCTOR\x10\x02\"\x98\x01\n\x0eValueInfoProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12$\n\x06tensor\x18\x02 \x03(\x0b\x32\x14.mind_ir.TensorProto\x12\x12\n\ndoc_string\x18\x03 \x01(\t\x12\x12\n\ndenotation\x18\x04 \x01(\t\x12*\n\tattr_info\x18\x05 \x01(\x0b\x32\x17.mind_ir.AttributeProto\"\xf3\x01\n\tNodeProto\x12\r\n\x05input\x18\x01 \x03(\t\x12\x0e\n\x06output\x18\x02 \x03(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0f\n\x07op_type\x18\x04 \x01(\t\x12*\n\tattribute\x18\x05 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x12\n\ndoc_string\x18\x06 \x01(\t\x12\x0e\n\x06\x64omain\x18\x07 \x01(\t\x12*\n\tnode_attr\x18\x08 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12,\n\x0bprimal_attr\x18\t \x03(\x0b\x32\x17.mind_ir.AttributeProto\"\xf8\x03\n\nModelProto\x12\x12\n\nir_version\x18\x01 \x01(\t\x12\x15\n\rproducer_name\x18\x02 \x01(\t\x12\x18\n\x10producer_version\x18\x03 \x01(\t\x12\x0e\n\x06\x64omain\x18\x04 \x01(\t\x12\x15\n\rmodel_version\x18\x05 \x01(\t\x12\x12\n\ndoc_string\x18\x06 \x01(\t\x12\"\n\x05graph\x18\x07 \x01(\x0b\x32\x13.mind_ir.GraphProto\x12&\n\tfunctions\x18\x08 \x03(\x0b\x32\x13.mind_ir.GraphProto\x12\x30\n\x0cpreprocessor\x18\t \x01(\x0b\x32\x1a.mind_ir.PreprocessorProto\x12\x15\n\rlittle_endian\x18\n \x01(\x08\x12(\n\x08parallel\x18\x0b \x01(\x0b\x32\x16.mind_ir.ParallelProto\x12+\n\nprimitives\x18\x0c \x03(\x0b\x32\x17.mind_ir.PrimitiveProto\x12\x17\n\x0fmind_ir_version\x18\r \x01(\x03\x12\x34\n\tuser_info\x18\x0e \x03(\x0b\x32!.mind_ir.ModelProto.UserInfoEntry\x1a/\n\rUserInfoEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\";\n\x11PreprocessorProto\x12&\n\x02op\x18\x01 \x03(\x0b\x32\x1a.mind_ir.PreprocessOpProto\"\x91\x01\n\x11PreprocessOpProto\x12\x15\n\rinput_columns\x18\x01 \x01(\t\x12\x16\n\x0eoutput_columns\x18\x02 \x01(\t\x12\x17\n\x0fproject_columns\x18\x03 \x01(\t\x12\x0f\n\x07op_type\x18\x04 \x01(\t\x12\x12\n\noperations\x18\x05 \x01(\t\x12\x0f\n\x07offload\x18\x06 \x01(\x08\"\xd2\x02\n\nGraphProto\x12 \n\x04node\x18\x01 \x03(\x0b\x32\x12.mind_ir.NodeProto\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\'\n\tparameter\x18\x03 \x03(\x0b\x32\x14.mind_ir.TensorProto\x12\x12\n\ndoc_string\x18\x04 \x01(\t\x12&\n\x05input\x18\x05 \x03(\x0b\x32\x17.mind_ir.ValueInfoProto\x12\'\n\x06output\x18\x06 \x03(\x0b\x32\x17.mind_ir.ValueInfoProto\x12\x12\n\nbprop_hash\x18\x07 \x01(\t\x12*\n\tattribute\x18\x08 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x16\n\x0e\x62prop_filepath\x18\t \x01(\t\x12.\n\rmap_parameter\x18\n \x03(\x0b\x32\x17.mind_ir.MapTensorProto\"\xda\x07\n\x0bTensorProto\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\x12\x11\n\tdata_type\x18\x02 \x01(\x05\x12\x12\n\nfloat_data\x18\x03 \x03(\x02\x12\x12\n\nint32_data\x18\x04 \x03(\x05\x12\x13\n\x0bstring_data\x18\x05 \x03(\x0c\x12\x12\n\nint64_data\x18\x06 \x03(\x03\x12\x0c\n\x04name\x18\x07 \x01(\t\x12\x12\n\ndoc_string\x18\x08 \x01(\t\x12\x10\n\x08raw_data\x18\t \x01(\x0c\x12\x13\n\x0b\x64ouble_data\x18\n \x03(\x01\x12\x13\n\x0buint64_data\x18\x0b \x03(\x04\x12=\n\rexternal_data\x18\x0c \x01(\x0b\x32&.mind_ir.TensorProto.ExternalDataProto\x12\x0f\n\x07ref_key\x18\r \x01(\t\x12\x10\n\x08min_dims\x18\x0e \x03(\x03\x12\x10\n\x08max_dims\x18\x0f \x03(\x03\x12>\n\x10\x63ompression_type\x18\x10 \x01(\x0e\x32$.mind_ir.TensorProto.CompressionType\x12:\n\x0cquant_params\x18\x11 \x03(\x0b\x32$.mind_ir.TensorProto.QuantParamProto\x1a\x45\n\x11\x45xternalDataProto\x12\x10\n\x08location\x18\x01 \x01(\t\x12\x0e\n\x06offset\x18\x02 \x01(\x03\x12\x0e\n\x06length\x18\x03 \x01(\x03\x1aV\n\x0fQuantParamProto\x12\x17\n\x0fquant_algo_name\x18\x01 \x02(\t\x12*\n\tattribute\x18\x02 \x03(\x0b\x32\x17.mind_ir.AttributeProto\"\xf4\x01\n\x08\x44\x61taType\x12\r\n\tUNDEFINED\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\t\n\x05UINT8\x10\x02\x12\x08\n\x04INT8\x10\x03\x12\n\n\x06UINT16\x10\x04\x12\t\n\x05INT16\x10\x05\x12\t\n\x05INT32\x10\x06\x12\t\n\x05INT64\x10\x07\x12\n\n\x06STRING\x10\x08\x12\x08\n\x04\x42OOL\x10\t\x12\x0b\n\x07\x46LOAT16\x10\n\x12\n\n\x06\x44OUBLE\x10\x0b\x12\n\n\x06UINT32\x10\x0c\x12\n\n\x06UINT64\x10\r\x12\r\n\tCOMPLEX64\x10\x0e\x12\x0e\n\nCOMPLEX128\x10\x0f\x12\x0c\n\x08\x42\x46LOAT16\x10\x10\x12\x0b\n\x07\x46LOAT64\x10\x11\x12\x0b\n\x07QINT4X2\x10\x12\"u\n\x0f\x43ompressionType\x12\x12\n\x0eNO_COMPRESSION\x10\x00\x12\x0c\n\x08INDEXING\x10\x01\x12\n\n\x06SPARSE\x10\x02\x12\x07\n\x03\x46SE\x10\x03\x12\x0f\n\x0b\x42IT_PACKING\x10\x04\x12\x0b\n\x07\x46SE_INT\x10\x05\x12\r\n\tFSE_INFER\x10\x06\"\xd1\x01\n\x0eMapTensorProto\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\rdefault_value\x18\x02 \x02(\x0b\x32\x17.mind_ir.AttributeProto\x12(\n\nkey_tensor\x18\x03 \x02(\x0b\x32\x14.mind_ir.TensorProto\x12*\n\x0cvalue_tensor\x18\x04 \x02(\x0b\x32\x14.mind_ir.TensorProto\x12+\n\rstatus_tensor\x18\x05 \x02(\x0b\x32\x14.mind_ir.TensorProto\"5\n\rParallelProto\x12$\n\x06layout\x18\x01 \x03(\x0b\x32\x14.mind_ir.LayoutProto\"\xfd\x01\n\x0bLayoutProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1e\n\x16\x64\x65vice_arrangement_int\x18\x02 \x03(\x03\x12\x16\n\x0etensor_map_int\x18\x03 \x03(\x03\x12\x17\n\x0fslice_shape_int\x18\x04 \x03(\x03\x12\x12\n\nfield_size\x18\x05 \x01(\x03\x12\x15\n\runiform_split\x18\x06 \x01(\x08\x12\x17\n\x0fopt_shard_group\x18\x07 \x01(\t\x12\x17\n\x0fpipeline_shared\x18\x08 \x01(\x08\x12\x0f\n\x07is_send\x18\t \x01(\x08\x12\x11\n\tpeer_rank\x18\n \x01(\x03\x12\x0e\n\x06sr_tag\x18\x0b \x01(\x03\"\xda\x01\n\x0ePrimitiveProto\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07op_type\x18\x02 \x01(\t\x12*\n\tattribute\x18\x03 \x03(\x0b\x32\x17.mind_ir.AttributeProto\x12\x15\n\rinstance_name\x18\x04 \x01(\t\x12\x33\n\tprim_type\x18\x05 \x01(\x0e\x32 .mind_ir.PrimitiveProto.PrimType\"1\n\x08PrimType\x12\r\n\tPRIMITIVE\x10\x01\x12\x16\n\x12PRIMITIVE_FUNCTION\x10\x02\"J\n\x0fScalarNodeProto\x12\x16\n\x0escalar_op_type\x18\x01 \x01(\x03\x12\x10\n\x08in_index\x18\x02 \x03(\x03\x12\r\n\x05value\x18\x03 \x03(\x03\"b\n\x16ScalarGraphHolderProto\x12-\n\x0bscalar_node\x18\x01 \x03(\x0b\x32\x18.mind_ir.ScalarNodeProto\x12\x19\n\x11input_shape_index\x18\x02 \x03(\x04*{\n\x07Version\x12\x14\n\x10IR_VERSION_START\x10\x00\x12\x0e\n\nIR_VERSION\x10\x01\x12!\n\x1dIR_VERSION_WITH_PRIM_FUNCTION\x10\x02\x12\'\n#IR_VERSION_WITH_SCALAR_GRAPH_HOLDER\x10\x03'
|
|
24
24
|
)
|
|
25
25
|
|
|
26
26
|
_VERSION = _descriptor.EnumDescriptor(
|
|
@@ -45,11 +45,16 @@ _VERSION = _descriptor.EnumDescriptor(
|
|
|
45
45
|
serialized_options=None,
|
|
46
46
|
type=None,
|
|
47
47
|
create_key=_descriptor._internal_create_key),
|
|
48
|
+
_descriptor.EnumValueDescriptor(
|
|
49
|
+
name='IR_VERSION_WITH_SCALAR_GRAPH_HOLDER', index=3, number=3,
|
|
50
|
+
serialized_options=None,
|
|
51
|
+
type=None,
|
|
52
|
+
create_key=_descriptor._internal_create_key),
|
|
48
53
|
],
|
|
49
54
|
containing_type=None,
|
|
50
55
|
serialized_options=None,
|
|
51
|
-
serialized_start=
|
|
52
|
-
serialized_end=
|
|
56
|
+
serialized_start=4813,
|
|
57
|
+
serialized_end=4936,
|
|
53
58
|
)
|
|
54
59
|
_sym_db.RegisterEnumDescriptor(_VERSION)
|
|
55
60
|
|
|
@@ -57,6 +62,7 @@ Version = enum_type_wrapper.EnumTypeWrapper(_VERSION)
|
|
|
57
62
|
IR_VERSION_START = 0
|
|
58
63
|
IR_VERSION = 1
|
|
59
64
|
IR_VERSION_WITH_PRIM_FUNCTION = 2
|
|
65
|
+
IR_VERSION_WITH_SCALAR_GRAPH_HOLDER = 3
|
|
60
66
|
|
|
61
67
|
|
|
62
68
|
_ATTRIBUTEPROTO_ATTRIBUTETYPE = _descriptor.EnumDescriptor(
|
|
@@ -266,11 +272,16 @@ _ATTRIBUTEPROTO_ATTRIBUTETYPE = _descriptor.EnumDescriptor(
|
|
|
266
272
|
serialized_options=None,
|
|
267
273
|
type=None,
|
|
268
274
|
create_key=_descriptor._internal_create_key),
|
|
275
|
+
_descriptor.EnumValueDescriptor(
|
|
276
|
+
name='SCALAR_GRAPH_HOLDER', index=40, number=40,
|
|
277
|
+
serialized_options=None,
|
|
278
|
+
type=None,
|
|
279
|
+
create_key=_descriptor._internal_create_key),
|
|
269
280
|
],
|
|
270
281
|
containing_type=None,
|
|
271
282
|
serialized_options=None,
|
|
272
|
-
serialized_start=
|
|
273
|
-
serialized_end=
|
|
283
|
+
serialized_start=683,
|
|
284
|
+
serialized_end=1267,
|
|
274
285
|
)
|
|
275
286
|
_sym_db.RegisterEnumDescriptor(_ATTRIBUTEPROTO_ATTRIBUTETYPE)
|
|
276
287
|
|
|
@@ -294,8 +305,8 @@ _FUNCTORPROTO_FUNCTORTYPE = _descriptor.EnumDescriptor(
|
|
|
294
305
|
],
|
|
295
306
|
containing_type=None,
|
|
296
307
|
serialized_options=None,
|
|
297
|
-
serialized_start=
|
|
298
|
-
serialized_end=
|
|
308
|
+
serialized_start=1390,
|
|
309
|
+
serialized_end=1444,
|
|
299
310
|
)
|
|
300
311
|
_sym_db.RegisterEnumDescriptor(_FUNCTORPROTO_FUNCTORTYPE)
|
|
301
312
|
|
|
@@ -404,8 +415,8 @@ _TENSORPROTO_DATATYPE = _descriptor.EnumDescriptor(
|
|
|
404
415
|
],
|
|
405
416
|
containing_type=None,
|
|
406
417
|
serialized_options=None,
|
|
407
|
-
serialized_start=
|
|
408
|
-
serialized_end=
|
|
418
|
+
serialized_start=3528,
|
|
419
|
+
serialized_end=3772,
|
|
409
420
|
)
|
|
410
421
|
_sym_db.RegisterEnumDescriptor(_TENSORPROTO_DATATYPE)
|
|
411
422
|
|
|
@@ -454,8 +465,8 @@ _TENSORPROTO_COMPRESSIONTYPE = _descriptor.EnumDescriptor(
|
|
|
454
465
|
],
|
|
455
466
|
containing_type=None,
|
|
456
467
|
serialized_options=None,
|
|
457
|
-
serialized_start=
|
|
458
|
-
serialized_end=
|
|
468
|
+
serialized_start=3774,
|
|
469
|
+
serialized_end=3891,
|
|
459
470
|
)
|
|
460
471
|
_sym_db.RegisterEnumDescriptor(_TENSORPROTO_COMPRESSIONTYPE)
|
|
461
472
|
|
|
@@ -479,8 +490,8 @@ _PRIMITIVEPROTO_PRIMTYPE = _descriptor.EnumDescriptor(
|
|
|
479
490
|
],
|
|
480
491
|
containing_type=None,
|
|
481
492
|
serialized_options=None,
|
|
482
|
-
serialized_start=
|
|
483
|
-
serialized_end=
|
|
493
|
+
serialized_start=4586,
|
|
494
|
+
serialized_end=4635,
|
|
484
495
|
)
|
|
485
496
|
_sym_db.RegisterEnumDescriptor(_PRIMITIVEPROTO_PRIMTYPE)
|
|
486
497
|
|
|
@@ -519,8 +530,8 @@ _ATTRIBUTEPROTO_SEQINFOPROTO = _descriptor.Descriptor(
|
|
|
519
530
|
extension_ranges=[],
|
|
520
531
|
oneofs=[
|
|
521
532
|
],
|
|
522
|
-
serialized_start=
|
|
523
|
-
serialized_end=
|
|
533
|
+
serialized_start=596,
|
|
534
|
+
serialized_end=680,
|
|
524
535
|
)
|
|
525
536
|
|
|
526
537
|
_ATTRIBUTEPROTO = _descriptor.Descriptor(
|
|
@@ -664,6 +675,13 @@ _ATTRIBUTEPROTO = _descriptor.Descriptor(
|
|
|
664
675
|
message_type=None, enum_type=None, containing_type=None,
|
|
665
676
|
is_extension=False, extension_scope=None,
|
|
666
677
|
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
|
678
|
+
_descriptor.FieldDescriptor(
|
|
679
|
+
name='graph_holder', full_name='mind_ir.AttributeProto.graph_holder', index=19,
|
|
680
|
+
number=20, type=11, cpp_type=10, label=1,
|
|
681
|
+
has_default_value=False, default_value=None,
|
|
682
|
+
message_type=None, enum_type=None, containing_type=None,
|
|
683
|
+
is_extension=False, extension_scope=None,
|
|
684
|
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
|
667
685
|
],
|
|
668
686
|
extensions=[
|
|
669
687
|
],
|
|
@@ -678,7 +696,7 @@ _ATTRIBUTEPROTO = _descriptor.Descriptor(
|
|
|
678
696
|
oneofs=[
|
|
679
697
|
],
|
|
680
698
|
serialized_start=27,
|
|
681
|
-
serialized_end=
|
|
699
|
+
serialized_end=1267,
|
|
682
700
|
)
|
|
683
701
|
|
|
684
702
|
|
|
@@ -724,8 +742,8 @@ _FUNCTORPROTO = _descriptor.Descriptor(
|
|
|
724
742
|
extension_ranges=[],
|
|
725
743
|
oneofs=[
|
|
726
744
|
],
|
|
727
|
-
serialized_start=
|
|
728
|
-
serialized_end=
|
|
745
|
+
serialized_start=1270,
|
|
746
|
+
serialized_end=1444,
|
|
729
747
|
)
|
|
730
748
|
|
|
731
749
|
|
|
@@ -784,8 +802,8 @@ _VALUEINFOPROTO = _descriptor.Descriptor(
|
|
|
784
802
|
extension_ranges=[],
|
|
785
803
|
oneofs=[
|
|
786
804
|
],
|
|
787
|
-
serialized_start=
|
|
788
|
-
serialized_end=
|
|
805
|
+
serialized_start=1447,
|
|
806
|
+
serialized_end=1599,
|
|
789
807
|
)
|
|
790
808
|
|
|
791
809
|
|
|
@@ -872,8 +890,8 @@ _NODEPROTO = _descriptor.Descriptor(
|
|
|
872
890
|
extension_ranges=[],
|
|
873
891
|
oneofs=[
|
|
874
892
|
],
|
|
875
|
-
serialized_start=
|
|
876
|
-
serialized_end=
|
|
893
|
+
serialized_start=1602,
|
|
894
|
+
serialized_end=1845,
|
|
877
895
|
)
|
|
878
896
|
|
|
879
897
|
|
|
@@ -911,8 +929,8 @@ _MODELPROTO_USERINFOENTRY = _descriptor.Descriptor(
|
|
|
911
929
|
extension_ranges=[],
|
|
912
930
|
oneofs=[
|
|
913
931
|
],
|
|
914
|
-
serialized_start=
|
|
915
|
-
serialized_end=
|
|
932
|
+
serialized_start=2305,
|
|
933
|
+
serialized_end=2352,
|
|
916
934
|
)
|
|
917
935
|
|
|
918
936
|
_MODELPROTO = _descriptor.Descriptor(
|
|
@@ -1033,8 +1051,8 @@ _MODELPROTO = _descriptor.Descriptor(
|
|
|
1033
1051
|
extension_ranges=[],
|
|
1034
1052
|
oneofs=[
|
|
1035
1053
|
],
|
|
1036
|
-
serialized_start=
|
|
1037
|
-
serialized_end=
|
|
1054
|
+
serialized_start=1848,
|
|
1055
|
+
serialized_end=2352,
|
|
1038
1056
|
)
|
|
1039
1057
|
|
|
1040
1058
|
|
|
@@ -1065,8 +1083,8 @@ _PREPROCESSORPROTO = _descriptor.Descriptor(
|
|
|
1065
1083
|
extension_ranges=[],
|
|
1066
1084
|
oneofs=[
|
|
1067
1085
|
],
|
|
1068
|
-
serialized_start=
|
|
1069
|
-
serialized_end=
|
|
1086
|
+
serialized_start=2354,
|
|
1087
|
+
serialized_end=2413,
|
|
1070
1088
|
)
|
|
1071
1089
|
|
|
1072
1090
|
|
|
@@ -1132,8 +1150,8 @@ _PREPROCESSOPPROTO = _descriptor.Descriptor(
|
|
|
1132
1150
|
extension_ranges=[],
|
|
1133
1151
|
oneofs=[
|
|
1134
1152
|
],
|
|
1135
|
-
serialized_start=
|
|
1136
|
-
serialized_end=
|
|
1153
|
+
serialized_start=2416,
|
|
1154
|
+
serialized_end=2561,
|
|
1137
1155
|
)
|
|
1138
1156
|
|
|
1139
1157
|
|
|
@@ -1227,8 +1245,8 @@ _GRAPHPROTO = _descriptor.Descriptor(
|
|
|
1227
1245
|
extension_ranges=[],
|
|
1228
1246
|
oneofs=[
|
|
1229
1247
|
],
|
|
1230
|
-
serialized_start=
|
|
1231
|
-
serialized_end=
|
|
1248
|
+
serialized_start=2564,
|
|
1249
|
+
serialized_end=2902,
|
|
1232
1250
|
)
|
|
1233
1251
|
|
|
1234
1252
|
|
|
@@ -1273,8 +1291,8 @@ _TENSORPROTO_EXTERNALDATAPROTO = _descriptor.Descriptor(
|
|
|
1273
1291
|
extension_ranges=[],
|
|
1274
1292
|
oneofs=[
|
|
1275
1293
|
],
|
|
1276
|
-
serialized_start=
|
|
1277
|
-
serialized_end=
|
|
1294
|
+
serialized_start=3368,
|
|
1295
|
+
serialized_end=3437,
|
|
1278
1296
|
)
|
|
1279
1297
|
|
|
1280
1298
|
_TENSORPROTO_QUANTPARAMPROTO = _descriptor.Descriptor(
|
|
@@ -1311,8 +1329,8 @@ _TENSORPROTO_QUANTPARAMPROTO = _descriptor.Descriptor(
|
|
|
1311
1329
|
extension_ranges=[],
|
|
1312
1330
|
oneofs=[
|
|
1313
1331
|
],
|
|
1314
|
-
serialized_start=
|
|
1315
|
-
serialized_end=
|
|
1332
|
+
serialized_start=3439,
|
|
1333
|
+
serialized_end=3525,
|
|
1316
1334
|
)
|
|
1317
1335
|
|
|
1318
1336
|
_TENSORPROTO = _descriptor.Descriptor(
|
|
@@ -1456,8 +1474,8 @@ _TENSORPROTO = _descriptor.Descriptor(
|
|
|
1456
1474
|
extension_ranges=[],
|
|
1457
1475
|
oneofs=[
|
|
1458
1476
|
],
|
|
1459
|
-
serialized_start=
|
|
1460
|
-
serialized_end=
|
|
1477
|
+
serialized_start=2905,
|
|
1478
|
+
serialized_end=3891,
|
|
1461
1479
|
)
|
|
1462
1480
|
|
|
1463
1481
|
|
|
@@ -1516,8 +1534,8 @@ _MAPTENSORPROTO = _descriptor.Descriptor(
|
|
|
1516
1534
|
extension_ranges=[],
|
|
1517
1535
|
oneofs=[
|
|
1518
1536
|
],
|
|
1519
|
-
serialized_start=
|
|
1520
|
-
serialized_end=
|
|
1537
|
+
serialized_start=3894,
|
|
1538
|
+
serialized_end=4103,
|
|
1521
1539
|
)
|
|
1522
1540
|
|
|
1523
1541
|
|
|
@@ -1548,8 +1566,8 @@ _PARALLELPROTO = _descriptor.Descriptor(
|
|
|
1548
1566
|
extension_ranges=[],
|
|
1549
1567
|
oneofs=[
|
|
1550
1568
|
],
|
|
1551
|
-
serialized_start=
|
|
1552
|
-
serialized_end=
|
|
1569
|
+
serialized_start=4105,
|
|
1570
|
+
serialized_end=4158,
|
|
1553
1571
|
)
|
|
1554
1572
|
|
|
1555
1573
|
|
|
@@ -1650,8 +1668,8 @@ _LAYOUTPROTO = _descriptor.Descriptor(
|
|
|
1650
1668
|
extension_ranges=[],
|
|
1651
1669
|
oneofs=[
|
|
1652
1670
|
],
|
|
1653
|
-
serialized_start=
|
|
1654
|
-
serialized_end=
|
|
1671
|
+
serialized_start=4161,
|
|
1672
|
+
serialized_end=4414,
|
|
1655
1673
|
)
|
|
1656
1674
|
|
|
1657
1675
|
|
|
@@ -1711,8 +1729,93 @@ _PRIMITIVEPROTO = _descriptor.Descriptor(
|
|
|
1711
1729
|
extension_ranges=[],
|
|
1712
1730
|
oneofs=[
|
|
1713
1731
|
],
|
|
1714
|
-
serialized_start=
|
|
1715
|
-
serialized_end=
|
|
1732
|
+
serialized_start=4417,
|
|
1733
|
+
serialized_end=4635,
|
|
1734
|
+
)
|
|
1735
|
+
|
|
1736
|
+
|
|
1737
|
+
_SCALARNODEPROTO = _descriptor.Descriptor(
|
|
1738
|
+
name='ScalarNodeProto',
|
|
1739
|
+
full_name='mind_ir.ScalarNodeProto',
|
|
1740
|
+
filename=None,
|
|
1741
|
+
file=DESCRIPTOR,
|
|
1742
|
+
containing_type=None,
|
|
1743
|
+
create_key=_descriptor._internal_create_key,
|
|
1744
|
+
fields=[
|
|
1745
|
+
_descriptor.FieldDescriptor(
|
|
1746
|
+
name='scalar_op_type', full_name='mind_ir.ScalarNodeProto.scalar_op_type', index=0,
|
|
1747
|
+
number=1, type=3, cpp_type=2, label=1,
|
|
1748
|
+
has_default_value=False, default_value=0,
|
|
1749
|
+
message_type=None, enum_type=None, containing_type=None,
|
|
1750
|
+
is_extension=False, extension_scope=None,
|
|
1751
|
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
|
1752
|
+
_descriptor.FieldDescriptor(
|
|
1753
|
+
name='in_index', full_name='mind_ir.ScalarNodeProto.in_index', index=1,
|
|
1754
|
+
number=2, type=3, cpp_type=2, label=3,
|
|
1755
|
+
has_default_value=False, default_value=[],
|
|
1756
|
+
message_type=None, enum_type=None, containing_type=None,
|
|
1757
|
+
is_extension=False, extension_scope=None,
|
|
1758
|
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
|
1759
|
+
_descriptor.FieldDescriptor(
|
|
1760
|
+
name='value', full_name='mind_ir.ScalarNodeProto.value', index=2,
|
|
1761
|
+
number=3, type=3, cpp_type=2, label=3,
|
|
1762
|
+
has_default_value=False, default_value=[],
|
|
1763
|
+
message_type=None, enum_type=None, containing_type=None,
|
|
1764
|
+
is_extension=False, extension_scope=None,
|
|
1765
|
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
|
1766
|
+
],
|
|
1767
|
+
extensions=[
|
|
1768
|
+
],
|
|
1769
|
+
nested_types=[],
|
|
1770
|
+
enum_types=[
|
|
1771
|
+
],
|
|
1772
|
+
serialized_options=None,
|
|
1773
|
+
is_extendable=False,
|
|
1774
|
+
syntax='proto2',
|
|
1775
|
+
extension_ranges=[],
|
|
1776
|
+
oneofs=[
|
|
1777
|
+
],
|
|
1778
|
+
serialized_start=4637,
|
|
1779
|
+
serialized_end=4711,
|
|
1780
|
+
)
|
|
1781
|
+
|
|
1782
|
+
|
|
1783
|
+
_SCALARGRAPHHOLDERPROTO = _descriptor.Descriptor(
|
|
1784
|
+
name='ScalarGraphHolderProto',
|
|
1785
|
+
full_name='mind_ir.ScalarGraphHolderProto',
|
|
1786
|
+
filename=None,
|
|
1787
|
+
file=DESCRIPTOR,
|
|
1788
|
+
containing_type=None,
|
|
1789
|
+
create_key=_descriptor._internal_create_key,
|
|
1790
|
+
fields=[
|
|
1791
|
+
_descriptor.FieldDescriptor(
|
|
1792
|
+
name='scalar_node', full_name='mind_ir.ScalarGraphHolderProto.scalar_node', index=0,
|
|
1793
|
+
number=1, type=11, cpp_type=10, label=3,
|
|
1794
|
+
has_default_value=False, default_value=[],
|
|
1795
|
+
message_type=None, enum_type=None, containing_type=None,
|
|
1796
|
+
is_extension=False, extension_scope=None,
|
|
1797
|
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
|
1798
|
+
_descriptor.FieldDescriptor(
|
|
1799
|
+
name='input_shape_index', full_name='mind_ir.ScalarGraphHolderProto.input_shape_index', index=1,
|
|
1800
|
+
number=2, type=4, cpp_type=4, label=3,
|
|
1801
|
+
has_default_value=False, default_value=[],
|
|
1802
|
+
message_type=None, enum_type=None, containing_type=None,
|
|
1803
|
+
is_extension=False, extension_scope=None,
|
|
1804
|
+
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
|
|
1805
|
+
],
|
|
1806
|
+
extensions=[
|
|
1807
|
+
],
|
|
1808
|
+
nested_types=[],
|
|
1809
|
+
enum_types=[
|
|
1810
|
+
],
|
|
1811
|
+
serialized_options=None,
|
|
1812
|
+
is_extendable=False,
|
|
1813
|
+
syntax='proto2',
|
|
1814
|
+
extension_ranges=[],
|
|
1815
|
+
oneofs=[
|
|
1816
|
+
],
|
|
1817
|
+
serialized_start=4713,
|
|
1818
|
+
serialized_end=4811,
|
|
1716
1819
|
)
|
|
1717
1820
|
|
|
1718
1821
|
_ATTRIBUTEPROTO_SEQINFOPROTO.fields_by_name['tuple_elem_item'].message_type = _ATTRIBUTEPROTO
|
|
@@ -1725,6 +1828,7 @@ _ATTRIBUTEPROTO.fields_by_name['type'].enum_type = _ATTRIBUTEPROTO_ATTRIBUTETYPE
|
|
|
1725
1828
|
_ATTRIBUTEPROTO.fields_by_name['values'].message_type = _ATTRIBUTEPROTO
|
|
1726
1829
|
_ATTRIBUTEPROTO.fields_by_name['seq_info'].message_type = _ATTRIBUTEPROTO_SEQINFOPROTO
|
|
1727
1830
|
_ATTRIBUTEPROTO.fields_by_name['functor'].message_type = _FUNCTORPROTO
|
|
1831
|
+
_ATTRIBUTEPROTO.fields_by_name['graph_holder'].message_type = _SCALARGRAPHHOLDERPROTO
|
|
1728
1832
|
_ATTRIBUTEPROTO_ATTRIBUTETYPE.containing_type = _ATTRIBUTEPROTO
|
|
1729
1833
|
_FUNCTORPROTO.fields_by_name['type'].enum_type = _FUNCTORPROTO_FUNCTORTYPE
|
|
1730
1834
|
_FUNCTORPROTO.fields_by_name['values'].message_type = _ATTRIBUTEPROTO
|
|
@@ -1764,6 +1868,7 @@ _PARALLELPROTO.fields_by_name['layout'].message_type = _LAYOUTPROTO
|
|
|
1764
1868
|
_PRIMITIVEPROTO.fields_by_name['attribute'].message_type = _ATTRIBUTEPROTO
|
|
1765
1869
|
_PRIMITIVEPROTO.fields_by_name['prim_type'].enum_type = _PRIMITIVEPROTO_PRIMTYPE
|
|
1766
1870
|
_PRIMITIVEPROTO_PRIMTYPE.containing_type = _PRIMITIVEPROTO
|
|
1871
|
+
_SCALARGRAPHHOLDERPROTO.fields_by_name['scalar_node'].message_type = _SCALARNODEPROTO
|
|
1767
1872
|
DESCRIPTOR.message_types_by_name['AttributeProto'] = _ATTRIBUTEPROTO
|
|
1768
1873
|
DESCRIPTOR.message_types_by_name['FunctorProto'] = _FUNCTORPROTO
|
|
1769
1874
|
DESCRIPTOR.message_types_by_name['ValueInfoProto'] = _VALUEINFOPROTO
|
|
@@ -1777,6 +1882,8 @@ DESCRIPTOR.message_types_by_name['MapTensorProto'] = _MAPTENSORPROTO
|
|
|
1777
1882
|
DESCRIPTOR.message_types_by_name['ParallelProto'] = _PARALLELPROTO
|
|
1778
1883
|
DESCRIPTOR.message_types_by_name['LayoutProto'] = _LAYOUTPROTO
|
|
1779
1884
|
DESCRIPTOR.message_types_by_name['PrimitiveProto'] = _PRIMITIVEPROTO
|
|
1885
|
+
DESCRIPTOR.message_types_by_name['ScalarNodeProto'] = _SCALARNODEPROTO
|
|
1886
|
+
DESCRIPTOR.message_types_by_name['ScalarGraphHolderProto'] = _SCALARGRAPHHOLDERPROTO
|
|
1780
1887
|
DESCRIPTOR.enum_types_by_name['Version'] = _VERSION
|
|
1781
1888
|
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
|
|
1782
1889
|
|
|
@@ -1903,6 +2010,20 @@ PrimitiveProto = _reflection.GeneratedProtocolMessageType('PrimitiveProto', (_me
|
|
|
1903
2010
|
})
|
|
1904
2011
|
_sym_db.RegisterMessage(PrimitiveProto)
|
|
1905
2012
|
|
|
2013
|
+
ScalarNodeProto = _reflection.GeneratedProtocolMessageType('ScalarNodeProto', (_message.Message,), {
|
|
2014
|
+
'DESCRIPTOR' : _SCALARNODEPROTO,
|
|
2015
|
+
'__module__' : 'mind_ir_pb2'
|
|
2016
|
+
# @@protoc_insertion_point(class_scope:mind_ir.ScalarNodeProto)
|
|
2017
|
+
})
|
|
2018
|
+
_sym_db.RegisterMessage(ScalarNodeProto)
|
|
2019
|
+
|
|
2020
|
+
ScalarGraphHolderProto = _reflection.GeneratedProtocolMessageType('ScalarGraphHolderProto', (_message.Message,), {
|
|
2021
|
+
'DESCRIPTOR' : _SCALARGRAPHHOLDERPROTO,
|
|
2022
|
+
'__module__' : 'mind_ir_pb2'
|
|
2023
|
+
# @@protoc_insertion_point(class_scope:mind_ir.ScalarGraphHolderProto)
|
|
2024
|
+
})
|
|
2025
|
+
_sym_db.RegisterMessage(ScalarGraphHolderProto)
|
|
2026
|
+
|
|
1906
2027
|
|
|
1907
2028
|
_MODELPROTO_USERINFOENTRY._options = None
|
|
1908
2029
|
# @@protoc_insertion_point(module_scope)
|
mindspore/train/model.py
CHANGED
|
@@ -36,7 +36,7 @@ from mindspore.train.metrics import get_metrics, get_metric_fn
|
|
|
36
36
|
from mindspore._checkparam import check_input_data, check_output_data
|
|
37
37
|
from mindspore import _checkparam as Validator
|
|
38
38
|
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor,\
|
|
39
|
-
|
|
39
|
+
TFTRegister
|
|
40
40
|
from mindspore.train.callback import __all__ as internal_cb_names
|
|
41
41
|
from mindspore.train.callback._cluster_monitor import ClusterMonitor
|
|
42
42
|
from mindspore import context
|
|
@@ -46,7 +46,7 @@ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_
|
|
|
46
46
|
from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_ps_mode, \
|
|
47
47
|
_cache_enable, _enable_distributed_mindrt
|
|
48
48
|
from mindspore.train.metrics import Loss
|
|
49
|
-
from mindspore.
|
|
49
|
+
from mindspore.log import vlog_print
|
|
50
50
|
from mindspore import nn
|
|
51
51
|
from mindspore.boost import AutoBoost
|
|
52
52
|
from mindspore.context import ParallelMode
|
|
@@ -143,21 +143,22 @@ def _handle_tft(func):
|
|
|
143
143
|
except RuntimeError as e:
|
|
144
144
|
logger.info("uce wrapper caught RuntimeError")
|
|
145
145
|
if not uce_env:
|
|
146
|
-
logger.
|
|
146
|
+
logger.error("uce wrapper caught RuntimeError but uce not enable, enter MindIO TTP process.",
|
|
147
|
+
exc_info=True)
|
|
147
148
|
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
148
149
|
raise e
|
|
149
150
|
e_str = str(e)
|
|
150
151
|
logger.info("uce wrapper caught RuntimeError e_str:{}".format(e_str))
|
|
151
152
|
if "UCEError" in e_str:
|
|
152
|
-
obj.is_uce_rank = True
|
|
153
153
|
logger.info("uce wrapper report UCEError")
|
|
154
|
+
obj.is_uce_rank = True
|
|
154
155
|
tft.tft_report_error(tft.ReportState.RS_UCE.value)
|
|
155
156
|
elif "ForceStopError" in e_str:
|
|
156
157
|
logger.info("uce wrapper caught RuntimeError ForceStopError")
|
|
157
158
|
force_stop_err = tft.ReportState.RS_NORMAL.value
|
|
158
159
|
tft.tft_report_error(force_stop_err)
|
|
159
160
|
else:
|
|
160
|
-
logger.
|
|
161
|
+
logger.error("uce wrapper caught other RuntimeError, enter MindIO TTP process.", exc_info=True)
|
|
161
162
|
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
162
163
|
raise e
|
|
163
164
|
ret = tft.tft_wait_next_action()
|
|
@@ -187,12 +188,14 @@ def _handle_tft(func):
|
|
|
187
188
|
cb_initial_step = initial_step
|
|
188
189
|
|
|
189
190
|
kwargs["initial_step"] = cb_initial_step
|
|
191
|
+
# reset all accu grads to zero
|
|
192
|
+
obj._reset_acc_grads()
|
|
190
193
|
|
|
191
194
|
logger.info("uce wrapper repair complete \
|
|
192
195
|
initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch, cb_initial_step))
|
|
193
196
|
continue
|
|
194
197
|
except BaseException as e:
|
|
195
|
-
logger.
|
|
198
|
+
logger.error("uce wrapper caught BaseException error, enter MindIO TTP process.", exc_info=True)
|
|
196
199
|
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
197
200
|
raise e
|
|
198
201
|
else:
|
|
@@ -908,10 +911,6 @@ class Model:
|
|
|
908
911
|
cb_params.list_callback = self._transform_callbacks(callbacks)
|
|
909
912
|
valid_infos = (valid_dataset, valid_frequency, valid_dataset_sink_mode)
|
|
910
913
|
cb_params.list_callback.insert(0, _FrameworkProfilerCallback())
|
|
911
|
-
if os.environ.get("ENABLE_FLOPS_UTILIZATION_COLLECTOR") == "1" and \
|
|
912
|
-
FlopsUtilizationCollector not in cb_params.list_callback:
|
|
913
|
-
cb_params.list_callback.insert(0, FlopsUtilizationCollector(
|
|
914
|
-
cb_params.batch_num, full_flops=False))
|
|
915
914
|
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
916
915
|
cb_params.list_callback.insert(0, _StepSync())
|
|
917
916
|
callbacks = cb_params.list_callback
|
|
@@ -1594,7 +1593,7 @@ class Model:
|
|
|
1594
1593
|
|
|
1595
1594
|
def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
|
|
1596
1595
|
"""
|
|
1597
|
-
Evaluation process in
|
|
1596
|
+
Evaluation process in :func:`mindspore.train.Model.fit`.
|
|
1598
1597
|
|
|
1599
1598
|
Args:
|
|
1600
1599
|
valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process
|
|
@@ -1670,6 +1669,9 @@ class Model:
|
|
|
1670
1669
|
cb_params.eval_results.update({"eval_loss": eval_loss})
|
|
1671
1670
|
list_callback.on_eval_end(run_context)
|
|
1672
1671
|
|
|
1672
|
+
dataset_helper.stop_send()
|
|
1673
|
+
dataset_helper.release()
|
|
1674
|
+
|
|
1673
1675
|
return metrics
|
|
1674
1676
|
|
|
1675
1677
|
def _eval_process(self, valid_dataset, list_callback=None, cb_params=None, add_eval_loss=False):
|
|
@@ -1780,10 +1782,6 @@ class Model:
|
|
|
1780
1782
|
cb_params.mode = "eval"
|
|
1781
1783
|
cb_params.cur_step_num = 0
|
|
1782
1784
|
cb_params.list_callback = self._transform_callbacks(callbacks)
|
|
1783
|
-
if os.environ.get("ENABLE_FLOPS_UTILIZATION_COLLECTOR") == "1" and \
|
|
1784
|
-
FlopsUtilizationCollector not in cb_params.list_callback:
|
|
1785
|
-
cb_params.list_callback.insert(0, FlopsUtilizationCollector(
|
|
1786
|
-
cb_params.batch_num, full_flops=False))
|
|
1787
1785
|
cb_params.network = self._network
|
|
1788
1786
|
|
|
1789
1787
|
self._clear_metrics()
|