mindspore 2.4.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -0
- mindspore/ConcurrencyCheck.dll +0 -0
- mindspore/CppBuildInsights.dll +0 -0
- mindspore/CppCoreCheck.dll +0 -0
- mindspore/EnumIndex.dll +0 -0
- mindspore/EspXEngine.dll +0 -0
- mindspore/HResultCheck.dll +0 -0
- mindspore/KernelTraceControl.dll +0 -0
- mindspore/LocalESPC.dll +0 -0
- mindspore/Microsoft.Diagnostics.Tracing.EventSource.dll +0 -0
- mindspore/Microsoft.VisualStudio.RemoteControl.dll +0 -0
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Microsoft.VisualStudio.Utilities.Internal.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/System.Runtime.CompilerServices.Unsafe.dll +0 -0
- mindspore/VariantClear.dll +0 -0
- mindspore/__init__.py +53 -0
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +106 -0
- mindspore/_checkparam.py +1419 -0
- mindspore/_extends/__init__.py +23 -0
- mindspore/_extends/builtin_operations.py +224 -0
- mindspore/_extends/graph_kernel/__init__.py +17 -0
- mindspore/_extends/graph_kernel/model/__init__.py +19 -0
- mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
- mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
- mindspore/_extends/graph_kernel/model/model.py +553 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
- mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
- mindspore/_extends/graph_kernel/splitter.py +140 -0
- mindspore/_extends/graph_kernel/utils.py +28 -0
- mindspore/_extends/parallel_compile/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
- mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
- mindspore/_extends/parse/__init__.py +49 -0
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +136 -0
- mindspore/_extends/parse/parser.py +1448 -0
- mindspore/_extends/parse/resources.py +213 -0
- mindspore/_extends/parse/standard_method.py +4475 -0
- mindspore/_extends/parse/trope.py +97 -0
- mindspore/_extends/pijit/__init__.py +23 -0
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/__init__.py +19 -0
- mindspore/_extends/remote/kernel_build_server.py +199 -0
- mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/_extends/utils.py +68 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +433 -0
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +42 -0
- mindspore/boost/adasum.py +319 -0
- mindspore/boost/base.py +535 -0
- mindspore/boost/boost.py +400 -0
- mindspore/boost/boost_cell_wrapper.py +790 -0
- mindspore/boost/dim_reduce.py +323 -0
- mindspore/boost/grad_accumulation.py +79 -0
- mindspore/boost/grad_freeze.py +382 -0
- mindspore/boost/group_loss_scale_manager.py +166 -0
- mindspore/boost/less_batch_normalization.py +174 -0
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/cfgpersist.dll +0 -0
- mindspore/clang_rt.asan_dbg_dynamic-x86_64.dll +0 -0
- mindspore/clang_rt.asan_dynamic-x86_64.dll +0 -0
- mindspore/common/__init__.py +86 -0
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_decorator.py +50 -0
- mindspore/common/_jit_fallback_utils.py +110 -0
- mindspore/common/_monad.py +25 -0
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +74 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +46 -0
- mindspore/common/_stub_tensor.py +210 -0
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +122 -0
- mindspore/common/api.py +2064 -0
- mindspore/common/auto_dynamic_shape.py +507 -0
- mindspore/common/dtype.py +422 -0
- mindspore/common/dump.py +130 -0
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +143 -0
- mindspore/common/initializer.py +880 -0
- mindspore/common/jit_config.py +98 -0
- mindspore/common/lazy_inline.py +240 -0
- mindspore/common/mindir_util.py +111 -0
- mindspore/common/mutable.py +234 -0
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +1081 -0
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +260 -0
- mindspore/common/sparse_tensor.py +1175 -0
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +5039 -0
- mindspore/communication/__init__.py +37 -0
- mindspore/communication/_comm_helper.py +501 -0
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +673 -0
- mindspore/config/op_info.config +533 -0
- mindspore/context.py +2077 -0
- mindspore/d3dcompiler_47.dll +0 -0
- mindspore/dataset/__init__.py +90 -0
- mindspore/dataset/audio/__init__.py +61 -0
- mindspore/dataset/audio/transforms.py +3690 -0
- mindspore/dataset/audio/utils.py +386 -0
- mindspore/dataset/audio/validators.py +1172 -0
- mindspore/dataset/callback/__init__.py +20 -0
- mindspore/dataset/callback/ds_callback.py +368 -0
- mindspore/dataset/callback/validators.py +32 -0
- mindspore/dataset/core/__init__.py +13 -0
- mindspore/dataset/core/config.py +1095 -0
- mindspore/dataset/core/datatypes.py +101 -0
- mindspore/dataset/core/py_util_helpers.py +65 -0
- mindspore/dataset/core/validator_helpers.py +781 -0
- mindspore/dataset/debug/__init__.py +21 -0
- mindspore/dataset/debug/debug_hook.py +97 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +124 -0
- mindspore/dataset/engine/cache_admin.py +47 -0
- mindspore/dataset/engine/cache_client.py +129 -0
- mindspore/dataset/engine/datasets.py +4582 -0
- mindspore/dataset/engine/datasets_audio.py +911 -0
- mindspore/dataset/engine/datasets_standard_format.py +543 -0
- mindspore/dataset/engine/datasets_text.py +2161 -0
- mindspore/dataset/engine/datasets_user_defined.py +1184 -0
- mindspore/dataset/engine/datasets_vision.py +4816 -0
- mindspore/dataset/engine/iterators.py +371 -0
- mindspore/dataset/engine/obs/__init__.py +23 -0
- mindspore/dataset/engine/obs/config_loader.py +68 -0
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
- mindspore/dataset/engine/obs/util.py +482 -0
- mindspore/dataset/engine/offload.py +596 -0
- mindspore/dataset/engine/queue.py +304 -0
- mindspore/dataset/engine/samplers.py +895 -0
- mindspore/dataset/engine/serializer_deserializer.py +159 -0
- mindspore/dataset/engine/validators.py +2895 -0
- mindspore/dataset/text/__init__.py +51 -0
- mindspore/dataset/text/transforms.py +1703 -0
- mindspore/dataset/text/utils.py +715 -0
- mindspore/dataset/text/validators.py +642 -0
- mindspore/dataset/transforms/__init__.py +45 -0
- mindspore/dataset/transforms/c_transforms.py +638 -0
- mindspore/dataset/transforms/py_transforms.py +393 -0
- mindspore/dataset/transforms/py_transforms_util.py +255 -0
- mindspore/dataset/transforms/transforms.py +1260 -0
- mindspore/dataset/transforms/validators.py +410 -0
- mindspore/dataset/utils/__init__.py +19 -0
- mindspore/dataset/utils/browse_dataset.py +190 -0
- mindspore/dataset/utils/line_reader.py +126 -0
- mindspore/dataset/vision/__init__.py +65 -0
- mindspore/dataset/vision/c_transforms.py +2641 -0
- mindspore/dataset/vision/py_transforms.py +2120 -0
- mindspore/dataset/vision/py_transforms_util.py +1660 -0
- mindspore/dataset/vision/transforms.py +7295 -0
- mindspore/dataset/vision/utils.py +863 -0
- mindspore/dataset/vision/validators.py +1483 -0
- mindspore/default_config.py +2 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/__init__.py +20 -0
- mindspore/experimental/es/__init__.py +22 -0
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/map_parameter.py +309 -0
- mindspore/experimental/optim/__init__.py +40 -0
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +193 -0
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +290 -0
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +1371 -0
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +262 -0
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +156 -0
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/OWNERS +7 -0
- mindspore/include/api/allocator.h +97 -0
- mindspore/include/api/callback/callback.h +93 -0
- mindspore/include/api/callback/ckpt_saver.h +41 -0
- mindspore/include/api/callback/loss_monitor.h +33 -0
- mindspore/include/api/callback/lr_scheduler.h +51 -0
- mindspore/include/api/callback/time_monitor.h +34 -0
- mindspore/include/api/callback/train_accuracy.h +37 -0
- mindspore/include/api/cell.h +90 -0
- mindspore/include/api/cfg.h +82 -0
- mindspore/include/api/context.h +602 -0
- mindspore/include/api/data_type.h +47 -0
- mindspore/include/api/delegate.h +178 -0
- mindspore/include/api/delegate_api.h +75 -0
- mindspore/include/api/dual_abi_helper.h +208 -0
- mindspore/include/api/format.h +28 -0
- mindspore/include/api/graph.h +46 -0
- mindspore/include/api/kernel.h +58 -0
- mindspore/include/api/kernel_api.h +168 -0
- mindspore/include/api/metrics/accuracy.h +36 -0
- mindspore/include/api/metrics/metrics.h +41 -0
- mindspore/include/api/model.h +438 -0
- mindspore/include/api/model_group.h +91 -0
- mindspore/include/api/model_parallel_runner.h +168 -0
- mindspore/include/api/serialization.h +185 -0
- mindspore/include/api/status.h +192 -0
- mindspore/include/api/types.h +431 -0
- mindspore/include/api/visible.h +41 -0
- mindspore/include/c_api/context_c.h +179 -0
- mindspore/include/c_api/data_type_c.h +52 -0
- mindspore/include/c_api/format_c.h +46 -0
- mindspore/include/c_api/model_c.h +347 -0
- mindspore/include/c_api/status_c.h +79 -0
- mindspore/include/c_api/tensor_c.h +146 -0
- mindspore/include/c_api/types_c.h +67 -0
- mindspore/include/dataset/config.h +163 -0
- mindspore/include/dataset/constants.h +363 -0
- mindspore/include/dataset/execute.h +196 -0
- mindspore/include/dataset/text.h +1092 -0
- mindspore/include/dataset/transforms.h +638 -0
- mindspore/include/dataset/vision.h +2129 -0
- mindspore/include/dataset/vision_ascend.h +206 -0
- mindspore/include/dataset/vision_lite.h +625 -0
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +633 -0
- mindspore/mindrecord/__init__.py +43 -0
- mindspore/mindrecord/common/__init__.py +17 -0
- mindspore/mindrecord/common/constant.py +20 -0
- mindspore/mindrecord/common/enums.py +44 -0
- mindspore/mindrecord/common/exceptions.py +311 -0
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +174 -0
- mindspore/mindrecord/filewriter.py +722 -0
- mindspore/mindrecord/mindpage.py +210 -0
- mindspore/mindrecord/shardheader.py +141 -0
- mindspore/mindrecord/shardindexgenerator.py +74 -0
- mindspore/mindrecord/shardreader.py +117 -0
- mindspore/mindrecord/shardsegment.py +128 -0
- mindspore/mindrecord/shardutils.py +185 -0
- mindspore/mindrecord/shardwriter.py +237 -0
- mindspore/mindrecord/tools/__init__.py +17 -0
- mindspore/mindrecord/tools/cifar10.py +140 -0
- mindspore/mindrecord/tools/cifar100.py +153 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
- mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
- mindspore/mindrecord/tools/csv_to_mr.py +200 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
- mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
- mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/linalg/__init__.py +22 -0
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/__init__.py +47 -0
- mindspore/nn/cell.py +2787 -0
- mindspore/nn/dynamic_lr.py +482 -0
- mindspore/nn/grad/__init__.py +21 -0
- mindspore/nn/grad/cell_grad.py +196 -0
- mindspore/nn/layer/__init__.py +63 -0
- mindspore/nn/layer/activation.py +1822 -0
- mindspore/nn/layer/basic.py +1629 -0
- mindspore/nn/layer/channel_shuffle.py +90 -0
- mindspore/nn/layer/combined.py +248 -0
- mindspore/nn/layer/container.py +734 -0
- mindspore/nn/layer/conv.py +1505 -0
- mindspore/nn/layer/dense.py +204 -0
- mindspore/nn/layer/embedding.py +869 -0
- mindspore/nn/layer/image.py +661 -0
- mindspore/nn/layer/math.py +1069 -0
- mindspore/nn/layer/normalization.py +1273 -0
- mindspore/nn/layer/padding.py +880 -0
- mindspore/nn/layer/pooling.py +2302 -0
- mindspore/nn/layer/rnn_cells.py +388 -0
- mindspore/nn/layer/rnns.py +849 -0
- mindspore/nn/layer/thor_layer.py +963 -0
- mindspore/nn/layer/timedistributed.py +155 -0
- mindspore/nn/layer/transformer.py +823 -0
- mindspore/nn/learning_rate_schedule.py +512 -0
- mindspore/nn/loss/__init__.py +36 -0
- mindspore/nn/loss/loss.py +2924 -0
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/__init__.py +45 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
- mindspore/nn/optim/ada_grad.py +217 -0
- mindspore/nn/optim/adadelta.py +206 -0
- mindspore/nn/optim/adafactor.py +448 -0
- mindspore/nn/optim/adam.py +1297 -0
- mindspore/nn/optim/adamax.py +220 -0
- mindspore/nn/optim/adasum.py +548 -0
- mindspore/nn/optim/asgd.py +216 -0
- mindspore/nn/optim/ftrl.py +401 -0
- mindspore/nn/optim/lamb.py +296 -0
- mindspore/nn/optim/lars.py +202 -0
- mindspore/nn/optim/lazyadam.py +533 -0
- mindspore/nn/optim/momentum.py +239 -0
- mindspore/nn/optim/optimizer.py +1034 -0
- mindspore/nn/optim/proximal_ada_grad.py +242 -0
- mindspore/nn/optim/rmsprop.py +264 -0
- mindspore/nn/optim/rprop.py +251 -0
- mindspore/nn/optim/sgd.py +237 -0
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +1310 -0
- mindspore/nn/probability/__init__.py +22 -0
- mindspore/nn/probability/bijector/__init__.py +35 -0
- mindspore/nn/probability/bijector/bijector.py +337 -0
- mindspore/nn/probability/bijector/exp.py +65 -0
- mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
- mindspore/nn/probability/bijector/invert.py +126 -0
- mindspore/nn/probability/bijector/power_transform.py +196 -0
- mindspore/nn/probability/bijector/scalar_affine.py +167 -0
- mindspore/nn/probability/bijector/softplus.py +189 -0
- mindspore/nn/probability/bnn_layers/__init__.py +29 -0
- mindspore/nn/probability/bnn_layers/_util.py +46 -0
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
- mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
- mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
- mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
- mindspore/nn/probability/distribution/__init__.py +56 -0
- mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
- mindspore/nn/probability/distribution/_utils/utils.py +362 -0
- mindspore/nn/probability/distribution/bernoulli.py +334 -0
- mindspore/nn/probability/distribution/beta.py +391 -0
- mindspore/nn/probability/distribution/categorical.py +435 -0
- mindspore/nn/probability/distribution/cauchy.py +383 -0
- mindspore/nn/probability/distribution/distribution.py +827 -0
- mindspore/nn/probability/distribution/exponential.py +350 -0
- mindspore/nn/probability/distribution/gamma.py +391 -0
- mindspore/nn/probability/distribution/geometric.py +335 -0
- mindspore/nn/probability/distribution/gumbel.py +257 -0
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +272 -0
- mindspore/nn/probability/distribution/logistic.py +379 -0
- mindspore/nn/probability/distribution/normal.py +336 -0
- mindspore/nn/probability/distribution/poisson.py +288 -0
- mindspore/nn/probability/distribution/student_t.py +149 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
- mindspore/nn/probability/distribution/uniform.py +375 -0
- mindspore/nn/reinforcement/__init__.py +24 -0
- mindspore/nn/reinforcement/_batch_read_write.py +142 -0
- mindspore/nn/reinforcement/_tensors_queue.py +152 -0
- mindspore/nn/reinforcement/tensor_array.py +145 -0
- mindspore/nn/sparse/__init__.py +23 -0
- mindspore/nn/sparse/sparse.py +147 -0
- mindspore/nn/wrap/__init__.py +49 -0
- mindspore/nn/wrap/cell_wrapper.py +968 -0
- mindspore/nn/wrap/grad_reducer.py +608 -0
- mindspore/nn/wrap/loss_scale.py +694 -0
- mindspore/numpy/__init__.py +121 -0
- mindspore/numpy/array_creations.py +2731 -0
- mindspore/numpy/array_ops.py +2629 -0
- mindspore/numpy/dtypes.py +185 -0
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +936 -0
- mindspore/numpy/math_ops.py +5911 -0
- mindspore/numpy/utils.py +214 -0
- mindspore/numpy/utils_const.py +565 -0
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +56 -0
- mindspore/ops/_constants.py +30 -0
- mindspore/ops/_grad_experimental/__init__.py +31 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
- mindspore/ops/_grad_experimental/grad_base.py +143 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
- mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
- mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
- mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
- mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
- mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
- mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
- mindspore/ops/_op_impl/__init__.py +23 -0
- mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
- mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
- mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/acos.py +32 -0
- mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
- mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
- mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
- mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
- mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
- mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/angle.py +31 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/asin.py +32 -0
- mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
- mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
- mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
- mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
- mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
- mindspore/ops/_op_impl/aicpu/cast.py +225 -0
- mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
- mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
- mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
- mindspore/ops/_op_impl/aicpu/complex.py +32 -0
- mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
- mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
- mindspore/ops/_op_impl/aicpu/concat.py +57 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +42 -0
- mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/cos.py +34 -0
- mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
- mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
- mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
- mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
- mindspore/ops/_op_impl/aicpu/div.py +41 -0
- mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
- mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
- mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
- mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
- mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
- mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
- mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/exp.py +37 -0
- mindspore/ops/_op_impl/aicpu/expand.py +45 -0
- mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
- mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
- mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
- mindspore/ops/_op_impl/aicpu/eye.py +44 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
- mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
- mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/gather.py +46 -0
- mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
- mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
- mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
- mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
- mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
- mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
- mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
- mindspore/ops/_op_impl/aicpu/identity.py +42 -0
- mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
- mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
- mindspore/ops/_op_impl/aicpu/imag.py +31 -0
- mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
- mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
- mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
- mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
- mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
- mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
- mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
- mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
- mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
- mindspore/ops/_op_impl/aicpu/log.py +37 -0
- mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
- mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
- mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
- mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
- mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
- mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
- mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
- mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
- mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
- mindspore/ops/_op_impl/aicpu/median.py +39 -0
- mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
- mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
- mindspore/ops/_op_impl/aicpu/mul.py +43 -0
- mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/neg.py +36 -0
- mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
- mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
- mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
- mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
- mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
- mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
- mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
- mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
- mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
- mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
- mindspore/ops/_op_impl/aicpu/padding.py +41 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
- mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/pow.py +39 -0
- mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
- mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
- mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/real.py +31 -0
- mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
- mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
- mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
- mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
- mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
- mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
- mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/round.py +34 -0
- mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
- mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
- mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
- mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
- mindspore/ops/_op_impl/aicpu/select.py +45 -0
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
- mindspore/ops/_op_impl/aicpu/sign.py +36 -0
- mindspore/ops/_op_impl/aicpu/sin.py +34 -0
- mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
- mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/slice.py +59 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
- mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
- mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
- mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
- mindspore/ops/_op_impl/aicpu/split.py +45 -0
- mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
- mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/square.py +35 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
- mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
- mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
- mindspore/ops/_op_impl/aicpu/stack.py +45 -0
- mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
- mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
- mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
- mindspore/ops/_op_impl/aicpu/stft.py +70 -0
- mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/sub.py +41 -0
- mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
- mindspore/ops/_op_impl/aicpu/tan.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/tile.py +56 -0
- mindspore/ops/_op_impl/aicpu/topk.py +34 -0
- mindspore/ops/_op_impl/aicpu/trace.py +40 -0
- mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
- mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril.py +42 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
- mindspore/ops/_op_impl/aicpu/triu.py +43 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
- mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
- mindspore/ops/_op_impl/aicpu/unique.py +31 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
- mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
- mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
- mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
- mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
- mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
- mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
- mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
- mindspore/ops/_op_impl/akg/__init__.py +19 -0
- mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
- mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
- mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
- mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
- mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
- mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
- mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
- mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
- mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
- mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
- mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
- mindspore/ops/_op_impl/cpu/__init__.py +78 -0
- mindspore/ops/_op_impl/cpu/adam.py +49 -0
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
- mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
- mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
- mindspore/ops/_op_impl/cpu/cast.py +171 -0
- mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
- mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
- mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
- mindspore/ops/_op_impl/cpu/div.py +32 -0
- mindspore/ops/_op_impl/cpu/dropout.py +31 -0
- mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
- mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
- mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
- mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
- mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
- mindspore/ops/_op_impl/cpu/hswish.py +32 -0
- mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
- mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
- mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
- mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
- mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
- mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/maximum.py +35 -0
- mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
- mindspore/ops/_op_impl/cpu/minimum.py +40 -0
- mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
- mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
- mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
- mindspore/ops/_op_impl/cpu/mul.py +32 -0
- mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
- mindspore/ops/_op_impl/cpu/pad.py +32 -0
- mindspore/ops/_op_impl/cpu/pow.py +32 -0
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
- mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
- mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
- mindspore/ops/_op_impl/cpu/range.py +34 -0
- mindspore/ops/_op_impl/cpu/real_div.py +33 -0
- mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
- mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
- mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/split.py +34 -0
- mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
- mindspore/ops/_op_impl/cpu/stack.py +38 -0
- mindspore/ops/_op_impl/cpu/sub.py +32 -0
- mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
- mindspore/ops/_op_impl/cpu/tile.py +37 -0
- mindspore/ops/_op_impl/cpu/top_k.py +31 -0
- mindspore/ops/_op_impl/cpu/transpose.py +39 -0
- mindspore/ops/_primitive_cache.py +90 -0
- mindspore/ops/_register_for_op.py +73 -0
- mindspore/ops/_utils/__init__.py +20 -0
- mindspore/ops/_utils/utils.py +147 -0
- mindspore/ops/_vmap/__init__.py +25 -0
- mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
- mindspore/ops/_vmap/vmap_base.py +533 -0
- mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
- mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
- mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
- mindspore/ops/_vmap/vmap_image_ops.py +194 -0
- mindspore/ops/_vmap/vmap_math_ops.py +993 -0
- mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
- mindspore/ops/_vmap/vmap_other_ops.py +105 -0
- mindspore/ops/_vmap/vmap_random_ops.py +122 -0
- mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +71 -0
- mindspore/ops/composite/base.py +1318 -0
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +125 -0
- mindspore/ops/composite/multitype_ops/__init__.py +77 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
- mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
- mindspore/ops/deprecated.py +315 -0
- mindspore/ops/function/__init__.py +782 -0
- mindspore/ops/function/array_func.py +7226 -0
- mindspore/ops/function/clip_func.py +384 -0
- mindspore/ops/function/debug_func.py +181 -0
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/__init__.py +34 -0
- mindspore/ops/function/grad/grad_func.py +1425 -0
- mindspore/ops/function/image_func.py +292 -0
- mindspore/ops/function/linalg_func.py +416 -0
- mindspore/ops/function/math_func.py +12228 -0
- mindspore/ops/function/nn_func.py +8609 -0
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +134 -0
- mindspore/ops/function/random_func.py +1715 -0
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +884 -0
- mindspore/ops/function/sparse_unary_func.py +2422 -0
- mindspore/ops/function/spectral_func.py +150 -0
- mindspore/ops/function/vmap_func.py +117 -0
- mindspore/ops/functional.py +464 -0
- mindspore/ops/op_info_register.py +1572 -0
- mindspore/ops/operations/__init__.py +722 -0
- mindspore/ops/operations/_csr_ops.py +403 -0
- mindspore/ops/operations/_custom_grad.py +181 -0
- mindspore/ops/operations/_embedding_cache_ops.py +307 -0
- mindspore/ops/operations/_grad_ops.py +2978 -0
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +2544 -0
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +601 -0
- mindspore/ops/operations/_ocr_ops.py +379 -0
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_pyfunc_registry.py +58 -0
- mindspore/ops/operations/_quant_ops.py +1844 -0
- mindspore/ops/operations/_rl_inner_ops.py +1231 -0
- mindspore/ops/operations/_scalar_ops.py +106 -0
- mindspore/ops/operations/_sequence_ops.py +1155 -0
- mindspore/ops/operations/_sparse_grad_ops.py +56 -0
- mindspore/ops/operations/_tensor_array.py +359 -0
- mindspore/ops/operations/_thor_ops.py +807 -0
- mindspore/ops/operations/array_ops.py +6124 -0
- mindspore/ops/operations/comm_ops.py +1985 -0
- mindspore/ops/operations/control_ops.py +127 -0
- mindspore/ops/operations/custom_ops.py +1129 -0
- mindspore/ops/operations/debug_ops.py +678 -0
- mindspore/ops/operations/image_ops.py +1041 -0
- mindspore/ops/operations/inner_ops.py +697 -0
- mindspore/ops/operations/linalg_ops.py +95 -0
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +5095 -0
- mindspore/ops/operations/nn_ops.py +9575 -0
- mindspore/ops/operations/other_ops.py +874 -0
- mindspore/ops/operations/random_ops.py +1288 -0
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/rl_ops.py +288 -0
- mindspore/ops/operations/sparse_ops.py +2753 -0
- mindspore/ops/operations/spectral_ops.py +111 -0
- mindspore/ops/primitive.py +1046 -0
- mindspore/ops/signature.py +54 -0
- mindspore/ops/vm_impl_registry.py +91 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +30 -0
- mindspore/parallel/_auto_parallel_context.py +1486 -0
- mindspore/parallel/_cell_wrapper.py +174 -0
- mindspore/parallel/_cost_model_context.py +700 -0
- mindspore/parallel/_dp_allreduce_fusion.py +159 -0
- mindspore/parallel/_offload_context.py +275 -0
- mindspore/parallel/_parallel_serialization.py +561 -0
- mindspore/parallel/_ps_context.py +242 -0
- mindspore/parallel/_recovery_context.py +110 -0
- mindspore/parallel/_tensor.py +730 -0
- mindspore/parallel/_transformer/__init__.py +35 -0
- mindspore/parallel/_transformer/layers.py +765 -0
- mindspore/parallel/_transformer/loss.py +251 -0
- mindspore/parallel/_transformer/moe.py +693 -0
- mindspore/parallel/_transformer/op_parallel_config.py +222 -0
- mindspore/parallel/_transformer/transformer.py +3119 -0
- mindspore/parallel/_utils.py +612 -0
- mindspore/parallel/algo_parameter_config.py +400 -0
- mindspore/parallel/checkpoint_transform.py +650 -0
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +14 -0
- mindspore/parallel/mpi/_mpi_config.py +116 -0
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +481 -0
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/perf_msvcbuildinsights.dll +0 -0
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +28 -0
- mindspore/profiler/common/__init__.py +14 -0
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/exceptions/__init__.py +14 -0
- mindspore/profiler/common/exceptions/error_code.py +83 -0
- mindspore/profiler/common/exceptions/exceptions.py +286 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/struct_type.py +118 -0
- mindspore/profiler/common/util.py +472 -0
- mindspore/profiler/common/validator/__init__.py +14 -0
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +254 -0
- mindspore/profiler/parser/__init__.py +14 -0
- mindspore/profiler/parser/aicpu_data_parser.py +272 -0
- mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +116 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
- mindspore/profiler/parser/ascend_op_generator.py +334 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
- mindspore/profiler/parser/base_timeline_generator.py +483 -0
- mindspore/profiler/parser/container.py +229 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
- mindspore/profiler/parser/flops_parser.py +531 -0
- mindspore/profiler/parser/framework_enum.py +111 -0
- mindspore/profiler/parser/framework_parser.py +464 -0
- mindspore/profiler/parser/framework_struct.py +61 -0
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/hccl_parser.py +573 -0
- mindspore/profiler/parser/hwts_log_parser.py +122 -0
- mindspore/profiler/parser/integrator.py +526 -0
- mindspore/profiler/parser/memory_usage_parser.py +277 -0
- mindspore/profiler/parser/minddata_analyzer.py +800 -0
- mindspore/profiler/parser/minddata_parser.py +186 -0
- mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
- mindspore/profiler/parser/op_intermediate_parser.py +149 -0
- mindspore/profiler/parser/optime_parser.py +250 -0
- mindspore/profiler/parser/profiler_info.py +213 -0
- mindspore/profiler/parser/step_trace_parser.py +666 -0
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +1922 -0
- mindspore/rewrite/__init__.py +28 -0
- mindspore/rewrite/api/__init__.py +17 -0
- mindspore/rewrite/api/node.py +519 -0
- mindspore/rewrite/api/node_type.py +53 -0
- mindspore/rewrite/api/pattern_engine.py +490 -0
- mindspore/rewrite/api/scoped_value.py +181 -0
- mindspore/rewrite/api/symbol_tree.py +497 -0
- mindspore/rewrite/ast_helpers/__init__.py +25 -0
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
- mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
- mindspore/rewrite/common/__init__.py +19 -0
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/error_log.py +39 -0
- mindspore/rewrite/common/event.py +28 -0
- mindspore/rewrite/common/namer.py +271 -0
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/common/observable.py +44 -0
- mindspore/rewrite/common/observer.py +54 -0
- mindspore/rewrite/node/__init__.py +22 -0
- mindspore/rewrite/node/call_function.py +95 -0
- mindspore/rewrite/node/cell_container.py +139 -0
- mindspore/rewrite/node/control_flow.py +113 -0
- mindspore/rewrite/node/node.py +1428 -0
- mindspore/rewrite/node/node_manager.py +283 -0
- mindspore/rewrite/node/node_topological_manager.py +223 -0
- mindspore/rewrite/parsers/__init__.py +29 -0
- mindspore/rewrite/parsers/arguments_parser.py +63 -0
- mindspore/rewrite/parsers/assign_parser.py +852 -0
- mindspore/rewrite/parsers/attribute_parser.py +57 -0
- mindspore/rewrite/parsers/class_def_parser.py +289 -0
- mindspore/rewrite/parsers/constant_parser.py +104 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +61 -0
- mindspore/rewrite/parsers/function_def_parser.py +84 -0
- mindspore/rewrite/parsers/if_parser.py +85 -0
- mindspore/rewrite/parsers/module_parser.py +117 -0
- mindspore/rewrite/parsers/parser.py +43 -0
- mindspore/rewrite/parsers/parser_register.py +86 -0
- mindspore/rewrite/parsers/return_parser.py +37 -0
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
- mindspore/rewrite/sparsify/sparsify.py +112 -0
- mindspore/rewrite/sparsify/utils.py +179 -0
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
- mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
- mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
- mindspore/run_check/__init__.py +20 -0
- mindspore/run_check/_check_version.py +507 -0
- mindspore/run_check/run_check.py +66 -0
- mindspore/safeguard/__init__.py +18 -0
- mindspore/safeguard/rewrite_obfuscation.py +875 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +48 -0
- mindspore/train/_utils.py +465 -0
- mindspore/train/amp.py +935 -0
- mindspore/train/anf_ir_pb2.py +1517 -0
- mindspore/train/callback/__init__.py +44 -0
- mindspore/train/callback/_backup_and_restore.py +117 -0
- mindspore/train/callback/_callback.py +613 -0
- mindspore/train/callback/_checkpoint.py +814 -0
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_dataset_graph.py +150 -0
- mindspore/train/callback/_early_stop.py +239 -0
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_history.py +92 -0
- mindspore/train/callback/_lambda_callback.py +80 -0
- mindspore/train/callback/_landscape.py +1049 -0
- mindspore/train/callback/_loss_monitor.py +107 -0
- mindspore/train/callback/_lr_scheduler_callback.py +76 -0
- mindspore/train/callback/_on_request_exit.py +298 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
- mindspore/train/callback/_summary_collector.py +1184 -0
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +141 -0
- mindspore/train/checkpoint_pb2.py +233 -0
- mindspore/train/data_sink.py +219 -0
- mindspore/train/dataset_helper.py +692 -0
- mindspore/train/lineage_pb2.py +1260 -0
- mindspore/train/loss_scale_manager.py +213 -0
- mindspore/train/memory_profiling_pb2.py +298 -0
- mindspore/train/metrics/__init__.py +175 -0
- mindspore/train/metrics/accuracy.py +133 -0
- mindspore/train/metrics/auc.py +129 -0
- mindspore/train/metrics/bleu_score.py +170 -0
- mindspore/train/metrics/confusion_matrix.py +700 -0
- mindspore/train/metrics/cosine_similarity.py +109 -0
- mindspore/train/metrics/dice.py +116 -0
- mindspore/train/metrics/error.py +175 -0
- mindspore/train/metrics/fbeta.py +167 -0
- mindspore/train/metrics/hausdorff_distance.py +333 -0
- mindspore/train/metrics/loss.py +97 -0
- mindspore/train/metrics/mean_surface_distance.py +189 -0
- mindspore/train/metrics/metric.py +373 -0
- mindspore/train/metrics/occlusion_sensitivity.py +225 -0
- mindspore/train/metrics/perplexity.py +133 -0
- mindspore/train/metrics/precision.py +160 -0
- mindspore/train/metrics/recall.py +159 -0
- mindspore/train/metrics/roc.py +223 -0
- mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
- mindspore/train/metrics/topk.py +167 -0
- mindspore/train/mind_ir_pb2.py +1908 -0
- mindspore/train/model.py +2252 -0
- mindspore/train/node_strategy_pb2.py +653 -0
- mindspore/train/print_pb2.py +184 -0
- mindspore/train/profiling_parallel_pb2.py +151 -0
- mindspore/train/serialization.py +3325 -0
- mindspore/train/summary/__init__.py +23 -0
- mindspore/train/summary/_lineage_adapter.py +41 -0
- mindspore/train/summary/_summary_adapter.py +496 -0
- mindspore/train/summary/_writer_pool.py +207 -0
- mindspore/train/summary/enums.py +56 -0
- mindspore/train/summary/summary_record.py +581 -0
- mindspore/train/summary/writer.py +167 -0
- mindspore/train/summary_pb2.py +1165 -0
- mindspore/train/train_thor/__init__.py +20 -0
- mindspore/train/train_thor/convert_utils.py +268 -0
- mindspore/train/train_thor/dataset_helper.py +192 -0
- mindspore/train/train_thor/model_thor.py +257 -0
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcomp140.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -0
- mindspore-2.4.0.dist-info/METADATA +352 -0
- mindspore-2.4.0.dist-info/RECORD +1406 -0
- mindspore-2.4.0.dist-info/WHEEL +5 -0
- mindspore-2.4.0.dist-info/entry_points.txt +3 -0
- mindspore-2.4.0.dist-info/top_level.txt +1 -0
mindspore/train/model.py
ADDED
|
@@ -0,0 +1,2252 @@
|
|
|
1
|
+
# Copyright 2020-2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""Model."""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
from collections.abc import Iterable
|
|
19
|
+
from functools import wraps
|
|
20
|
+
|
|
21
|
+
import sys
|
|
22
|
+
import os
|
|
23
|
+
import math
|
|
24
|
+
import copy
|
|
25
|
+
import importlib
|
|
26
|
+
import time
|
|
27
|
+
import numpy as np
|
|
28
|
+
|
|
29
|
+
import mindspore
|
|
30
|
+
import mindspore.dataset as ds
|
|
31
|
+
from mindspore import log as logger
|
|
32
|
+
from mindspore.train.serialization import save_checkpoint, load_checkpoint
|
|
33
|
+
from mindspore.train.callback._checkpoint import ModelCheckpoint, _chg_ckpt_file_name_if_same_exist
|
|
34
|
+
from mindspore.common.tensor import Tensor
|
|
35
|
+
from mindspore.train.metrics import get_metrics, get_metric_fn
|
|
36
|
+
from mindspore._checkparam import check_input_data, check_output_data
|
|
37
|
+
from mindspore import _checkparam as Validator
|
|
38
|
+
from mindspore.train.callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback, TimeMonitor,\
|
|
39
|
+
FlopsUtilizationCollector, TFTRegister
|
|
40
|
+
from mindspore.train.callback import __all__ as internal_cb_names
|
|
41
|
+
from mindspore.train.callback._cluster_monitor import ClusterMonitor
|
|
42
|
+
from mindspore import context
|
|
43
|
+
from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_parameter_broadcast, \
|
|
44
|
+
_device_number_check, _parameter_broadcast_check, _parallel_predict_check, \
|
|
45
|
+
_reset_op_id_with_offset
|
|
46
|
+
from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver, _is_ps_mode, \
|
|
47
|
+
_cache_enable, _enable_distributed_mindrt
|
|
48
|
+
from mindspore.train.metrics import Loss
|
|
49
|
+
from mindspore import nn
|
|
50
|
+
from mindspore.boost import AutoBoost
|
|
51
|
+
from mindspore.context import ParallelMode
|
|
52
|
+
from mindspore.parallel._recovery_context import _set_recovery_context, _get_recovery_context
|
|
53
|
+
from mindspore.train.dataset_helper import DatasetHelper, connect_network_with_dataset
|
|
54
|
+
from mindspore.common.api import _pynative_executor, ARG_SPECIFIED, TOTAL_ARG_LEN
|
|
55
|
+
from mindspore.dataset.core.config import get_debug_mode
|
|
56
|
+
from mindspore.dataset.engine.datasets import _set_training_dataset, _reset_training_dataset
|
|
57
|
+
from mindspore.train import amp
|
|
58
|
+
from mindspore._c_expression import _framework_profiler_step_start, _framework_profiler_step_end
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _transfer_tensor_to_tuple(inputs):
|
|
62
|
+
"""
|
|
63
|
+
If the input is a tensor, convert it to a tuple. If not, the output is unchanged.
|
|
64
|
+
"""
|
|
65
|
+
if isinstance(inputs, Tensor):
|
|
66
|
+
return (inputs,)
|
|
67
|
+
|
|
68
|
+
return inputs
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class _StepSync(Callback):
|
|
72
|
+
@staticmethod
|
|
73
|
+
def step_end(run_context):
|
|
74
|
+
_pynative_executor.sync()
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class _FrameworkProfilerCallback(Callback):
|
|
78
|
+
"""
|
|
79
|
+
Profiler callback of framework for training.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def step_begin(self, run_context):
|
|
83
|
+
_framework_profiler_step_start()
|
|
84
|
+
|
|
85
|
+
def step_end(self, run_context):
|
|
86
|
+
_framework_profiler_step_end()
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _save_final_ckpt(func):
|
|
90
|
+
"""
|
|
91
|
+
Decorator function, which saves the current checkpoint when an exception occurs during training.
|
|
92
|
+
"""
|
|
93
|
+
@wraps(func)
|
|
94
|
+
def wrapper(self, *args, **kwargs):
|
|
95
|
+
obj = None
|
|
96
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), ModelCheckpoint):
|
|
97
|
+
obj = kwargs.get('callbacks')
|
|
98
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
|
99
|
+
for item in kwargs.get('callbacks'):
|
|
100
|
+
if isinstance(item, ModelCheckpoint):
|
|
101
|
+
obj = item
|
|
102
|
+
if obj and obj._config and obj._config.exception_save:
|
|
103
|
+
try:
|
|
104
|
+
func(self, *args, **kwargs)
|
|
105
|
+
except BaseException as e:
|
|
106
|
+
# pylint: disable=W0212
|
|
107
|
+
prefix = _chg_ckpt_file_name_if_same_exist(obj._directory, obj._exception_prefix, True)
|
|
108
|
+
cur_ckpoint_file = prefix + "-" + str(self._current_epoch_num) + "_" \
|
|
109
|
+
+ str(self._current_step_num) + "_breakpoint.ckpt"
|
|
110
|
+
cur_file = os.path.join(obj._directory, cur_ckpoint_file)
|
|
111
|
+
if "epoch_num" in obj._append_dict:
|
|
112
|
+
obj._append_dict["epoch_num"] = obj._append_epoch_num + self._current_epoch_num
|
|
113
|
+
if "step_num" in obj._append_dict:
|
|
114
|
+
obj._append_dict["step_num"] = obj._append_step_num + self._current_step_num
|
|
115
|
+
save_checkpoint(self._train_network, cur_file, obj._config.integrated_save, obj._config.async_save,
|
|
116
|
+
obj._append_dict, obj._config.enc_key, obj._config.enc_mode)
|
|
117
|
+
raise e
|
|
118
|
+
else:
|
|
119
|
+
func(self, *args, **kwargs)
|
|
120
|
+
return wrapper
|
|
121
|
+
|
|
122
|
+
def _handle_tft(func):
|
|
123
|
+
"""
|
|
124
|
+
Decorator function, which starts uce handle process when an exception occurs during training.
|
|
125
|
+
"""
|
|
126
|
+
@wraps(func)
|
|
127
|
+
def wrapper(self, *args, **kwargs):
|
|
128
|
+
obj = None
|
|
129
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), TFTRegister):
|
|
130
|
+
obj = kwargs.get('callbacks')
|
|
131
|
+
if kwargs.get('callbacks') and isinstance(kwargs.get('callbacks'), list):
|
|
132
|
+
for item in kwargs.get('callbacks'):
|
|
133
|
+
if isinstance(item, TFTRegister):
|
|
134
|
+
obj = item
|
|
135
|
+
if obj:
|
|
136
|
+
tft = obj.tft
|
|
137
|
+
tft_env = os.getenv("MS_ENABLE_TFT", "")
|
|
138
|
+
uce_env = "UCE:1" in tft_env
|
|
139
|
+
while True:
|
|
140
|
+
try:
|
|
141
|
+
return func(self, *args, **kwargs)
|
|
142
|
+
except RuntimeError as e:
|
|
143
|
+
logger.info("uce wrapper caught RuntimeError")
|
|
144
|
+
if not uce_env:
|
|
145
|
+
logger.info("uce wrapper caught RuntimeError uce not enable")
|
|
146
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
147
|
+
raise e
|
|
148
|
+
e_str = str(e)
|
|
149
|
+
logger.info("uce wrapper caught RuntimeError e_str:{}".format(e_str))
|
|
150
|
+
if "UCEError" in e_str:
|
|
151
|
+
logger.info("uce wrapper report UCEError")
|
|
152
|
+
tft.tft_report_error(tft.ReportState.RS_UCE.value)
|
|
153
|
+
elif "ForceStopError" in e_str:
|
|
154
|
+
logger.info("uce wrapper caught RuntimeError ForceStopError")
|
|
155
|
+
force_stop_err = tft.ReportState.RS_NORMAL.value
|
|
156
|
+
tft.tft_report_error(force_stop_err)
|
|
157
|
+
else:
|
|
158
|
+
logger.info("uce wrapper caught RuntimeError rankid: {} OTHER ERROR")
|
|
159
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
160
|
+
raise e
|
|
161
|
+
ret = tft.tft_wait_next_action()
|
|
162
|
+
if ret == tft.Action.EXIT.value:
|
|
163
|
+
raise e
|
|
164
|
+
repair_step = tft.tft_get_repair_step()
|
|
165
|
+
logger.info("uce wrapper caught repair finish REPAIR STEP: {} batch_num: \
|
|
166
|
+
{}".format(repair_step, self.batch_num))
|
|
167
|
+
initial_epoch = int(repair_step/self.batch_num)
|
|
168
|
+
initial_step = repair_step % self.batch_num
|
|
169
|
+
kwargs["initial_epoch"] = initial_epoch
|
|
170
|
+
|
|
171
|
+
train_dataset = args[1]
|
|
172
|
+
dataset_sink_mode = args[3] if len(args) > 3 else kwargs.get('dataset_sink_mode', True)
|
|
173
|
+
sink_size = args[4] if len(args) > 4 else kwargs.get('sink_size', -1)
|
|
174
|
+
|
|
175
|
+
cb_initial_step = 0
|
|
176
|
+
if dataset_sink_mode:
|
|
177
|
+
train_dataset.set_init_step(initial_epoch)
|
|
178
|
+
dataset_size = train_dataset.get_dataset_size()
|
|
179
|
+
if sink_size != -1:
|
|
180
|
+
cb_initial_step = initial_epoch * sink_size + initial_step
|
|
181
|
+
else:
|
|
182
|
+
cb_initial_step = initial_epoch * dataset_size + initial_step
|
|
183
|
+
else:
|
|
184
|
+
train_dataset.set_init_step(initial_step)
|
|
185
|
+
cb_initial_step = initial_step
|
|
186
|
+
|
|
187
|
+
kwargs["initial_step"] = cb_initial_step
|
|
188
|
+
|
|
189
|
+
logger.info("uce wrapper repair complete \
|
|
190
|
+
initial_epoch: {}, cb_initial_step: {} ".format(initial_epoch, cb_initial_step))
|
|
191
|
+
continue
|
|
192
|
+
except BaseException as e:
|
|
193
|
+
logger.info("uce wrapper caught BaseException error")
|
|
194
|
+
tft.tft_report_error(tft.ReportState.RS_UNKNOWN.value)
|
|
195
|
+
raise e
|
|
196
|
+
else:
|
|
197
|
+
return func(self, *args, **kwargs)
|
|
198
|
+
return wrapper
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _check_tft():
|
|
202
|
+
"""Check if TFT is supported"""
|
|
203
|
+
tft_env = os.getenv("MS_ENABLE_TFT")
|
|
204
|
+
device_target = context.get_context("device_target")
|
|
205
|
+
if tft_env and device_target == "Ascend":
|
|
206
|
+
from mindspore._c_expression import MSContext
|
|
207
|
+
ascend_target = MSContext.get_instance().get_ascend_soc_version()
|
|
208
|
+
if ascend_target == 'ascend910':
|
|
209
|
+
raise ValueError("TFT is not supported when using ascend910")
|
|
210
|
+
ms_mode = context.get_context("mode")
|
|
211
|
+
if ms_mode != mindspore.GRAPH_MODE:
|
|
212
|
+
raise ValueError("TFT is only supported in GRAPH_MODE")
|
|
213
|
+
jit_level = context.get_context("jit_level")
|
|
214
|
+
if jit_level == "O2" and "UCE:1" in tft_env:
|
|
215
|
+
raise ValueError("TFT is not supported when using jit_level == O2")
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _append_ccae(callbacks):
|
|
219
|
+
"""Add cluster monitoring when CCAE is enabled."""
|
|
220
|
+
perf_config = os.getenv("PERF_DUMP_CONFIG")
|
|
221
|
+
if perf_config is None:
|
|
222
|
+
return callbacks
|
|
223
|
+
pairs = perf_config.split(',')
|
|
224
|
+
perf_config_dict = {}
|
|
225
|
+
for pair in pairs:
|
|
226
|
+
key, value = pair.split(':')
|
|
227
|
+
if value.lower() == 'true':
|
|
228
|
+
perf_config_dict[key] = True
|
|
229
|
+
elif value.lower() == 'false':
|
|
230
|
+
perf_config_dict[key] = False
|
|
231
|
+
elif value.isdigit():
|
|
232
|
+
perf_config_dict[key] = int(value)
|
|
233
|
+
else:
|
|
234
|
+
perf_config_dict[key] = value
|
|
235
|
+
if perf_config_dict.get("enable", False):
|
|
236
|
+
if callbacks is None:
|
|
237
|
+
callbacks = ClusterMonitor()
|
|
238
|
+
elif isinstance(callbacks, list):
|
|
239
|
+
callbacks.append(ClusterMonitor())
|
|
240
|
+
else:
|
|
241
|
+
callbacks = [callbacks, ClusterMonitor()]
|
|
242
|
+
return callbacks
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def _get_arg_infos(inputs):
|
|
246
|
+
"""Get compile argument information from inputs.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
inputs (Union[list, tuple, dict]): Argument got from cell which is set by `set_inputs`.
|
|
250
|
+
|
|
251
|
+
Raises:
|
|
252
|
+
RuntimeError: inputs is not a list, tuple or dict.
|
|
253
|
+
RuntimeError: inputs is a dict without necessary keys and values.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
_type_: _description_
|
|
257
|
+
"""
|
|
258
|
+
if isinstance(inputs, (list, tuple)):
|
|
259
|
+
arg_specified = [[idx, arg] for idx, arg in enumerate(inputs)]
|
|
260
|
+
arg_len = len(inputs)
|
|
261
|
+
elif isinstance(inputs, dict):
|
|
262
|
+
arg_specified = inputs.get(ARG_SPECIFIED, None)
|
|
263
|
+
arg_len = inputs.get(TOTAL_ARG_LEN, None)
|
|
264
|
+
if arg_specified is None or arg_len is None:
|
|
265
|
+
raise RuntimeError(
|
|
266
|
+
"The incremental inputs should be processed(with \"%s\" and \"%s\"), but got %s." %
|
|
267
|
+
(ARG_SPECIFIED, TOTAL_ARG_LEN, str(inputs)))
|
|
268
|
+
else:
|
|
269
|
+
raise RuntimeError("inputs should be a list/tuple or a dict, but got %s!" % str(inputs))
|
|
270
|
+
|
|
271
|
+
return arg_len, arg_specified
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _merge_inputs(inputs1, inputs2):
|
|
275
|
+
"""Merge two processed inputs to a new inputs for latter setting cell's inputs."""
|
|
276
|
+
is_fullmode1 = isinstance(inputs1, (list, tuple))
|
|
277
|
+
is_fullmode2 = isinstance(inputs2, (list, tuple))
|
|
278
|
+
|
|
279
|
+
if is_fullmode1 and is_fullmode2:
|
|
280
|
+
return [*inputs1, *inputs2]
|
|
281
|
+
|
|
282
|
+
arg_len1, arg_specified1 = _get_arg_infos(inputs1)
|
|
283
|
+
arg_len2, arg_specified2 = _get_arg_infos(inputs2)
|
|
284
|
+
|
|
285
|
+
res_arg_len = arg_len1 + arg_len2
|
|
286
|
+
res_arg_specified = []
|
|
287
|
+
res_arg_specified.extend(arg_specified1)
|
|
288
|
+
# The second inputs should add offset before merging.
|
|
289
|
+
for idx, arg in arg_specified2:
|
|
290
|
+
res_arg_specified.append([idx + arg_len1, arg])
|
|
291
|
+
|
|
292
|
+
return {ARG_SPECIFIED: res_arg_specified, TOTAL_ARG_LEN: res_arg_len}
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def _process_loss_inputs(loss_inputs):
|
|
296
|
+
"""Process loss's inputs whose first input should be dropped for train or eval.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
loss_inputs (Union[list, tuple, dict]): Arguments save by `set_inputs` or `jit`.
|
|
300
|
+
|
|
301
|
+
Raises:
|
|
302
|
+
RuntimeError: inputs is not a list, tuple or dict.
|
|
303
|
+
RuntimeError: inputs is a dict without necessary keys and values.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
list, tuple or dict: Arguments for latter setting.
|
|
307
|
+
"""
|
|
308
|
+
# For train or eval, the first input of loss is the inner-tensor, so drop it.
|
|
309
|
+
res = None
|
|
310
|
+
if isinstance(loss_inputs, (list, tuple)):
|
|
311
|
+
res = [*loss_inputs]
|
|
312
|
+
res.pop(0)
|
|
313
|
+
elif isinstance(loss_inputs, dict):
|
|
314
|
+
loss_arg_specified = loss_inputs.get(ARG_SPECIFIED, None)
|
|
315
|
+
loss_arg_len = loss_inputs.get(TOTAL_ARG_LEN, None)
|
|
316
|
+
if loss_arg_specified is None or loss_arg_len is None:
|
|
317
|
+
raise RuntimeError(
|
|
318
|
+
"The loss incremental inputs should be processed(with \"%s\" and \"%s\"), but got %s." %
|
|
319
|
+
(ARG_SPECIFIED, TOTAL_ARG_LEN, str(loss_inputs)))
|
|
320
|
+
res_loss_arg_specified = []
|
|
321
|
+
for idx, arg in loss_arg_specified:
|
|
322
|
+
if idx == 0:
|
|
323
|
+
continue
|
|
324
|
+
res_loss_arg_specified.append([idx, arg])
|
|
325
|
+
res = {ARG_SPECIFIED: res_loss_arg_specified, TOTAL_ARG_LEN: loss_arg_len - 1}
|
|
326
|
+
else:
|
|
327
|
+
raise RuntimeError("loss_inputs should be a list/tuple or a dict, but got %s!" % str(loss_inputs))
|
|
328
|
+
|
|
329
|
+
return res
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def _set_with_processed_inputs(network, inputs):
|
|
333
|
+
"""Save set inputs for computation graph with processed inputs.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
network (nn.Cell): Target cell.
|
|
337
|
+
inputs (Union[list, tuple, dict]): Inputs argument got from other cell.
|
|
338
|
+
|
|
339
|
+
Raises:
|
|
340
|
+
RuntimeError: network is not a nn.Cell.
|
|
341
|
+
RuntimeError: inputs is not a list, tuple or dict.
|
|
342
|
+
"""
|
|
343
|
+
Validator.check_value_type('network', network, nn.Cell)
|
|
344
|
+
if isinstance(inputs, (list, tuple)):
|
|
345
|
+
network.set_inputs(*inputs)
|
|
346
|
+
elif isinstance(inputs, dict):
|
|
347
|
+
network.set_inputs(**inputs)
|
|
348
|
+
else:
|
|
349
|
+
raise RuntimeError(
|
|
350
|
+
"Reset inputs from a process inputs, should be a list/tuple or a dict, but got %s!" % str(inputs))
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
class Model:
|
|
354
|
+
"""
|
|
355
|
+
High-Level API for training or inference.
|
|
356
|
+
|
|
357
|
+
`Model` groups layers into an object with training and inference features based on the arguments.
|
|
358
|
+
|
|
359
|
+
Note:
|
|
360
|
+
- If use mixed precision functions, need to set parameter `optimizer` at the same time,
|
|
361
|
+
otherwise mixed precision functions do not take effect.
|
|
362
|
+
When uses mixed precision functions, `global_step` in optimizer may be different from `cur_step_num`
|
|
363
|
+
in Model.
|
|
364
|
+
- After using `custom_mixed_precision` or `auto_mixed_precision` for precision conversion, it is not supported
|
|
365
|
+
to perform the precision conversion again. If `Model` is used to train a converted network, `amp_level`
|
|
366
|
+
need to be configured to ``O0`` to avoid the duplicated accuracy conversion.
|
|
367
|
+
|
|
368
|
+
Args:
|
|
369
|
+
network (Cell): A training or testing network.
|
|
370
|
+
loss_fn (Cell): Objective function. If `loss_fn` is None, the `network` should contain the calculation of loss.
|
|
371
|
+
Default: ``None`` .
|
|
372
|
+
optimizer (Cell): Optimizer for updating the weights. If `optimizer` is None, the `network` needs to
|
|
373
|
+
do backpropagation and update weights. Default: ``None`` .
|
|
374
|
+
metrics (Union[dict, set]): A Dictionary or a set of metrics for model evaluation.
|
|
375
|
+
eg: {'accuracy', 'recall'}. Default: ``None`` .
|
|
376
|
+
eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
|
|
377
|
+
`eval_network` . Default: ``None`` .
|
|
378
|
+
eval_indexes (list): It is used when eval_network is defined. If `eval_indexes` is None by default, all outputs
|
|
379
|
+
of the `eval_network` would be passed to metrics. If `eval_indexes` is set, it must contain
|
|
380
|
+
three elements: the positions of loss value, predicted value and label in outputs of the
|
|
381
|
+
`eval_network`. In this case, the loss value will be passed to the `Loss` metric, the
|
|
382
|
+
predicted value and label will be passed to other metrics.
|
|
383
|
+
:func:`mindspore.train.Metric.set_indexes` is recommended instead of `eval_indexes`.
|
|
384
|
+
Default: ``None`` .
|
|
385
|
+
amp_level (str): Option for argument `level` in :func:`mindspore.amp.build_train_network`, level for mixed
|
|
386
|
+
precision training. Supports ["O0", "O1", "O2", "O3", "auto"]. Default: ``"O0"`` .
|
|
387
|
+
|
|
388
|
+
For details on `amp_level` , refer to :func:`mindspore.amp.auto_mixed_precision`.
|
|
389
|
+
|
|
390
|
+
The BatchNorm strategy can be changed by `keep_batchnorm_fp32` settings in `kwargs`. `keep_batchnorm_fp32`
|
|
391
|
+
must be a bool. The loss scale strategy can be changed by `loss_scale_manager` setting in `kwargs`.
|
|
392
|
+
`loss_scale_manager` should be a subclass of :class:`mindspore.amp.LossScaleManager`.
|
|
393
|
+
|
|
394
|
+
boost_level (str): Option for argument `level` in `mindspore.boost`, level for boost mode
|
|
395
|
+
training. Supports ["O0", "O1", "O2"]. Default: ``"O0"`` .
|
|
396
|
+
|
|
397
|
+
- "O0": Do not change.
|
|
398
|
+
- "O1": Enable the boost mode, the performance is improved by about 20%, and
|
|
399
|
+
the accuracy is the same as the original accuracy.
|
|
400
|
+
- "O2": Enable the boost mode, the performance is improved by about 30%, and
|
|
401
|
+
the accuracy is reduced by less than 3%.
|
|
402
|
+
|
|
403
|
+
If you want to config boost mode by yourself, you can set boost_config_dict as `boost.py`.
|
|
404
|
+
In order for this function to work, you need to set the optimizer, eval_network or metric parameters
|
|
405
|
+
at the same time.
|
|
406
|
+
|
|
407
|
+
Notice: The current optimization enabled by default only applies to some networks, and not all networks
|
|
408
|
+
can obtain the same benefits. It is recommended to enable this function on
|
|
409
|
+
the Graph mode + Ascend platform, and for better acceleration, refer to the documentation to configure
|
|
410
|
+
boost_config_dict.
|
|
411
|
+
|
|
412
|
+
Examples:
|
|
413
|
+
>>> from mindspore import nn
|
|
414
|
+
>>> from mindspore.train import Model
|
|
415
|
+
>>>
|
|
416
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
417
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
418
|
+
>>> net = LeNet5()
|
|
419
|
+
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
420
|
+
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
421
|
+
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
|
422
|
+
>>> model.train_network
|
|
423
|
+
>>> model.predict_network
|
|
424
|
+
>>> model.eval_network
|
|
425
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
426
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
427
|
+
>>> dataset = create_dataset()
|
|
428
|
+
>>> model.train(2, dataset)
|
|
429
|
+
"""
|
|
430
|
+
|
|
431
|
+
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, eval_indexes=None,
|
|
432
|
+
amp_level="O0", boost_level="O0", **kwargs):
|
|
433
|
+
self._network = network
|
|
434
|
+
self._loss_fn = loss_fn
|
|
435
|
+
self._optimizer = optimizer
|
|
436
|
+
self._loss_scale_manager = None
|
|
437
|
+
self._loss_scale_manager_set = False
|
|
438
|
+
self._keep_bn_fp32 = None
|
|
439
|
+
self._check_kwargs(kwargs)
|
|
440
|
+
self._amp_level = amp_level
|
|
441
|
+
self._boost_level = boost_level
|
|
442
|
+
self._eval_network = eval_network
|
|
443
|
+
self._process_amp_args(kwargs)
|
|
444
|
+
self._parallel_mode = _get_parallel_mode()
|
|
445
|
+
self._device_number = _get_device_num()
|
|
446
|
+
self._parameter_broadcast = _get_parameter_broadcast()
|
|
447
|
+
self._metrics = metrics
|
|
448
|
+
|
|
449
|
+
self._check_amp_level_arg(optimizer, amp_level)
|
|
450
|
+
self._check_for_graph_cell(kwargs)
|
|
451
|
+
self._build_boost_network(kwargs)
|
|
452
|
+
self._train_network = self._build_train_network()
|
|
453
|
+
self._train_network._jit_config_dict = network.jit_config_dict
|
|
454
|
+
self._build_eval_network(metrics, self._eval_network, eval_indexes)
|
|
455
|
+
self._build_predict_network()
|
|
456
|
+
self._current_epoch_num = 0
|
|
457
|
+
self._current_step_num = 0
|
|
458
|
+
self.epoch_iter = 0
|
|
459
|
+
self.enable_recovery = False
|
|
460
|
+
self._backbone_is_train = True
|
|
461
|
+
self.need_load_ckpt = False
|
|
462
|
+
self._lite_full_predictor = None
|
|
463
|
+
self._lite_incremental_predictor = None
|
|
464
|
+
self._mindspore_lite = None
|
|
465
|
+
self._lite_infer = True # if backend lite infer fails, set False
|
|
466
|
+
self._mindspore_lite_model_group_id = id(self) & 0xFFFF
|
|
467
|
+
self.batch_num = -1
|
|
468
|
+
|
|
469
|
+
def _check_for_graph_cell(self, kwargs):
|
|
470
|
+
"""Check for graph cell"""
|
|
471
|
+
if not isinstance(self._network, nn.GraphCell):
|
|
472
|
+
return
|
|
473
|
+
if self._amp_level != "O0":
|
|
474
|
+
logger.warning("amp_level will not work when network is a GraphCell.")
|
|
475
|
+
|
|
476
|
+
if self._loss_fn is not None or self._optimizer is not None:
|
|
477
|
+
raise ValueError("For 'Model', 'loss_fn' and 'optimizer' should be None when network is a GraphCell, "
|
|
478
|
+
"but got 'loss_fn': {}, 'optimizer': {}.".format(self._loss_fn, self._optimizer))
|
|
479
|
+
if kwargs:
|
|
480
|
+
raise ValueError("For 'Model', the '**kwargs' argument should be empty when network is a GraphCell.")
|
|
481
|
+
|
|
482
|
+
def _process_amp_args(self, kwargs):
|
|
483
|
+
if 'keep_batchnorm_fp32' in kwargs:
|
|
484
|
+
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
|
|
485
|
+
if 'loss_scale_manager' in kwargs:
|
|
486
|
+
self._loss_scale_manager = kwargs['loss_scale_manager']
|
|
487
|
+
self._loss_scale_manager_set = True
|
|
488
|
+
|
|
489
|
+
def _check_amp_level_arg(self, optimizer, amp_level):
|
|
490
|
+
"""Check amp level arg"""
|
|
491
|
+
if optimizer is None and amp_level != "O0":
|
|
492
|
+
raise ValueError(
|
|
493
|
+
"Auto mixed precision will not work because 'optimizer' is None.Please set amp_level='O0' "
|
|
494
|
+
"to disable auto mixed precision or set 'optimizer' not be None to use auto mixed precision.")
|
|
495
|
+
|
|
496
|
+
def _check_kwargs(self, kwargs):
|
|
497
|
+
for arg in kwargs:
|
|
498
|
+
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32', 'boost_config_dict', 'acc_level']:
|
|
499
|
+
raise ValueError(f"The argument in 'kwargs' should be 'loss_scale_manager' or "
|
|
500
|
+
f"'keep_batchnorm_fp32' or 'boost_config_dict' or 'acc_level', but got '{arg}'.")
|
|
501
|
+
|
|
502
|
+
def _check_reuse_dataset(self, dataset):
|
|
503
|
+
if not hasattr(dataset, '__model_hash__'):
|
|
504
|
+
dataset.__model_hash__ = hash(self)
|
|
505
|
+
if hasattr(dataset, '__model_hash__') and dataset.__model_hash__ != hash(self):
|
|
506
|
+
raise RuntimeError("The dataset object had been used in other model by model.train(...), "
|
|
507
|
+
"please create a new dataset.")
|
|
508
|
+
|
|
509
|
+
def _build_boost_network(self, kwargs):
|
|
510
|
+
"""Build the boost network."""
|
|
511
|
+
boost_config_dict = ""
|
|
512
|
+
if 'boost_config_dict' in kwargs:
|
|
513
|
+
boost_config_dict = kwargs['boost_config_dict']
|
|
514
|
+
if 'acc_level' in kwargs:
|
|
515
|
+
self._boost_level = kwargs['acc_level']
|
|
516
|
+
logger.warning("Next version acc_level will be removed, please replace with boost_level")
|
|
517
|
+
processor = AutoBoost(self._boost_level, boost_config_dict)
|
|
518
|
+
if processor.level not in ["O1", "O2"]:
|
|
519
|
+
return
|
|
520
|
+
if self._optimizer is None:
|
|
521
|
+
logger.warning("In boost mode, the optimizer must be defined.")
|
|
522
|
+
return
|
|
523
|
+
if self._eval_network is None and self._metrics is None:
|
|
524
|
+
logger.warning("In boost mode, the eval_network and metrics cannot be undefined at the same time.")
|
|
525
|
+
return
|
|
526
|
+
|
|
527
|
+
self._network, self._optimizer = processor.network_auto_process_train(self._network, self._optimizer)
|
|
528
|
+
if self._eval_network is not None:
|
|
529
|
+
self._eval_network = processor.network_auto_process_eval(self._eval_network)
|
|
530
|
+
|
|
531
|
+
def _build_train_network(self):
|
|
532
|
+
"""Build train network"""
|
|
533
|
+
network = self._network
|
|
534
|
+
Validator.check_value_type('network', network, nn.Cell)
|
|
535
|
+
if self._loss_scale_manager is not None and self._optimizer is None:
|
|
536
|
+
raise ValueError("The argument 'optimizer' can not be None when set 'loss_scale_manager'.")
|
|
537
|
+
|
|
538
|
+
net_inputs = network.get_inputs()
|
|
539
|
+
if self._loss_fn:
|
|
540
|
+
if self._loss_fn.get_inputs() and net_inputs:
|
|
541
|
+
loss_inputs = _process_loss_inputs(self._loss_fn.get_inputs())
|
|
542
|
+
net_inputs = _merge_inputs(net_inputs, loss_inputs)
|
|
543
|
+
if self._optimizer:
|
|
544
|
+
amp_config = {}
|
|
545
|
+
if self._loss_scale_manager_set:
|
|
546
|
+
amp_config['loss_scale_manager'] = self._loss_scale_manager
|
|
547
|
+
if self._keep_bn_fp32 is not None:
|
|
548
|
+
amp_config['keep_batchnorm_fp32'] = self._keep_bn_fp32
|
|
549
|
+
network = amp.build_train_network(network,
|
|
550
|
+
self._optimizer,
|
|
551
|
+
self._loss_fn,
|
|
552
|
+
level=self._amp_level,
|
|
553
|
+
boost_level=self._boost_level,
|
|
554
|
+
**amp_config)
|
|
555
|
+
elif self._loss_fn:
|
|
556
|
+
network = nn.WithLossCell(network, self._loss_fn)
|
|
557
|
+
# If need to check if loss_fn is not None, but optimizer is None
|
|
558
|
+
|
|
559
|
+
if net_inputs is not None:
|
|
560
|
+
_set_with_processed_inputs(network, net_inputs)
|
|
561
|
+
return network
|
|
562
|
+
|
|
563
|
+
def _build_eval_network(self, metrics, eval_network, eval_indexes):
|
|
564
|
+
"""Build the network for evaluation."""
|
|
565
|
+
self._metric_fns = get_metrics(metrics)
|
|
566
|
+
if not self._metric_fns:
|
|
567
|
+
return
|
|
568
|
+
|
|
569
|
+
if eval_network is not None:
|
|
570
|
+
if eval_indexes is not None and not (isinstance(eval_indexes, list) and len(eval_indexes) == 3):
|
|
571
|
+
raise ValueError("The argument 'eval_indexes' must be a list or None. If 'eval_indexes' is a list, "
|
|
572
|
+
"length of it must be three. But got 'eval_indexes' {}".format(eval_indexes))
|
|
573
|
+
|
|
574
|
+
self._eval_network = eval_network
|
|
575
|
+
self._eval_indexes = eval_indexes
|
|
576
|
+
else:
|
|
577
|
+
if self._loss_fn is None:
|
|
578
|
+
raise ValueError(f"If `metrics` is set, `eval_network` must not be None. Do not set `metrics` if you"
|
|
579
|
+
f" don't want an evaluation.\n"
|
|
580
|
+
f"If evaluation is required, you need to specify `eval_network`, which will be used in"
|
|
581
|
+
f" the framework to evaluate the model.\n"
|
|
582
|
+
f"For the simple scenarios with one data, one label and one logits, `eval_network` is"
|
|
583
|
+
f" optional, and then you can set `eval_network` or `loss_fn`. For the latter case,"
|
|
584
|
+
f" framework will automatically build an evaluation network with `network` and"
|
|
585
|
+
f" `loss_fn`.")
|
|
586
|
+
net_inputs = self._network.get_inputs()
|
|
587
|
+
if self._loss_fn.get_inputs() and net_inputs:
|
|
588
|
+
loss_inputs = _process_loss_inputs(self._loss_fn.get_inputs())
|
|
589
|
+
net_inputs = _merge_inputs(net_inputs, loss_inputs)
|
|
590
|
+
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O2", "O3", "auto"])
|
|
591
|
+
if net_inputs is not None:
|
|
592
|
+
_set_with_processed_inputs(self._eval_network, net_inputs)
|
|
593
|
+
self._eval_indexes = [0, 1, 2]
|
|
594
|
+
|
|
595
|
+
def _build_predict_network(self):
|
|
596
|
+
"""Build the network for prediction."""
|
|
597
|
+
self._predict_network = self._network
|
|
598
|
+
# Unlike the cases in build_train_network() and build_eval_network(), 'multi_subgraphs' is not set
|
|
599
|
+
|
|
600
|
+
def _clear_metrics(self):
|
|
601
|
+
"""Clear metrics local values."""
|
|
602
|
+
for metric in self._metric_fns.values():
|
|
603
|
+
metric.clear()
|
|
604
|
+
|
|
605
|
+
def _update_metrics(self, outputs):
|
|
606
|
+
"""Update metrics local values."""
|
|
607
|
+
if isinstance(outputs, Tensor):
|
|
608
|
+
outputs = (outputs,)
|
|
609
|
+
if not isinstance(outputs, tuple):
|
|
610
|
+
raise ValueError(f"The argument 'outputs' should be tuple, but got {type(outputs)}.")
|
|
611
|
+
|
|
612
|
+
if self._eval_indexes is not None and len(outputs) < 3:
|
|
613
|
+
raise ValueError("The length of 'outputs' must be >= 3, but got {}".format(len(outputs)))
|
|
614
|
+
|
|
615
|
+
for metric in self._metric_fns.values():
|
|
616
|
+
if self._eval_indexes is None:
|
|
617
|
+
metric.update(*outputs)
|
|
618
|
+
else:
|
|
619
|
+
if isinstance(metric, Loss):
|
|
620
|
+
metric.update(outputs[self._eval_indexes[0]])
|
|
621
|
+
else:
|
|
622
|
+
metric.update(outputs[self._eval_indexes[1]], outputs[self._eval_indexes[2]])
|
|
623
|
+
|
|
624
|
+
def _get_metrics(self):
|
|
625
|
+
"""Get metrics local values."""
|
|
626
|
+
metrics = dict()
|
|
627
|
+
# There's no need for server to execute eval, just give fake metrics.
|
|
628
|
+
for key, value in self._metric_fns.items():
|
|
629
|
+
if not _is_role_pserver():
|
|
630
|
+
metrics[key] = value.eval()
|
|
631
|
+
else:
|
|
632
|
+
metrics[key] = 1
|
|
633
|
+
return metrics
|
|
634
|
+
|
|
635
|
+
def _get_scaling_sens(self):
|
|
636
|
+
"""get the scaling sens"""
|
|
637
|
+
scaling_sens = 1
|
|
638
|
+
if self._loss_scale_manager is not None:
|
|
639
|
+
scaling_sens = self._loss_scale_manager.get_loss_scale()
|
|
640
|
+
if self._parallel_mode == ParallelMode.DATA_PARALLEL:
|
|
641
|
+
scaling_sens /= self._device_number
|
|
642
|
+
return scaling_sens
|
|
643
|
+
|
|
644
|
+
def _exec_preprocess(self, is_train, dataset, dataset_sink_mode, sink_size=-1, epoch_num=1, dataset_helper=None):
|
|
645
|
+
"""Initializes dataset."""
|
|
646
|
+
if is_train:
|
|
647
|
+
network = self._train_network
|
|
648
|
+
phase = 'train'
|
|
649
|
+
else:
|
|
650
|
+
network = self._eval_network
|
|
651
|
+
phase = 'eval'
|
|
652
|
+
|
|
653
|
+
if dataset_sink_mode and not is_train:
|
|
654
|
+
dataset.__loop_size__ = 1
|
|
655
|
+
|
|
656
|
+
if dataset_helper is None:
|
|
657
|
+
logger.info("Begin to create DatasetHelper.")
|
|
658
|
+
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num)
|
|
659
|
+
|
|
660
|
+
if dataset_sink_mode:
|
|
661
|
+
logger.info("Begin to connect network with dataset.")
|
|
662
|
+
network = connect_network_with_dataset(network, dataset_helper)
|
|
663
|
+
|
|
664
|
+
if _get_recovery_context("enable_recovery") and is_train:
|
|
665
|
+
_set_training_dataset(dataset_helper)
|
|
666
|
+
|
|
667
|
+
network.set_train(is_train)
|
|
668
|
+
network.phase = phase
|
|
669
|
+
self._backbone_is_train = is_train
|
|
670
|
+
|
|
671
|
+
return dataset_helper, network
|
|
672
|
+
|
|
673
|
+
def _check_network_mode(self, network, is_train):
|
|
674
|
+
"""
|
|
675
|
+
Change network mode if modes of backbone network and current network are not matching.
|
|
676
|
+
"""
|
|
677
|
+
if self._backbone_is_train != is_train:
|
|
678
|
+
network.set_train(is_train)
|
|
679
|
+
self._backbone_is_train = is_train
|
|
680
|
+
# Mode train and eval are the same net, network will be set_grad in _build_train_network.
|
|
681
|
+
# But if mode just want to do predict or eval, must set network set_grad False
|
|
682
|
+
if not is_train:
|
|
683
|
+
network.set_grad(False)
|
|
684
|
+
return network
|
|
685
|
+
|
|
686
|
+
def _check_need_ckpt(self, callbacks):
|
|
687
|
+
"""Check callback list contain ckpt"""
|
|
688
|
+
need_ckpt = False
|
|
689
|
+
save_ckpt_steps = 1
|
|
690
|
+
last_triggered_step = 0
|
|
691
|
+
for cb in callbacks:
|
|
692
|
+
if isinstance(cb, ModelCheckpoint):
|
|
693
|
+
need_ckpt = True
|
|
694
|
+
cfg_size = cb._get_save_checkpoint_steps
|
|
695
|
+
save_ckpt_steps = save_ckpt_steps if (cfg_size is None or cfg_size >= sys.maxsize) else cfg_size
|
|
696
|
+
last_triggered_step = cb._get_last_trigger_step
|
|
697
|
+
break
|
|
698
|
+
return need_ckpt, save_ckpt_steps, last_triggered_step
|
|
699
|
+
|
|
700
|
+
def _store_training_step_info(self, cb_params):
|
|
701
|
+
"""
|
|
702
|
+
cache train step info
|
|
703
|
+
:param cb_params: callback params
|
|
704
|
+
:return: none
|
|
705
|
+
"""
|
|
706
|
+
if os.environ.get("MS_ENABLE_CKPT_D2H_ASYNC") != "1":
|
|
707
|
+
return
|
|
708
|
+
if (context.get_context("mode") == context.GRAPH_MODE) and (context.get_context("device_target") == "Ascend"):
|
|
709
|
+
cb_params.need_ckpt, cb_params.save_checkpoint_steps, \
|
|
710
|
+
cb_params.last_triggered_step = self._check_need_ckpt(cb_params.list_callback)
|
|
711
|
+
logger.info(f"need_ckpt:{cb_params.need_ckpt},"
|
|
712
|
+
f"save_checkpoint_steps:{cb_params.save_checkpoint_steps},"
|
|
713
|
+
f"cur_step_num:{cb_params.cur_step_num},"
|
|
714
|
+
f"last_triggered_step:{cb_params.last_triggered_step}")
|
|
715
|
+
context.set_context(ascend_config={"need_ckpt": cb_params.need_ckpt,
|
|
716
|
+
"save_checkpoint_steps": cb_params.save_checkpoint_steps,
|
|
717
|
+
"cur_step_num": cb_params.cur_step_num,
|
|
718
|
+
"last_triggered_step": cb_params.last_triggered_step})
|
|
719
|
+
|
|
720
|
+
def _warmup_dataset(self, epoch, train_dataset, sink_size=-1):
|
|
721
|
+
"""
|
|
722
|
+
Trigger dataset pipeline running before graph compiling.
|
|
723
|
+
|
|
724
|
+
Args:
|
|
725
|
+
epoch (int): Total number of iterations on the data.
|
|
726
|
+
train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be
|
|
727
|
+
initialized. Default: ``None``.
|
|
728
|
+
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
729
|
+
"""
|
|
730
|
+
if sink_size == -1:
|
|
731
|
+
epoch_num = epoch
|
|
732
|
+
else:
|
|
733
|
+
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())
|
|
734
|
+
train_dataset.__total_batch__ = epoch * sink_size
|
|
735
|
+
dataset_helper = None
|
|
736
|
+
dataset_helper, _ = self._exec_preprocess(is_train=True,
|
|
737
|
+
dataset=train_dataset,
|
|
738
|
+
dataset_sink_mode=True,
|
|
739
|
+
sink_size=sink_size,
|
|
740
|
+
epoch_num=epoch_num,
|
|
741
|
+
dataset_helper=dataset_helper)
|
|
742
|
+
train_dataset._dataset_helper = dataset_helper
|
|
743
|
+
train_dataset._warmup_epoch = epoch
|
|
744
|
+
|
|
745
|
+
def _waiting_for_dataset_warmup_ready(self, train_dataset):
|
|
746
|
+
"""
|
|
747
|
+
Wait for the dataset to warmup until there is a batch of data available for training on the device side.
|
|
748
|
+
|
|
749
|
+
Args:
|
|
750
|
+
train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be
|
|
751
|
+
initialized. Default: ``None``.
|
|
752
|
+
"""
|
|
753
|
+
mbuf_size = train_dataset.__transfer_dataset__.get_mbuf_queue_size()
|
|
754
|
+
while mbuf_size == 0:
|
|
755
|
+
time.sleep(10)
|
|
756
|
+
mbuf_size = train_dataset.__transfer_dataset__.get_mbuf_queue_size()
|
|
757
|
+
if mbuf_size != 0:
|
|
758
|
+
break
|
|
759
|
+
logger.warning(f"Waiting for the dataset warmup, current device queue size: {mbuf_size}")
|
|
760
|
+
|
|
761
|
+
def _init(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
|
|
762
|
+
"""
|
|
763
|
+
Initialize compute graphs and data graphs with the sink mode.
|
|
764
|
+
|
|
765
|
+
Note:
|
|
766
|
+
Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently.
|
|
767
|
+
|
|
768
|
+
Args:
|
|
769
|
+
train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be
|
|
770
|
+
initialized. Default: ``None``.
|
|
771
|
+
valid_dataset (Dataset): A evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs
|
|
772
|
+
will be initialized, and `metrics` in `Model` can not be None. Default: ``None``.
|
|
773
|
+
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
774
|
+
epoch (int): Total number of iterations on the data. Default: 1.
|
|
775
|
+
"""
|
|
776
|
+
if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend":
|
|
777
|
+
raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.')
|
|
778
|
+
|
|
779
|
+
if not train_dataset and not valid_dataset:
|
|
780
|
+
raise ValueError("The argument 'train_dataset' and 'valid_dataset' can not both be None or empty.")
|
|
781
|
+
|
|
782
|
+
logger.info("Begin to check device number in model.build() procedure.")
|
|
783
|
+
_device_number_check(self._parallel_mode, self._device_number)
|
|
784
|
+
|
|
785
|
+
if train_dataset:
|
|
786
|
+
if not isinstance(train_dataset, mindspore.dataset.Dataset):
|
|
787
|
+
raise TypeError("The type of 'train_dataset' must be `Dataset`, "
|
|
788
|
+
"but got {}.".format(type(train_dataset)))
|
|
789
|
+
|
|
790
|
+
logger.info("Begin to check parameter broadcast in model.build() procedure.")
|
|
791
|
+
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
|
|
792
|
+
if self._parameter_broadcast:
|
|
793
|
+
self._train_network.set_broadcast_flag()
|
|
794
|
+
|
|
795
|
+
logger.info("Begin to exec preprocess in model.build() procedure.")
|
|
796
|
+
train_dataset.__no_send__ = True
|
|
797
|
+
train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
|
|
798
|
+
dataset=train_dataset,
|
|
799
|
+
dataset_sink_mode=True,
|
|
800
|
+
sink_size=sink_size)
|
|
801
|
+
logger.info("Begin to warmup dataset in model.build() procedure.")
|
|
802
|
+
self._warmup_dataset(epoch, train_dataset, sink_size)
|
|
803
|
+
|
|
804
|
+
# Since dataset pipeline has been triggered, delete flag
|
|
805
|
+
delattr(train_dataset, "__no_send__")
|
|
806
|
+
|
|
807
|
+
# Waiting for the dataset warmup ready
|
|
808
|
+
logger.info("Begin waiting for dataset warmup in model.build() procedure.")
|
|
809
|
+
self._waiting_for_dataset_warmup_ready(train_dataset)
|
|
810
|
+
logger.info("The dataset warmup was successful in model.build() procedure.")
|
|
811
|
+
|
|
812
|
+
if context.get_auto_parallel_context("pipeline_stages") > 1 and valid_dataset:
|
|
813
|
+
train_network.add_flags_recursive(is_first_iteration=True)
|
|
814
|
+
for inputs in train_dataset_helper:
|
|
815
|
+
logger.info("Begin to compile train network in model.build() procedure.")
|
|
816
|
+
train_network.compile(*inputs)
|
|
817
|
+
self._train_network.parameter_layout_dict = train_network.parameter_layout_dict
|
|
818
|
+
break
|
|
819
|
+
|
|
820
|
+
if valid_dataset:
|
|
821
|
+
if not isinstance(valid_dataset, mindspore.dataset.Dataset):
|
|
822
|
+
raise TypeError("The type of 'valid_dataset' must be `Dataset`, "
|
|
823
|
+
"but got {}.".format(type(valid_dataset)))
|
|
824
|
+
if not self._metric_fns:
|
|
825
|
+
raise RuntimeError("If define `valid_dataset`, metric fn can not be None or empty, "
|
|
826
|
+
"you should set the argument 'metrics' for model.")
|
|
827
|
+
|
|
828
|
+
valid_dataset.__no_send__ = True
|
|
829
|
+
valid_dataset_helper, eval_network = self._exec_preprocess(is_train=False,
|
|
830
|
+
dataset=valid_dataset,
|
|
831
|
+
dataset_sink_mode=True)
|
|
832
|
+
if context.get_auto_parallel_context("pipeline_stages") > 1:
|
|
833
|
+
eval_network.add_flags_recursive(is_first_iteration=False)
|
|
834
|
+
for inputs in valid_dataset_helper:
|
|
835
|
+
logger.info("Begin to compile eval network in model.build() procedure.")
|
|
836
|
+
eval_network.compile(*inputs)
|
|
837
|
+
break
|
|
838
|
+
|
|
839
|
+
@staticmethod
|
|
840
|
+
def _transform_callbacks(callbacks):
|
|
841
|
+
"""Transform callback to a list."""
|
|
842
|
+
if callbacks is None:
|
|
843
|
+
return []
|
|
844
|
+
|
|
845
|
+
if isinstance(callbacks, Iterable):
|
|
846
|
+
return list(callbacks)
|
|
847
|
+
|
|
848
|
+
return [callbacks]
|
|
849
|
+
|
|
850
|
+
@_handle_tft
|
|
851
|
+
@_save_final_ckpt
|
|
852
|
+
def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True, sink_size=-1, initial_epoch=0,
|
|
853
|
+
valid_dataset=None, valid_frequency=1, valid_dataset_sink_mode=True, initial_step=0):
|
|
854
|
+
"""
|
|
855
|
+
Training.
|
|
856
|
+
|
|
857
|
+
Args:
|
|
858
|
+
epoch (int): Total number of iterations on the data.
|
|
859
|
+
train_dataset (Dataset): A training dataset iterator. If there is no
|
|
860
|
+
loss_fn, a tuple with multiple data (data1, data2, data3, ...) will be
|
|
861
|
+
returned and passed to the network. Otherwise, a tuple (data, label) will
|
|
862
|
+
be returned. The data and label would be passed to the network and loss
|
|
863
|
+
function respectively.
|
|
864
|
+
callbacks (list): List of callback objects which should be executed while training. Default: ``None``.
|
|
865
|
+
dataset_sink_mode (bool): Determine whether the data should be passed through the dataset channel.
|
|
866
|
+
Default: ``True``.
|
|
867
|
+
Configure pynative mode or CPU, the training process will be performed with
|
|
868
|
+
dataset not sink.
|
|
869
|
+
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
870
|
+
initial_epoch (int): Epoch at which to start train, it used for resuming a previous training run.
|
|
871
|
+
Default: 0.
|
|
872
|
+
"""
|
|
873
|
+
if self._parameter_broadcast:
|
|
874
|
+
self._train_network.set_broadcast_flag()
|
|
875
|
+
|
|
876
|
+
cb_params = _InternalCallbackParam()
|
|
877
|
+
cb_params.cur_step_num = initial_step
|
|
878
|
+
cb_params.train_network = self._train_network
|
|
879
|
+
cb_params.epoch_num = epoch - initial_epoch
|
|
880
|
+
if dataset_sink_mode and sink_size > 0:
|
|
881
|
+
cb_params.batch_num = sink_size
|
|
882
|
+
else:
|
|
883
|
+
cb_params.batch_num = train_dataset.get_dataset_size()
|
|
884
|
+
self.batch_num = cb_params.batch_num
|
|
885
|
+
cb_params.mode = "train"
|
|
886
|
+
cb_params.loss_fn = self._loss_fn
|
|
887
|
+
cb_params.optimizer = self._optimizer
|
|
888
|
+
cb_params.parallel_mode = self._parallel_mode
|
|
889
|
+
cb_params.device_number = self._device_number
|
|
890
|
+
cb_params.train_dataset = train_dataset
|
|
891
|
+
cb_params.list_callback = self._transform_callbacks(callbacks)
|
|
892
|
+
valid_infos = (valid_dataset, valid_frequency, valid_dataset_sink_mode)
|
|
893
|
+
cb_params.list_callback.insert(0, _FrameworkProfilerCallback())
|
|
894
|
+
if os.environ.get("ENABLE_FLOPS_UTILIZATION_COLLECTOR") == "1" and \
|
|
895
|
+
FlopsUtilizationCollector not in cb_params.list_callback:
|
|
896
|
+
cb_params.list_callback.insert(0, FlopsUtilizationCollector(
|
|
897
|
+
cb_params.batch_num, full_flops=False))
|
|
898
|
+
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
899
|
+
cb_params.list_callback.insert(0, _StepSync())
|
|
900
|
+
callbacks = cb_params.list_callback
|
|
901
|
+
cb_params.train_dataset_element = None
|
|
902
|
+
cb_params.network = self._network
|
|
903
|
+
# Embedding cache server only run one step.
|
|
904
|
+
if _is_role_pserver() and _cache_enable():
|
|
905
|
+
epoch = 1
|
|
906
|
+
cb_params.last_save_ckpt_step = None
|
|
907
|
+
cb_params.latest_ckpt_file = None
|
|
908
|
+
|
|
909
|
+
# build callback list
|
|
910
|
+
with _CallbackManager(callbacks) as list_callback:
|
|
911
|
+
self._check_reuse_dataset(train_dataset)
|
|
912
|
+
if not dataset_sink_mode:
|
|
913
|
+
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch,
|
|
914
|
+
valid_infos)
|
|
915
|
+
elif context.get_context("device_target") == "CPU":
|
|
916
|
+
logger.info("The CPU cannot support dataset sink mode currently."
|
|
917
|
+
"So the training process will be performed with dataset not sink.")
|
|
918
|
+
self._train_process(epoch, train_dataset, list_callback, cb_params, initial_epoch,
|
|
919
|
+
valid_infos)
|
|
920
|
+
else:
|
|
921
|
+
self._train_dataset_sink_process(epoch, train_dataset, list_callback,
|
|
922
|
+
cb_params, sink_size, initial_epoch, valid_infos)
|
|
923
|
+
|
|
924
|
+
@staticmethod
|
|
925
|
+
def _should_eval(epoch, validation_freq):
|
|
926
|
+
return epoch % validation_freq == 0 if isinstance(validation_freq, int) else epoch in validation_freq
|
|
927
|
+
|
|
928
|
+
def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None,
|
|
929
|
+
sink_size=-1, initial_epoch=0, valid_infos=None):
|
|
930
|
+
"""
|
|
931
|
+
Training process. The data would be passed to network through dataset channel.
|
|
932
|
+
|
|
933
|
+
Args:
|
|
934
|
+
epoch (int): Total number of iterations on the data.
|
|
935
|
+
train_dataset (Dataset): A training dataset iterator. If there is no
|
|
936
|
+
loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
|
|
937
|
+
returned and passed to the network. Otherwise, a tuple (data, label) should
|
|
938
|
+
be returned. The data and label would be passed to the network and loss
|
|
939
|
+
function respectively.
|
|
940
|
+
list_callback (Callback): Executor of callback list. Default: ``None``.
|
|
941
|
+
cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
|
|
942
|
+
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
943
|
+
initial_epoch (int): Epoch at which to start train, it used for resuming a previous training run.
|
|
944
|
+
Default: 0.
|
|
945
|
+
"""
|
|
946
|
+
is_graph = (context.get_context("mode") == context.GRAPH_MODE)
|
|
947
|
+
dataset_size = train_dataset.get_dataset_size()
|
|
948
|
+
if dataset_size % sink_size != 0:
|
|
949
|
+
logger.info("In dataset_sink mode (dataset_size % sink_size) should equal to 0, "
|
|
950
|
+
"it is suggested to pad/drop data or adjust sink_size. "
|
|
951
|
+
"But got 'dataset_size': {}, 'sink_size': {}.".format(dataset_size, sink_size))
|
|
952
|
+
if sink_size == -1:
|
|
953
|
+
dataset_sink_num = epoch
|
|
954
|
+
else:
|
|
955
|
+
dataset_sink_num = math.ceil(epoch * sink_size / dataset_size)
|
|
956
|
+
train_dataset.__total_batch__ = epoch * sink_size
|
|
957
|
+
|
|
958
|
+
cb_params.sink_size = sink_size
|
|
959
|
+
cb_params.dataset_sink_mode = True
|
|
960
|
+
run_context = RunContext(cb_params)
|
|
961
|
+
list_callback.on_train_begin(run_context)
|
|
962
|
+
# used to stop training for early stop, such as stopAtTIme or stopATStep
|
|
963
|
+
dataset_helper = None
|
|
964
|
+
if hasattr(train_dataset, '_dataset_helper'):
|
|
965
|
+
dataset_helper = train_dataset._dataset_helper
|
|
966
|
+
|
|
967
|
+
self.epoch_iter = 0
|
|
968
|
+
self._check_enable_recovery()
|
|
969
|
+
# Used to check whether need perform recovery for process which is restarted.
|
|
970
|
+
self._check_need_load_ckpt(cb_params, dataset_size, sink_size)
|
|
971
|
+
# Check whether this process is embedding cache server.
|
|
972
|
+
is_embedding_cache_server = _is_role_pserver() and _cache_enable()
|
|
973
|
+
|
|
974
|
+
while self.epoch_iter < (epoch - initial_epoch):
|
|
975
|
+
cb_params.cur_epoch_num = self.epoch_iter + 1 + initial_epoch
|
|
976
|
+
self._current_epoch_num = cb_params.cur_epoch_num
|
|
977
|
+
self._current_step_num = 0
|
|
978
|
+
list_callback.on_train_epoch_begin(run_context)
|
|
979
|
+
dataset_helper, train_network = self._exec_preprocess(is_train=True,
|
|
980
|
+
dataset=train_dataset,
|
|
981
|
+
dataset_sink_mode=True,
|
|
982
|
+
sink_size=sink_size,
|
|
983
|
+
epoch_num=dataset_sink_num,
|
|
984
|
+
dataset_helper=dataset_helper)
|
|
985
|
+
|
|
986
|
+
cb_params.train_network = train_network
|
|
987
|
+
cb_params.dataset_helper = dataset_helper
|
|
988
|
+
|
|
989
|
+
# Perform recovery for process which is restarted.
|
|
990
|
+
self._reset_training_step_for_abnormal_process(cb_params, dataset_helper)
|
|
991
|
+
# Perform recovery for process which is not restarted.
|
|
992
|
+
self._reset_training_step_for_normal_process(cb_params, dataset_helper)
|
|
993
|
+
|
|
994
|
+
# For data sink dataset_helper only iter once, other wise iter epoch_size times.
|
|
995
|
+
for inputs in dataset_helper:
|
|
996
|
+
if is_graph:
|
|
997
|
+
cb_params.cur_step_num += dataset_helper.sink_size()
|
|
998
|
+
else:
|
|
999
|
+
cb_params.cur_step_num += 1
|
|
1000
|
+
self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
|
1001
|
+
self._store_training_step_info(cb_params)
|
|
1002
|
+
cb_params.train_dataset_element = inputs
|
|
1003
|
+
list_callback.on_train_step_begin(run_context)
|
|
1004
|
+
train_network = self._check_network_mode(train_network, True)
|
|
1005
|
+
outputs = train_network(*inputs)
|
|
1006
|
+
cb_params.net_outputs = outputs
|
|
1007
|
+
|
|
1008
|
+
# In disaster recovery scenarios, need not to execute callbacks if this step executes failed.
|
|
1009
|
+
need_exec_callback_step_end = not (self.enable_recovery and _get_recovery_context("need_reset"))
|
|
1010
|
+
if need_exec_callback_step_end:
|
|
1011
|
+
list_callback.on_train_step_end(run_context)
|
|
1012
|
+
|
|
1013
|
+
# Embedding cache server only run one step.
|
|
1014
|
+
if is_embedding_cache_server:
|
|
1015
|
+
break
|
|
1016
|
+
|
|
1017
|
+
dataset_helper.continue_send()
|
|
1018
|
+
|
|
1019
|
+
# When it's distributed training and using MindRT,
|
|
1020
|
+
# the node id should be reset to start from 0.
|
|
1021
|
+
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
1022
|
+
if _enable_distributed_mindrt():
|
|
1023
|
+
_reset_op_id_with_offset()
|
|
1024
|
+
|
|
1025
|
+
self._eval_during_train(valid_infos, cb_params, list_callback)
|
|
1026
|
+
|
|
1027
|
+
# In disaster recovery scenarios, need not to execute callbacks if this epoch executes failed.
|
|
1028
|
+
# Embedding cache server need not do epoch end callback, this process only run one step.
|
|
1029
|
+
need_exec_callback_epoch_end = not ((self.enable_recovery and _get_recovery_context("need_reset"))
|
|
1030
|
+
or is_embedding_cache_server)
|
|
1031
|
+
|
|
1032
|
+
if need_exec_callback_epoch_end:
|
|
1033
|
+
list_callback.on_train_epoch_end(run_context)
|
|
1034
|
+
if "metrics" in cb_params or "eval_results" in cb_params:
|
|
1035
|
+
cb_params.pop("metrics", None)
|
|
1036
|
+
cb_params.pop("eval_results", None)
|
|
1037
|
+
|
|
1038
|
+
should_stop = run_context.get_stop_requested()
|
|
1039
|
+
if should_stop:
|
|
1040
|
+
break
|
|
1041
|
+
|
|
1042
|
+
need_reset_to_beginning = self.enable_recovery and _get_recovery_context("need_reset")\
|
|
1043
|
+
and not _get_recovery_context("latest_ckpt_file")
|
|
1044
|
+
self.epoch_iter += 1
|
|
1045
|
+
if need_reset_to_beginning:
|
|
1046
|
+
self.epoch_iter = 0
|
|
1047
|
+
cb_params.cur_step_num = 0
|
|
1048
|
+
|
|
1049
|
+
dataset_helper.stop_send()
|
|
1050
|
+
dataset_helper.release()
|
|
1051
|
+
|
|
1052
|
+
list_callback.on_train_end(run_context)
|
|
1053
|
+
|
|
1054
|
+
def _eval_during_train(self, valid_infos, cb_params, list_callback):
|
|
1055
|
+
"""Exec eval during train process."""
|
|
1056
|
+
valid_dataset, valid_frequency, valid_dataset_sink_mode = valid_infos
|
|
1057
|
+
if valid_dataset and self._should_eval(cb_params.cur_epoch_num, valid_frequency):
|
|
1058
|
+
train_cur_step_num = cb_params.cur_step_num
|
|
1059
|
+
train_batch_num = cb_params.batch_num
|
|
1060
|
+
train_dataset_sink_mode = cb_params.dataset_sink_mode
|
|
1061
|
+
train_net_outputs = cb_params.net_outputs
|
|
1062
|
+
|
|
1063
|
+
eval_callback = []
|
|
1064
|
+
for cb in list_callback._callbacks:
|
|
1065
|
+
if cb.__class__.__name__ in internal_cb_names:
|
|
1066
|
+
if isinstance(cb, TimeMonitor):
|
|
1067
|
+
eval_callback.append(cb)
|
|
1068
|
+
else:
|
|
1069
|
+
eval_callback.append(cb)
|
|
1070
|
+
|
|
1071
|
+
self._eval_in_fit(valid_dataset,
|
|
1072
|
+
callbacks=eval_callback,
|
|
1073
|
+
dataset_sink_mode=valid_dataset_sink_mode,
|
|
1074
|
+
cb_params=cb_params)
|
|
1075
|
+
cb_params.mode = "train"
|
|
1076
|
+
cb_params.cur_step_num = train_cur_step_num
|
|
1077
|
+
cb_params.batch_num = train_batch_num
|
|
1078
|
+
cb_params.dataset_sink_mode = train_dataset_sink_mode
|
|
1079
|
+
cb_params.net_outputs = train_net_outputs
|
|
1080
|
+
|
|
1081
|
+
def _check_enable_recovery(self):
|
|
1082
|
+
"""
|
|
1083
|
+
Check whether enable recovery and execution mode consistency.
|
|
1084
|
+
"""
|
|
1085
|
+
|
|
1086
|
+
enable_recovery = _get_recovery_context("enable_recovery")
|
|
1087
|
+
if not enable_recovery:
|
|
1088
|
+
self.enable_recovery = False
|
|
1089
|
+
else:
|
|
1090
|
+
if context.get_context("mode") != context.GRAPH_MODE:
|
|
1091
|
+
raise RuntimeError("Recovery for training only support graph mode currently.")
|
|
1092
|
+
self.enable_recovery = enable_recovery and _is_role_worker()
|
|
1093
|
+
|
|
1094
|
+
def _check_need_load_ckpt(self, cb_params, dataset_size, sink_size=-1):
|
|
1095
|
+
"""
|
|
1096
|
+
Check whether need to load checkpoint after abnormal process restart.
|
|
1097
|
+
|
|
1098
|
+
Args:
|
|
1099
|
+
cb_params (_InternalCallbackParam): Callback parameters.
|
|
1100
|
+
dataset_size (int): The number of batches in a dataset.
|
|
1101
|
+
sink_size (int): Control the amount of data in each sink. Default: -1.
|
|
1102
|
+
"""
|
|
1103
|
+
if not self.enable_recovery:
|
|
1104
|
+
self.need_load_ckpt = False
|
|
1105
|
+
|
|
1106
|
+
cb_params.latest_ckpt_file = _get_recovery_context("latest_ckpt_file")
|
|
1107
|
+
if cb_params.latest_ckpt_file:
|
|
1108
|
+
recovery_epoch_num = _get_recovery_context("latest_ckpt_epoch")
|
|
1109
|
+
recovery_step_num = _get_recovery_context("latest_ckpt_step")
|
|
1110
|
+
dataset_sink_size = sink_size if sink_size > 0 else dataset_size
|
|
1111
|
+
cb_params.cur_step_num = (recovery_epoch_num - 1) * dataset_sink_size + recovery_step_num
|
|
1112
|
+
cb_params.last_save_ckpt_step = cb_params.cur_step_num
|
|
1113
|
+
self.epoch_iter = recovery_epoch_num
|
|
1114
|
+
self.need_load_ckpt = True
|
|
1115
|
+
else:
|
|
1116
|
+
self.need_load_ckpt = False
|
|
1117
|
+
|
|
1118
|
+
def _reset_training_step_for_abnormal_process(self, cb_params, dataset_helper):
|
|
1119
|
+
"""
|
|
1120
|
+
Execute recovery for abnormal exit process when restart.
|
|
1121
|
+
|
|
1122
|
+
Args:
|
|
1123
|
+
cb_params (_InternalCallbackParam): Callback parameters.
|
|
1124
|
+
"""
|
|
1125
|
+
|
|
1126
|
+
if self.need_load_ckpt:
|
|
1127
|
+
try:
|
|
1128
|
+
load_checkpoint(cb_params.latest_ckpt_file, cb_params.train_network)
|
|
1129
|
+
except BaseException as e:
|
|
1130
|
+
os.remove(cb_params.latest_ckpt_file)
|
|
1131
|
+
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
|
|
1132
|
+
+ cb_params.latest_ckpt_file) from e
|
|
1133
|
+
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
|
|
1134
|
+
self.need_load_ckpt = False
|
|
1135
|
+
|
|
1136
|
+
def _reset_training_step_for_normal_process(self, cb_params, dataset_helper):
|
|
1137
|
+
"""
|
|
1138
|
+
Execute recovery for normal process when there is process exit abnormally.
|
|
1139
|
+
|
|
1140
|
+
Args:
|
|
1141
|
+
cb_params (_InternalCallbackParam): Callback parameters.
|
|
1142
|
+
dataset_helper (DatasetHelper): A class to process the MindData dataset,
|
|
1143
|
+
it provides the type, shape and queue name of the dataset to wrap the `GetNext`.
|
|
1144
|
+
"""
|
|
1145
|
+
|
|
1146
|
+
if self.enable_recovery and _get_recovery_context("need_reset"):
|
|
1147
|
+
cb_params.latest_ckpt_file = _get_recovery_context("latest_ckpt_file")
|
|
1148
|
+
if cb_params.latest_ckpt_file:
|
|
1149
|
+
try:
|
|
1150
|
+
load_checkpoint(cb_params.latest_ckpt_file, cb_params.train_network)
|
|
1151
|
+
except BaseException as e:
|
|
1152
|
+
os.remove(cb_params.latest_ckpt_file)
|
|
1153
|
+
raise RuntimeError(e.__str__() + ", load ckpt failed and remove the ckpt: "\
|
|
1154
|
+
+ cb_params.latest_ckpt_file) from e
|
|
1155
|
+
|
|
1156
|
+
recovery_epoch_num = _get_recovery_context("latest_ckpt_epoch")
|
|
1157
|
+
recovery_step_num = _get_recovery_context("latest_ckpt_step")
|
|
1158
|
+
cb_params.cur_step_num = (recovery_epoch_num - 1) * dataset_helper.sink_size() + recovery_step_num
|
|
1159
|
+
self.epoch_iter = recovery_epoch_num
|
|
1160
|
+
cb_params.cur_epoch_num = self.epoch_iter + 1
|
|
1161
|
+
cb_params.last_save_ckpt_step = cb_params.cur_step_num
|
|
1162
|
+
_reset_training_dataset(cb_params.cur_step_num, dataset_helper.iter.dataset.get_dataset_size())
|
|
1163
|
+
else:
|
|
1164
|
+
_reset_training_dataset(0, dataset_helper.iter.dataset.get_dataset_size())
|
|
1165
|
+
|
|
1166
|
+
_set_recovery_context(need_reset=False)
|
|
1167
|
+
|
|
1168
|
+
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None, initial_epoch=0,
|
|
1169
|
+
valid_infos=None):
|
|
1170
|
+
"""
|
|
1171
|
+
Training process. The data would be passed to network directly.
|
|
1172
|
+
|
|
1173
|
+
Args:
|
|
1174
|
+
epoch (int): Total number of iterations on the data.
|
|
1175
|
+
train_dataset (Dataset): A training dataset iterator. If there is no
|
|
1176
|
+
loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
|
|
1177
|
+
returned and passed to the network. Otherwise, a tuple (data, label) should
|
|
1178
|
+
be returned. The data and label would be passed to the network and loss
|
|
1179
|
+
function respectively.
|
|
1180
|
+
list_callback (Callback): Executor of callback list. Default: ``None``.
|
|
1181
|
+
cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
|
|
1182
|
+
initial_epoch (int): Epoch at which to start train, it used for resuming a previous training run.
|
|
1183
|
+
Default: 0.
|
|
1184
|
+
"""
|
|
1185
|
+
dataset_helper, _ = self._exec_preprocess(is_train=True,
|
|
1186
|
+
dataset=train_dataset,
|
|
1187
|
+
dataset_sink_mode=False,
|
|
1188
|
+
epoch_num=epoch)
|
|
1189
|
+
cb_params.dataset_sink_mode = False
|
|
1190
|
+
run_context = RunContext(cb_params)
|
|
1191
|
+
list_callback.on_train_begin(run_context)
|
|
1192
|
+
is_embedding_cache_server = _is_role_pserver() and _cache_enable()
|
|
1193
|
+
|
|
1194
|
+
for i in range(initial_epoch, epoch):
|
|
1195
|
+
cb_params.cur_epoch_num = i + 1
|
|
1196
|
+
self._current_epoch_num = cb_params.cur_epoch_num
|
|
1197
|
+
self._current_step_num = 0
|
|
1198
|
+
|
|
1199
|
+
list_callback.on_train_epoch_begin(run_context)
|
|
1200
|
+
|
|
1201
|
+
for next_element in dataset_helper:
|
|
1202
|
+
len_element = len(next_element)
|
|
1203
|
+
next_element = _transfer_tensor_to_tuple(next_element)
|
|
1204
|
+
if self._loss_fn and len_element != 2:
|
|
1205
|
+
raise ValueError("When 'loss_fn' is not None, 'train_dataset' should return "
|
|
1206
|
+
"two elements, but got {}, please check the number of elements "
|
|
1207
|
+
"returned by 'train_dataset'".format(len_element))
|
|
1208
|
+
cb_params.cur_step_num += 1
|
|
1209
|
+
self._current_step_num = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
|
1210
|
+
cb_params.train_dataset_element = next_element
|
|
1211
|
+
list_callback.on_train_step_begin(run_context)
|
|
1212
|
+
self._check_network_mode(self._train_network, True)
|
|
1213
|
+
outputs = self._train_network(*next_element)
|
|
1214
|
+
cb_params.net_outputs = outputs
|
|
1215
|
+
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
|
|
1216
|
+
overflow = outputs[1]
|
|
1217
|
+
overflow = np.all(overflow.asnumpy())
|
|
1218
|
+
self._loss_scale_manager.update_loss_scale(overflow)
|
|
1219
|
+
|
|
1220
|
+
list_callback.on_train_step_end(run_context)
|
|
1221
|
+
# Embedding cache server only run one step.
|
|
1222
|
+
if is_embedding_cache_server:
|
|
1223
|
+
break
|
|
1224
|
+
should_stop = run_context.get_stop_requested()
|
|
1225
|
+
if should_stop:
|
|
1226
|
+
break
|
|
1227
|
+
|
|
1228
|
+
# When it's distributed training and using MindRT,
|
|
1229
|
+
# the node id should be reset to start from 0.
|
|
1230
|
+
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
1231
|
+
if _enable_distributed_mindrt():
|
|
1232
|
+
_reset_op_id_with_offset()
|
|
1233
|
+
|
|
1234
|
+
self._eval_during_train(valid_infos, cb_params, list_callback)
|
|
1235
|
+
|
|
1236
|
+
train_dataset.reset()
|
|
1237
|
+
|
|
1238
|
+
# if param is cache enable, flush data from cache to host before epoch end
|
|
1239
|
+
self._flush_from_cache(cb_params)
|
|
1240
|
+
|
|
1241
|
+
# Embedding cache server need not do epoch end callback, this process only run one step.
|
|
1242
|
+
if not is_embedding_cache_server:
|
|
1243
|
+
list_callback.on_train_epoch_end(run_context)
|
|
1244
|
+
if "metrics" in cb_params or "eval_results" in cb_params:
|
|
1245
|
+
cb_params.pop("metrics", None)
|
|
1246
|
+
cb_params.pop("eval_results", None)
|
|
1247
|
+
should_stop = run_context.get_stop_requested()
|
|
1248
|
+
if should_stop:
|
|
1249
|
+
break
|
|
1250
|
+
|
|
1251
|
+
list_callback.on_train_end(run_context)
|
|
1252
|
+
|
|
1253
|
+
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=False, sink_size=-1, initial_epoch=0):
|
|
1254
|
+
"""
|
|
1255
|
+
Training API.
|
|
1256
|
+
|
|
1257
|
+
When setting pynative mode or CPU, the training process will be performed with dataset not sink.
|
|
1258
|
+
|
|
1259
|
+
Note:
|
|
1260
|
+
If dataset_sink_mode is True, data will be sent to device. If the device is Ascend, features
|
|
1261
|
+
of data will be transferred one by one. The limitation of data transmission per time is 256M.
|
|
1262
|
+
|
|
1263
|
+
When dataset_sink_mode is True, the `step_end` method of the instance of Callback will be called at the end
|
|
1264
|
+
of step in PyNative mode, or will be called at the end of epoch in Graph mode.
|
|
1265
|
+
|
|
1266
|
+
If dataset_sink_mode is True, dataset will be bound to this model and cannot be used by other models.
|
|
1267
|
+
|
|
1268
|
+
If sink_size > 0, each epoch of the dataset can be traversed unlimited times until you get sink_size
|
|
1269
|
+
elements of the dataset. The next epoch continues to traverse from the end position of the previous
|
|
1270
|
+
traversal.
|
|
1271
|
+
|
|
1272
|
+
The interface builds the computational graphs and then executes the computational graphs. However, when
|
|
1273
|
+
the `Model.build` is executed first, it only performs the graphs execution.
|
|
1274
|
+
|
|
1275
|
+
Args:
|
|
1276
|
+
epoch (int): Total training epochs. Generally, train network will be trained on complete dataset per epoch.
|
|
1277
|
+
If `dataset_sink_mode` is set to True and `sink_size` is greater than 0, each epoch will
|
|
1278
|
+
train `sink_size` steps instead of total steps of dataset.
|
|
1279
|
+
If `epoch` used with `initial_epoch`, it is to be understood as "final epoch".
|
|
1280
|
+
train_dataset (Dataset): A training dataset iterator. If `loss_fn` is defined, the data and label will be
|
|
1281
|
+
passed to the `network` and the `loss_fn` respectively, so a tuple (data, label)
|
|
1282
|
+
should be returned from dataset. If there is multiple data or labels, set `loss_fn`
|
|
1283
|
+
to None and implement calculation of loss in `network`,
|
|
1284
|
+
then a tuple (data1, data2, data3, ...) with all data returned from dataset will be
|
|
1285
|
+
passed to the `network`.
|
|
1286
|
+
callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object,
|
|
1287
|
+
which should be executed while training.
|
|
1288
|
+
Default: ``None``.
|
|
1289
|
+
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
|
|
1290
|
+
Configure pynative mode or CPU, the training process will be performed with
|
|
1291
|
+
dataset not sink. Default: ``False``.
|
|
1292
|
+
sink_size (int): Control the number of steps for each sinking.
|
|
1293
|
+
`sink_size` is invalid if `dataset_sink_mode` is False.
|
|
1294
|
+
If sink_size = -1, sink the complete dataset for each epoch.
|
|
1295
|
+
If sink_size > 0, sink sink_size data for each epoch.
|
|
1296
|
+
Default: -1.
|
|
1297
|
+
initial_epoch (int): Epoch at which to start train, it used for resuming a previous training run.
|
|
1298
|
+
Default: 0.
|
|
1299
|
+
|
|
1300
|
+
Examples:
|
|
1301
|
+
>>> import mindspore as ms
|
|
1302
|
+
>>> from mindspore import nn
|
|
1303
|
+
>>> from mindspore.train import Model
|
|
1304
|
+
>>>
|
|
1305
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
1306
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
1307
|
+
>>> dataset = create_dataset()
|
|
1308
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1309
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1310
|
+
>>> net = LeNet5()
|
|
1311
|
+
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
1312
|
+
>>> loss_scale_manager = ms.FixedLossScaleManager(1024., False)
|
|
1313
|
+
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
1314
|
+
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
|
|
1315
|
+
... loss_scale_manager=loss_scale_manager)
|
|
1316
|
+
>>> model.train(2, dataset)
|
|
1317
|
+
"""
|
|
1318
|
+
_check_tft()
|
|
1319
|
+
device_target = context.get_context("device_target")
|
|
1320
|
+
# prepare dataset for obfuscated model
|
|
1321
|
+
train_dataset = self._prepare_obf_dataset(train_dataset)
|
|
1322
|
+
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
1323
|
+
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
|
1324
|
+
dataset_sink_mode = False
|
|
1325
|
+
|
|
1326
|
+
Validator.check_bool(dataset_sink_mode)
|
|
1327
|
+
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode:
|
|
1328
|
+
raise ValueError("Dataset sink mode is currently not supported when training with a GraphCell.")
|
|
1329
|
+
|
|
1330
|
+
if hasattr(train_dataset, '_warmup_epoch') and train_dataset._warmup_epoch != epoch:
|
|
1331
|
+
raise ValueError("when use Model.build to initialize model, the value of parameter 'epoch' in Model.build "
|
|
1332
|
+
"should be equal to value in Model.train, but got the value of epoch in build {} and "
|
|
1333
|
+
"the value of epoch in train {} separately."
|
|
1334
|
+
.format(train_dataset._warmup_epoch, epoch))
|
|
1335
|
+
|
|
1336
|
+
# Parameter server and embedding cache mode check.
|
|
1337
|
+
if _is_ps_mode():
|
|
1338
|
+
if not dataset_sink_mode and _cache_enable():
|
|
1339
|
+
raise ValueError("Embedding cache mode should run with 'dataset_sink_mode=True'.")
|
|
1340
|
+
|
|
1341
|
+
self._check_sink_mode_for_ds_debug_mode(dataset_sink_mode)
|
|
1342
|
+
|
|
1343
|
+
Validator.check_is_int(sink_size)
|
|
1344
|
+
Validator.check_positive_int(epoch)
|
|
1345
|
+
Validator.check_non_negative_int(initial_epoch)
|
|
1346
|
+
if initial_epoch >= epoch:
|
|
1347
|
+
raise ValueError(f"For 'Model.train', the parameter 'epoch' must bigger than parameter 'initial_epoch',"
|
|
1348
|
+
f" but got the parameter 'epoch' is {epoch}, 'initial_epoch' is {initial_epoch}.")
|
|
1349
|
+
|
|
1350
|
+
dataset_size = train_dataset.get_dataset_size()
|
|
1351
|
+
if dataset_size == 0:
|
|
1352
|
+
raise ValueError("There is no valid data in dataset, please check dataset file firstly.")
|
|
1353
|
+
if sink_size == -1:
|
|
1354
|
+
sink_size = dataset_size
|
|
1355
|
+
if sink_size < -1 or sink_size == 0:
|
|
1356
|
+
raise ValueError("For 'Model.train', The argument 'sink_size' must be -1 or positive, "
|
|
1357
|
+
"but got {}.".format(sink_size))
|
|
1358
|
+
|
|
1359
|
+
_device_number_check(self._parallel_mode, self._device_number)
|
|
1360
|
+
|
|
1361
|
+
callbacks = _append_ccae(callbacks)
|
|
1362
|
+
if callbacks:
|
|
1363
|
+
self._check_methods_for_custom_callbacks(callbacks, "train")
|
|
1364
|
+
self._train(epoch,
|
|
1365
|
+
train_dataset,
|
|
1366
|
+
callbacks=callbacks,
|
|
1367
|
+
dataset_sink_mode=dataset_sink_mode,
|
|
1368
|
+
sink_size=sink_size,
|
|
1369
|
+
initial_epoch=initial_epoch)
|
|
1370
|
+
|
|
1371
|
+
# When it's distributed training and using MindRT,
|
|
1372
|
+
# the node id should be reset to start from 0.
|
|
1373
|
+
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
1374
|
+
if _enable_distributed_mindrt():
|
|
1375
|
+
_reset_op_id_with_offset()
|
|
1376
|
+
|
|
1377
|
+
@staticmethod
|
|
1378
|
+
def _check_sink_mode_for_ds_debug_mode(dataset_sink_mode):
|
|
1379
|
+
if get_debug_mode() and dataset_sink_mode:
|
|
1380
|
+
raise ValueError("Dataset sink mode is not supported when dataset pipeline debug mode is on. "
|
|
1381
|
+
"Please manually turn off sink mode.")
|
|
1382
|
+
|
|
1383
|
+
@staticmethod
|
|
1384
|
+
def _check_methods_for_custom_callbacks(callbacks, current_mode):
|
|
1385
|
+
"""
|
|
1386
|
+
Check whether methods of custimized callbacks are valid.
|
|
1387
|
+
|
|
1388
|
+
Args:
|
|
1389
|
+
callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object.
|
|
1390
|
+
current_mode (str): 'fit', 'train' or 'eval'.
|
|
1391
|
+
"""
|
|
1392
|
+
old_version_methods_names = {'begin', 'end', 'epoch_begin', 'epoch_end', 'step_begin', 'step_end'}
|
|
1393
|
+
if not isinstance(callbacks, list):
|
|
1394
|
+
callbacks = [callbacks]
|
|
1395
|
+
for cb in callbacks:
|
|
1396
|
+
cb_name = cb.__class__.__name__
|
|
1397
|
+
if cb_name not in internal_cb_names:
|
|
1398
|
+
cb_methods_names = set(cb.__class__.__dict__.keys())
|
|
1399
|
+
invalid_methods_names = cb_methods_names & old_version_methods_names
|
|
1400
|
+
if invalid_methods_names:
|
|
1401
|
+
if current_mode in ["train", "eval"]:
|
|
1402
|
+
logger.warning("For %s callback, %s methods may not be supported in later version, "
|
|
1403
|
+
"Use methods prefixed with 'on_train' or 'on_eval' instead "
|
|
1404
|
+
"when using customized callbacks." % (cb_name, invalid_methods_names))
|
|
1405
|
+
else:
|
|
1406
|
+
raise ValueError("For %s callback, %s methods may not be supported in later version, "
|
|
1407
|
+
"Use methods prefixed with 'on_train' or 'on_eval' instead when "
|
|
1408
|
+
"using customized callbacks." % (cb_name, invalid_methods_names))
|
|
1409
|
+
|
|
1410
|
+
def fit(self, epoch, train_dataset, valid_dataset=None, valid_frequency=1, callbacks=None,
|
|
1411
|
+
dataset_sink_mode=False, valid_dataset_sink_mode=False, sink_size=-1, initial_epoch=0):
|
|
1412
|
+
"""
|
|
1413
|
+
Fit API.
|
|
1414
|
+
|
|
1415
|
+
Evaluation process will be performed during training process if `valid_dataset` is provided.
|
|
1416
|
+
|
|
1417
|
+
More details please refer to :func:`mindspore.train.Model.train` and
|
|
1418
|
+
:func:`mindspore.train.Model.eval`.
|
|
1419
|
+
|
|
1420
|
+
Args:
|
|
1421
|
+
epoch (int): Total training epochs. Generally, train network will be trained on complete dataset per epoch.
|
|
1422
|
+
If `dataset_sink_mode` is set to True and `sink_size` is greater than 0, each epoch will
|
|
1423
|
+
train `sink_size` steps instead of total steps of dataset.
|
|
1424
|
+
If `epoch` used with `initial_epoch`, it is to be understood as "final epoch".
|
|
1425
|
+
train_dataset (Dataset): A training dataset iterator. If `loss_fn` is defined, the data and label will be
|
|
1426
|
+
passed to the `network` and the `loss_fn` respectively, so a tuple (data, label)
|
|
1427
|
+
should be returned from dataset. If there is multiple data or labels, set `loss_fn`
|
|
1428
|
+
to None and implement calculation of loss in `network`,
|
|
1429
|
+
then a tuple (data1, data2, data3, ...) with all data returned from dataset
|
|
1430
|
+
will be passed to the `network`.
|
|
1431
|
+
valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process
|
|
1432
|
+
will be performed on the end of training process. Default: ``None`` .
|
|
1433
|
+
valid_frequency (int, list): Only relevant if `valid_dataset` is provided. If an integer, specifies
|
|
1434
|
+
how many training epochs to run before a new validation run is performed,
|
|
1435
|
+
e.g. `valid_frequency=2` runs validation every 2 epochs.
|
|
1436
|
+
If a list, specifies the epochs on which to run validation,
|
|
1437
|
+
e.g. `valid_frequency=[1, 5]` runs validation at the end of the 1st, 5th epochs.
|
|
1438
|
+
Default: ``1`` .
|
|
1439
|
+
callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object,
|
|
1440
|
+
which should be executed while training.
|
|
1441
|
+
Default: ``None`` .
|
|
1442
|
+
dataset_sink_mode (bool): Determines whether to pass the train data through dataset channel.
|
|
1443
|
+
Configure pynative mode or CPU, the training process will be performed with
|
|
1444
|
+
dataset not sink. Default: ``False`` .
|
|
1445
|
+
valid_dataset_sink_mode (bool): Determines whether to pass the validation data through dataset channel.
|
|
1446
|
+
Default: ``False`` .
|
|
1447
|
+
sink_size (int): Control the number of steps for each sinking.
|
|
1448
|
+
`sink_size` is invalid if `dataset_sink_mode` is False.
|
|
1449
|
+
If sink_size = -1, sink the complete dataset for each epoch.
|
|
1450
|
+
If sink_size > 0, sink sink_size data for each epoch.
|
|
1451
|
+
Default: ``-1`` .
|
|
1452
|
+
initial_epoch (int): Epoch at which to start train, it useful for resuming a previous training run.
|
|
1453
|
+
Default: ``0`` .
|
|
1454
|
+
|
|
1455
|
+
Examples:
|
|
1456
|
+
>>> from mindspore import nn
|
|
1457
|
+
>>> from mindspore.train import Model
|
|
1458
|
+
>>>
|
|
1459
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
1460
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
1461
|
+
>>> train_dataset = create_dataset("train")
|
|
1462
|
+
>>> valid_dataset = create_dataset("test")
|
|
1463
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1464
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1465
|
+
>>> net = LeNet5()
|
|
1466
|
+
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
1467
|
+
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
1468
|
+
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={"accuracy"})
|
|
1469
|
+
>>> model.fit(2, train_dataset, valid_dataset)
|
|
1470
|
+
|
|
1471
|
+
Tutorial Examples:
|
|
1472
|
+
- `Advanced Encapsulation: Model - Train and Save Model
|
|
1473
|
+
<https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
|
|
1474
|
+
"""
|
|
1475
|
+
device_target = context.get_context("device_target")
|
|
1476
|
+
if _is_ps_mode() and not _cache_enable() and (device_target in ["Ascend", "CPU"]) and dataset_sink_mode:
|
|
1477
|
+
logger.info("For PS mode, reset datasink mode to False when using Ascend or CPU backend.")
|
|
1478
|
+
dataset_sink_mode = False
|
|
1479
|
+
|
|
1480
|
+
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
|
1481
|
+
valid_dataset_sink_mode = Validator.check_bool(valid_dataset_sink_mode)
|
|
1482
|
+
|
|
1483
|
+
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode:
|
|
1484
|
+
raise ValueError("Dataset sink mode is currently not supported when training with a GraphCell.")
|
|
1485
|
+
|
|
1486
|
+
if hasattr(train_dataset, '_warmup_epoch') and train_dataset._warmup_epoch != epoch:
|
|
1487
|
+
raise ValueError("when use Model.build to initialize model, the value of parameter `epoch` in Model.build "
|
|
1488
|
+
"should be equal to value in Model.fit, but got {} and {} separately."
|
|
1489
|
+
.format(train_dataset._warmup_epoch, epoch))
|
|
1490
|
+
|
|
1491
|
+
Validator.check_is_int(sink_size)
|
|
1492
|
+
Validator.check_positive_int(epoch)
|
|
1493
|
+
Validator.check_non_negative_int(initial_epoch)
|
|
1494
|
+
if initial_epoch >= epoch:
|
|
1495
|
+
raise ValueError(f"For 'Model.fit', the parameter 'epoch' must bigger than parameter 'initial_epoch',"
|
|
1496
|
+
f" but got the parameter 'epoch' is {epoch}, 'initial_epoch' is {initial_epoch}.")
|
|
1497
|
+
dataset_size = train_dataset.get_dataset_size()
|
|
1498
|
+
if dataset_size == 0:
|
|
1499
|
+
raise ValueError("There is no valid data in dataset, please check dataset file firstly.")
|
|
1500
|
+
if sink_size == -1:
|
|
1501
|
+
sink_size = dataset_size
|
|
1502
|
+
if sink_size < -1 or sink_size == 0:
|
|
1503
|
+
raise ValueError("For 'Model.fit', The parameter 'sink_size' must be -1 or positive, "
|
|
1504
|
+
"but got {}.".format(sink_size))
|
|
1505
|
+
|
|
1506
|
+
_device_number_check(self._parallel_mode, self._device_number)
|
|
1507
|
+
|
|
1508
|
+
if not isinstance(valid_frequency, (int, list)):
|
|
1509
|
+
raise TypeError(f"For 'Model.fit', the type of 'valid_frequency' must be a list or an integer, but got "
|
|
1510
|
+
f"type {type(valid_frequency)}.")
|
|
1511
|
+
|
|
1512
|
+
if valid_dataset and not self._metric_fns:
|
|
1513
|
+
raise ValueError("For 'Model.fit', if valid_dataset is not None, the model argument 'metrics' can not be"
|
|
1514
|
+
"None or empty, you should set the argument 'metrics' for model.")
|
|
1515
|
+
if callbacks:
|
|
1516
|
+
self._check_methods_for_custom_callbacks(callbacks, "fit")
|
|
1517
|
+
self._train(epoch,
|
|
1518
|
+
train_dataset,
|
|
1519
|
+
callbacks=callbacks,
|
|
1520
|
+
dataset_sink_mode=dataset_sink_mode,
|
|
1521
|
+
sink_size=sink_size,
|
|
1522
|
+
initial_epoch=initial_epoch,
|
|
1523
|
+
valid_dataset=valid_dataset,
|
|
1524
|
+
valid_frequency=valid_frequency,
|
|
1525
|
+
valid_dataset_sink_mode=valid_dataset_sink_mode)
|
|
1526
|
+
|
|
1527
|
+
def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1):
|
|
1528
|
+
"""
|
|
1529
|
+
Build computational graphs and data graphs with the sink mode.
|
|
1530
|
+
|
|
1531
|
+
.. warning::
|
|
1532
|
+
This is an experimental API that is subject to change or deletion.
|
|
1533
|
+
|
|
1534
|
+
Note:
|
|
1535
|
+
The interface builds the computational graphs, when the interface is executed first, 'Model.train' only
|
|
1536
|
+
performs the graphs execution. Pre-build process only supports `GRAPH_MODE` and `Ascend` target currently.
|
|
1537
|
+
It only supports dataset sink mode.
|
|
1538
|
+
|
|
1539
|
+
Args:
|
|
1540
|
+
train_dataset (Dataset): A training dataset iterator. If `train_dataset` is defined, training graphs will be
|
|
1541
|
+
built. Default: ``None`` .
|
|
1542
|
+
valid_dataset (Dataset): An evaluating dataset iterator. If `valid_dataset` is defined, evaluation graphs
|
|
1543
|
+
will be built, and `metrics` in `Model` can not be None. Default: ``None`` .
|
|
1544
|
+
sink_size (int): Control the number of steps for each sinking. Default: ``-1`` .
|
|
1545
|
+
epoch (int): Control the training epochs. Default: ``1`` .
|
|
1546
|
+
|
|
1547
|
+
Examples:
|
|
1548
|
+
>>> from mindspore import nn
|
|
1549
|
+
>>> from mindspore.train import Model
|
|
1550
|
+
>>> from mindspore.amp import FixedLossScaleManager
|
|
1551
|
+
>>>
|
|
1552
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
1553
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
1554
|
+
>>> dataset = create_dataset()
|
|
1555
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1556
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1557
|
+
>>> net = LeNet5()
|
|
1558
|
+
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
1559
|
+
>>> loss_scale_manager = FixedLossScaleManager()
|
|
1560
|
+
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
1561
|
+
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
|
|
1562
|
+
... loss_scale_manager=loss_scale_manager)
|
|
1563
|
+
>>> model.build(dataset, epoch=2)
|
|
1564
|
+
>>> model.train(2, dataset)
|
|
1565
|
+
"""
|
|
1566
|
+
epoch = Validator.check_positive_int(epoch)
|
|
1567
|
+
if hasattr(self._train_network, '_is_check_and_refresh') and not self._train_network._is_check_and_refresh:
|
|
1568
|
+
self._train_network.check_names_and_refresh_name()
|
|
1569
|
+
self._train_network._is_check_and_refresh = True
|
|
1570
|
+
logger.info("Begin to init dataset in model.build() procedure.")
|
|
1571
|
+
self._init(train_dataset, valid_dataset, sink_size, epoch)
|
|
1572
|
+
logger.info("The model.build() which contains dataset warmup and network compile is success.")
|
|
1573
|
+
|
|
1574
|
+
def _eval_in_fit(self, valid_dataset, callbacks=None, dataset_sink_mode=True, cb_params=None):
|
|
1575
|
+
"""
|
|
1576
|
+
Evaluation process in `mindspore.train.Model.fit`.
|
|
1577
|
+
|
|
1578
|
+
Args:
|
|
1579
|
+
valid_dataset (Dataset): Dataset to evaluate the model. If `valid_dataset` is provided, evaluation process
|
|
1580
|
+
will be performed on the end of training process. Default: ``None``.
|
|
1581
|
+
callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object, which should be
|
|
1582
|
+
executed while evaluation. Default: ``None``.
|
|
1583
|
+
valid_dataset_sink_mode (bool): Determines whether to pass the validation data through dataset channel.
|
|
1584
|
+
Default: ``True``.
|
|
1585
|
+
cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
|
|
1586
|
+
"""
|
|
1587
|
+
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode:
|
|
1588
|
+
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
|
|
1589
|
+
|
|
1590
|
+
cb_params.eval_network = self._eval_network
|
|
1591
|
+
cb_params.valid_dataset = valid_dataset
|
|
1592
|
+
cb_params.batch_num = valid_dataset.get_dataset_size()
|
|
1593
|
+
cb_params.mode = "eval"
|
|
1594
|
+
cb_params.cur_step_num = 0
|
|
1595
|
+
|
|
1596
|
+
self._clear_metrics()
|
|
1597
|
+
|
|
1598
|
+
if context.get_context("device_target") == "CPU" and dataset_sink_mode:
|
|
1599
|
+
dataset_sink_mode = False
|
|
1600
|
+
logger.info("CPU cannot support dataset sink mode currently."
|
|
1601
|
+
"So the evaluating process will be performed with dataset non-sink mode.")
|
|
1602
|
+
|
|
1603
|
+
with _CallbackManager(callbacks) as list_callback:
|
|
1604
|
+
if dataset_sink_mode:
|
|
1605
|
+
return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params, add_eval_loss=True)
|
|
1606
|
+
return self._eval_process(valid_dataset, list_callback, cb_params, add_eval_loss=True)
|
|
1607
|
+
|
|
1608
|
+
def _eval_dataset_sink_process(self, valid_dataset, list_callback=None, cb_params=None, add_eval_loss=False):
|
|
1609
|
+
"""
|
|
1610
|
+
Evaluation. The data would be passed to network through dataset channel.
|
|
1611
|
+
|
|
1612
|
+
Args:
|
|
1613
|
+
valid_dataset (Dataset): Dataset to evaluate the model.
|
|
1614
|
+
list_callback (Callback): Executor of callback list. Default: ``None``.
|
|
1615
|
+
cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
|
|
1616
|
+
|
|
1617
|
+
Returns:
|
|
1618
|
+
Dict, which returns the loss value and metrics values for the model in the test mode.
|
|
1619
|
+
"""
|
|
1620
|
+
run_context = RunContext(cb_params)
|
|
1621
|
+
|
|
1622
|
+
dataset_helper, eval_network = self._exec_preprocess(is_train=False,
|
|
1623
|
+
dataset=valid_dataset,
|
|
1624
|
+
dataset_sink_mode=True)
|
|
1625
|
+
cb_params.eval_network = eval_network
|
|
1626
|
+
cb_params.dataset_sink_mode = True
|
|
1627
|
+
list_callback.on_eval_begin(run_context)
|
|
1628
|
+
list_callback.on_eval_epoch_begin(run_context)
|
|
1629
|
+
for inputs in dataset_helper:
|
|
1630
|
+
cb_params.cur_step_num += 1
|
|
1631
|
+
inputs = _transfer_tensor_to_tuple(inputs)
|
|
1632
|
+
cb_params.eval_dataset_element = inputs
|
|
1633
|
+
list_callback.on_eval_step_begin(run_context)
|
|
1634
|
+
eval_network = self._check_network_mode(eval_network, False)
|
|
1635
|
+
outputs = eval_network(*inputs)
|
|
1636
|
+
cb_params.net_outputs = outputs
|
|
1637
|
+
list_callback.on_eval_step_end(run_context)
|
|
1638
|
+
self._update_metrics(outputs)
|
|
1639
|
+
if add_eval_loss:
|
|
1640
|
+
eval_loss_fn = get_metric_fn("loss")
|
|
1641
|
+
eval_loss_fn.update(outputs[self._eval_indexes[0]])
|
|
1642
|
+
|
|
1643
|
+
list_callback.on_eval_epoch_end(run_context)
|
|
1644
|
+
metrics = self._get_metrics()
|
|
1645
|
+
cb_params.metrics = metrics
|
|
1646
|
+
if add_eval_loss:
|
|
1647
|
+
eval_loss = eval_loss_fn.eval()
|
|
1648
|
+
cb_params.eval_results = copy.deepcopy(metrics)
|
|
1649
|
+
cb_params.eval_results.update({"eval_loss": eval_loss})
|
|
1650
|
+
list_callback.on_eval_end(run_context)
|
|
1651
|
+
|
|
1652
|
+
return metrics
|
|
1653
|
+
|
|
1654
|
+
def _eval_process(self, valid_dataset, list_callback=None, cb_params=None, add_eval_loss=False):
|
|
1655
|
+
"""
|
|
1656
|
+
Evaluation. The data would be passed to network directly.
|
|
1657
|
+
|
|
1658
|
+
Args:
|
|
1659
|
+
valid_dataset (Dataset): Dataset to evaluate the model.
|
|
1660
|
+
list_callback (Callback): Executor of callback list. Default: ``None``.
|
|
1661
|
+
cb_params (_InternalCallbackParam): Callback parameters. Default: ``None``.
|
|
1662
|
+
|
|
1663
|
+
Returns:
|
|
1664
|
+
Dict, which returns the loss value and metrics values for the model in the test mode.
|
|
1665
|
+
"""
|
|
1666
|
+
run_context = RunContext(cb_params)
|
|
1667
|
+
cb_params.dataset_sink_mode = False
|
|
1668
|
+
list_callback.on_eval_begin(run_context)
|
|
1669
|
+
dataset_helper, _ = self._exec_preprocess(is_train=False,
|
|
1670
|
+
dataset=valid_dataset,
|
|
1671
|
+
dataset_sink_mode=False)
|
|
1672
|
+
list_callback.on_eval_epoch_begin(run_context)
|
|
1673
|
+
for next_element in dataset_helper:
|
|
1674
|
+
cb_params.cur_step_num += 1
|
|
1675
|
+
next_element = _transfer_tensor_to_tuple(next_element)
|
|
1676
|
+
cb_params.eval_dataset_element = next_element
|
|
1677
|
+
list_callback.on_eval_step_begin(run_context)
|
|
1678
|
+
self._check_network_mode(self._eval_network, False)
|
|
1679
|
+
outputs = self._eval_network(*next_element)
|
|
1680
|
+
cb_params.net_outputs = outputs
|
|
1681
|
+
list_callback.on_eval_step_end(run_context)
|
|
1682
|
+
self._update_metrics(outputs)
|
|
1683
|
+
if add_eval_loss:
|
|
1684
|
+
eval_loss_fn = get_metric_fn("loss")
|
|
1685
|
+
eval_loss_fn.update(outputs[self._eval_indexes[0]])
|
|
1686
|
+
if run_context.get_stop_requested():
|
|
1687
|
+
break
|
|
1688
|
+
|
|
1689
|
+
list_callback.on_eval_epoch_end(run_context)
|
|
1690
|
+
valid_dataset.reset()
|
|
1691
|
+
metrics = self._get_metrics()
|
|
1692
|
+
cb_params.metrics = metrics
|
|
1693
|
+
if add_eval_loss:
|
|
1694
|
+
eval_loss = eval_loss_fn.eval()
|
|
1695
|
+
cb_params.eval_results = copy.deepcopy(metrics)
|
|
1696
|
+
cb_params.eval_results.update({"eval_loss": eval_loss})
|
|
1697
|
+
list_callback.on_eval_end(run_context)
|
|
1698
|
+
return metrics
|
|
1699
|
+
|
|
1700
|
+
def eval(self, valid_dataset, callbacks=None, dataset_sink_mode=False):
|
|
1701
|
+
"""
|
|
1702
|
+
Evaluation API.
|
|
1703
|
+
|
|
1704
|
+
Configure to pynative mode or CPU, the evaluating process will be performed with dataset non-sink mode.
|
|
1705
|
+
|
|
1706
|
+
Note:
|
|
1707
|
+
If dataset_sink_mode is True, data will be sent to device. At this point, the dataset will be bound to this
|
|
1708
|
+
model, so the dataset cannot be used by other models. If the device is Ascend, features
|
|
1709
|
+
of data will be transferred one by one. The limitation of data transmission per time is 256M.
|
|
1710
|
+
|
|
1711
|
+
The interface builds the computational graphs and then executes the computational graphs. However, when
|
|
1712
|
+
the `Model.build` is executed first, it only performs the graphs execution.
|
|
1713
|
+
|
|
1714
|
+
Args:
|
|
1715
|
+
valid_dataset (Dataset): Dataset to evaluate the model.
|
|
1716
|
+
callbacks (Optional[list(Callback), Callback]): List of callback objects or callback object,
|
|
1717
|
+
which should be executed while evaluation.
|
|
1718
|
+
Default: ``None`` .
|
|
1719
|
+
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
|
|
1720
|
+
Default: ``False`` .
|
|
1721
|
+
|
|
1722
|
+
Returns:
|
|
1723
|
+
Dict, the key is the metric name defined by users and the value is the metrics value for
|
|
1724
|
+
the model in the test mode.
|
|
1725
|
+
|
|
1726
|
+
Examples:
|
|
1727
|
+
>>> from mindspore import nn
|
|
1728
|
+
>>> from mindspore.train import Model
|
|
1729
|
+
>>>
|
|
1730
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
1731
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
1732
|
+
>>> dataset = create_dataset()
|
|
1733
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1734
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1735
|
+
>>> net = LeNet5()
|
|
1736
|
+
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
1737
|
+
>>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'})
|
|
1738
|
+
>>> acc = model.eval(dataset, dataset_sink_mode=False)
|
|
1739
|
+
|
|
1740
|
+
Tutorial Examples:
|
|
1741
|
+
- `Advanced Encapsulation: Model - Train and Save Model
|
|
1742
|
+
<https://www.mindspore.cn/docs/en/master/model_train/train_process/model.html#training-and-saving-model>`_
|
|
1743
|
+
"""
|
|
1744
|
+
valid_dataset = self._prepare_obf_dataset(valid_dataset)
|
|
1745
|
+
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
|
1746
|
+
|
|
1747
|
+
_device_number_check(self._parallel_mode, self._device_number)
|
|
1748
|
+
if not self._metric_fns:
|
|
1749
|
+
raise ValueError("For Model.eval, the model argument 'metrics' can not be None or empty, "
|
|
1750
|
+
"you should set the argument 'metrics' for model.")
|
|
1751
|
+
if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode:
|
|
1752
|
+
raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.")
|
|
1753
|
+
if callbacks:
|
|
1754
|
+
self._check_methods_for_custom_callbacks(callbacks, "eval")
|
|
1755
|
+
cb_params = _InternalCallbackParam()
|
|
1756
|
+
cb_params.eval_network = self._eval_network
|
|
1757
|
+
cb_params.valid_dataset = valid_dataset
|
|
1758
|
+
cb_params.batch_num = valid_dataset.get_dataset_size()
|
|
1759
|
+
cb_params.mode = "eval"
|
|
1760
|
+
cb_params.cur_step_num = 0
|
|
1761
|
+
cb_params.list_callback = self._transform_callbacks(callbacks)
|
|
1762
|
+
if os.environ.get("ENABLE_FLOPS_UTILIZATION_COLLECTOR") == "1" and \
|
|
1763
|
+
FlopsUtilizationCollector not in cb_params.list_callback:
|
|
1764
|
+
cb_params.list_callback.insert(0, FlopsUtilizationCollector(
|
|
1765
|
+
cb_params.batch_num, full_flops=False))
|
|
1766
|
+
cb_params.network = self._network
|
|
1767
|
+
|
|
1768
|
+
self._clear_metrics()
|
|
1769
|
+
|
|
1770
|
+
# Embedding cache server as a storage service, no need to execute eval.
|
|
1771
|
+
is_embedding_cache_server = _is_role_pserver() and _cache_enable()
|
|
1772
|
+
if is_embedding_cache_server:
|
|
1773
|
+
metrics = self._get_metrics()
|
|
1774
|
+
cb_params.metrics = metrics
|
|
1775
|
+
return metrics
|
|
1776
|
+
|
|
1777
|
+
if context.get_context("device_target") == "CPU" and dataset_sink_mode:
|
|
1778
|
+
dataset_sink_mode = False
|
|
1779
|
+
logger.info("CPU cannot support dataset sink mode currently."
|
|
1780
|
+
"So the evaluating process will be performed with dataset non-sink mode.")
|
|
1781
|
+
|
|
1782
|
+
with _CallbackManager(callbacks) as list_callback:
|
|
1783
|
+
if dataset_sink_mode:
|
|
1784
|
+
eval_result = self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params)
|
|
1785
|
+
else:
|
|
1786
|
+
eval_result = self._eval_process(valid_dataset, list_callback, cb_params)
|
|
1787
|
+
|
|
1788
|
+
# When it's distributed training and using MindRT,
|
|
1789
|
+
# the node id should be reset to start from 0.
|
|
1790
|
+
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
1791
|
+
if _enable_distributed_mindrt():
|
|
1792
|
+
_reset_op_id_with_offset()
|
|
1793
|
+
|
|
1794
|
+
return eval_result
|
|
1795
|
+
|
|
1796
|
+
def _predict_lite(self, *predict_data, config=None):
|
|
1797
|
+
"""
|
|
1798
|
+
Generate output predictions for the input samples using backend 'lite'.
|
|
1799
|
+
|
|
1800
|
+
Args:
|
|
1801
|
+
predict_data (Union[Tensor, list[Tensor], tuple[Tensor]], optional):
|
|
1802
|
+
The predict data, can be a single tensor,
|
|
1803
|
+
a list of tensor, or a tuple of tensor.
|
|
1804
|
+
|
|
1805
|
+
config (dict, optional) - The config parameter is enabled when the backend is ‘lite’.
|
|
1806
|
+
The config includes two parts: config_path (configPath, str) and config_item (str, dict).
|
|
1807
|
+
When the config_item is set, its priority is higher than the config_path. Set the ranking
|
|
1808
|
+
table file for inference. The content of the configuration file is as follows:
|
|
1809
|
+
|
|
1810
|
+
config_path defines the path of the configuration file, which is used to pass user-defined
|
|
1811
|
+
options during model building. In the following scenarios, users may need to set parameters.
|
|
1812
|
+
For example: "/home/user/config.ini". Default value: ``"" `` , here is the content of the
|
|
1813
|
+
config.ini file:
|
|
1814
|
+
|
|
1815
|
+
.. code-block::
|
|
1816
|
+
|
|
1817
|
+
[ascend_context]
|
|
1818
|
+
rank_table_file = [path_a](storage initial path of the rank table file)
|
|
1819
|
+
[execution_plan]
|
|
1820
|
+
[op_name1] = data_type:float16 (operator named op_name1 is set to data type float16)
|
|
1821
|
+
[op_name2] = data_type:float32 (operator named op_name2 is set to data type float32)
|
|
1822
|
+
|
|
1823
|
+
When only the config_path is configured, it is done as follows:
|
|
1824
|
+
|
|
1825
|
+
.. code-block::
|
|
1826
|
+
|
|
1827
|
+
config = {"configPath" : "/home/user/config.ini"}
|
|
1828
|
+
|
|
1829
|
+
When only the config_dict is configured, it is done as follows:
|
|
1830
|
+
|
|
1831
|
+
.. code-block::
|
|
1832
|
+
|
|
1833
|
+
config = {"ascend_context" : {"rank_table_file" : "path_b"},
|
|
1834
|
+
"execution_plan" : {"op_name1" : "data_type:float16", "op_name2" : "data_type:float32"}}
|
|
1835
|
+
|
|
1836
|
+
When both the `config_path` and the `config_dict` are configured, it is done as follows:
|
|
1837
|
+
|
|
1838
|
+
.. code-block::
|
|
1839
|
+
|
|
1840
|
+
config = {"configPath" : "/home/user/config.ini",
|
|
1841
|
+
"ascend_context" : {"rank_table_file" : "path_b"},
|
|
1842
|
+
"execution_plan" : {"op_name3" : "data_type:float16", "op_name4" : "data_type:float32"}}
|
|
1843
|
+
|
|
1844
|
+
Note that both the "configPath" is configured in the config_dict and the config_item,
|
|
1845
|
+
in this case, the path_b in the config_dict takes precedence.
|
|
1846
|
+
|
|
1847
|
+
Returns:
|
|
1848
|
+
Tensor, array(s) of predictions.
|
|
1849
|
+
"""
|
|
1850
|
+
def _get_lite_context(lite_context_input):
|
|
1851
|
+
# use default lite context parameters for now
|
|
1852
|
+
device_target = context.get_context("device_target").lower()
|
|
1853
|
+
lite_context_input.target = [device_target]
|
|
1854
|
+
if device_target == 'cpu':
|
|
1855
|
+
inter_op_parallel_num = context.get_context('inter_op_parallel_num')
|
|
1856
|
+
if inter_op_parallel_num and isinstance(inter_op_parallel_num, int):
|
|
1857
|
+
lite_context_input.cpu.inter_op_parallel_num = inter_op_parallel_num
|
|
1858
|
+
elif device_target == 'gpu':
|
|
1859
|
+
device_id = context.get_context('device_id')
|
|
1860
|
+
if device_id and isinstance(device_id, int):
|
|
1861
|
+
lite_context_input.gpu.device_id = device_id
|
|
1862
|
+
if context.get_auto_parallel_context("parallel_mode") == context.ParallelMode.SEMI_AUTO_PARALLEL:
|
|
1863
|
+
from mindspore.communication import init, get_rank
|
|
1864
|
+
init()
|
|
1865
|
+
lite_context_input.gpu.rank_id = get_rank()
|
|
1866
|
+
elif device_target == 'ascend':
|
|
1867
|
+
device_id = context.get_context('device_id')
|
|
1868
|
+
if device_id and isinstance(device_id, int):
|
|
1869
|
+
lite_context_input.ascend.device_id = device_id
|
|
1870
|
+
if context.get_auto_parallel_context("parallel_mode") == context.ParallelMode.SEMI_AUTO_PARALLEL:
|
|
1871
|
+
from mindspore.communication import init, get_rank
|
|
1872
|
+
init()
|
|
1873
|
+
lite_context_input.ascend.rank_id = get_rank()
|
|
1874
|
+
lite_context_input.ascend.provider = "ge"
|
|
1875
|
+
else:
|
|
1876
|
+
raise RuntimeError(f"For predict lite, device target should be in ['gpu', 'cpu', 'ascend']"
|
|
1877
|
+
f" but got {device_target}")
|
|
1878
|
+
return lite_context_input
|
|
1879
|
+
|
|
1880
|
+
if not self._mindspore_lite:
|
|
1881
|
+
self._mindspore_lite = importlib.import_module('mindspore_lite')
|
|
1882
|
+
|
|
1883
|
+
use_past = False # default execute full model inference
|
|
1884
|
+
model_group_id = None
|
|
1885
|
+
if self._predict_network.get_flags().__contains__("is_first_iteration"):
|
|
1886
|
+
is_first_iteration = self._predict_network.get_flags()['is_first_iteration']
|
|
1887
|
+
if isinstance(is_first_iteration, bool):
|
|
1888
|
+
use_past = not is_first_iteration
|
|
1889
|
+
model_group_id = self._mindspore_lite_model_group_id
|
|
1890
|
+
|
|
1891
|
+
check_input_data(*predict_data, data_class=(int, float, str, None, Tensor))
|
|
1892
|
+
if use_past:
|
|
1893
|
+
# Execute incremental model inference
|
|
1894
|
+
if not self._lite_incremental_predictor:
|
|
1895
|
+
lite_context = _get_lite_context(self._mindspore_lite.Context())
|
|
1896
|
+
self._lite_incremental_predictor = \
|
|
1897
|
+
self._mindspore_lite.lite_infer.LiteInfer(self, *predict_data, context=lite_context,
|
|
1898
|
+
model_group_id=model_group_id, config=config)
|
|
1899
|
+
|
|
1900
|
+
inputs = self._lite_incremental_predictor.get_inputs()
|
|
1901
|
+
if len(predict_data) != len(inputs):
|
|
1902
|
+
raise RuntimeError(f"For 'Model.predict', numbers of predict_data {len(predict_data)} "
|
|
1903
|
+
f"is not equal to numbers of net input {len(inputs)}")
|
|
1904
|
+
for i, single_data in enumerate(predict_data):
|
|
1905
|
+
inputs[i].set_data_from_numpy(single_data.asnumpy())
|
|
1906
|
+
outputs: list = self._lite_incremental_predictor.predict(inputs)
|
|
1907
|
+
else:
|
|
1908
|
+
# Execute full model inference
|
|
1909
|
+
if not self._lite_full_predictor:
|
|
1910
|
+
lite_context = _get_lite_context(self._mindspore_lite.Context())
|
|
1911
|
+
self._lite_full_predictor = \
|
|
1912
|
+
self._mindspore_lite.lite_infer.LiteInfer(self, *predict_data, context=lite_context,
|
|
1913
|
+
model_group_id=model_group_id, config=config)
|
|
1914
|
+
|
|
1915
|
+
inputs = self._lite_full_predictor.get_inputs()
|
|
1916
|
+
if len(predict_data) != len(inputs):
|
|
1917
|
+
raise RuntimeError(f"For 'Model.predict', numbers of predict_data {len(predict_data)} "
|
|
1918
|
+
f"is not equal to numbers of net input {len(inputs)}")
|
|
1919
|
+
for i, single_data in enumerate(predict_data):
|
|
1920
|
+
inputs[i].set_data_from_numpy(single_data.asnumpy())
|
|
1921
|
+
outputs: list = self._lite_full_predictor.predict(inputs)
|
|
1922
|
+
if not outputs:
|
|
1923
|
+
return Tensor(outputs)
|
|
1924
|
+
if len(outputs) == 1:
|
|
1925
|
+
return Tensor(outputs[0].get_data_to_numpy())
|
|
1926
|
+
outputs = [Tensor(single_output.get_data_to_numpy()) for single_output in outputs]
|
|
1927
|
+
return tuple(outputs)
|
|
1928
|
+
|
|
1929
|
+
def predict(self, *predict_data, backend=None, config=None):
|
|
1930
|
+
"""
|
|
1931
|
+
Generate output predictions for the input samples.
|
|
1932
|
+
|
|
1933
|
+
Args:
|
|
1934
|
+
predict_data (Union[Tensor, list[Tensor], tuple[Tensor]], optional):
|
|
1935
|
+
The predict data, can be a single tensor,
|
|
1936
|
+
a list of tensor, or a tuple of tensor.
|
|
1937
|
+
backend (str): Select predict backend, this parameter is an experimental feature
|
|
1938
|
+
and is mainly used for MindSpore Lite cloud-side inference. Default: ``None`` .
|
|
1939
|
+
config (dict, optional) - The config parameter is enabled when the backend is ‘lite’.
|
|
1940
|
+
The config includes two parts: config_path (configPath, str) and config_item (str, dict).
|
|
1941
|
+
When the config_item is set, its priority is higher than the config_path. Set the ranking
|
|
1942
|
+
table file for inference. The content of the configuration file is as follows:
|
|
1943
|
+
|
|
1944
|
+
config_path defines the path of the configuration file, which is used to pass user-defined
|
|
1945
|
+
options during model building. In the following scenarios, users may need to set parameters.
|
|
1946
|
+
For example: "/home/user/config.ini". Default value: ``""`` , here is the content of the
|
|
1947
|
+
config.ini file:
|
|
1948
|
+
|
|
1949
|
+
.. code-block::
|
|
1950
|
+
|
|
1951
|
+
[ascend_context]
|
|
1952
|
+
rank_table_file = [path_a](storage initial path of the rank table file)
|
|
1953
|
+
[execution_plan]
|
|
1954
|
+
[op_name1] = data_type:float16 (operator named op_name1 is set to data type float16)
|
|
1955
|
+
[op_name2] = data_type:float32 (operator named op_name2 is set to data type float32)
|
|
1956
|
+
|
|
1957
|
+
When only the config_path is configured, it is done as follows:
|
|
1958
|
+
|
|
1959
|
+
.. code-block::
|
|
1960
|
+
|
|
1961
|
+
config = {"configPath" : "/home/user/config.ini"}
|
|
1962
|
+
|
|
1963
|
+
When only the config_dict is configured, it is done as follows:
|
|
1964
|
+
|
|
1965
|
+
.. code-block::
|
|
1966
|
+
|
|
1967
|
+
config = {"ascend_context" : {"rank_table_file" : "path_b"},
|
|
1968
|
+
"execution_plan" : {"op_name1" : "data_type:float16", "op_name2" : "data_type:float32"}}
|
|
1969
|
+
|
|
1970
|
+
When both the `config_path` and the `config_dict` are configured, it is done as follows:
|
|
1971
|
+
|
|
1972
|
+
.. code-block::
|
|
1973
|
+
|
|
1974
|
+
config = {"configPath" : "/home/user/config.ini",
|
|
1975
|
+
"ascend_context" : {"rank_table_file" : "path_b"},
|
|
1976
|
+
"execution_plan" : {"op_name3" : "data_type:float16", "op_name4" : "data_type:float32"}}
|
|
1977
|
+
|
|
1978
|
+
Note that both the "configPath" is configured in the config_dict and the config_item,
|
|
1979
|
+
in this case, the path_b in the config_dict takes precedence.
|
|
1980
|
+
|
|
1981
|
+
Returns:
|
|
1982
|
+
Tensor, array(s) of predictions.
|
|
1983
|
+
|
|
1984
|
+
Examples:
|
|
1985
|
+
>>> import numpy as np
|
|
1986
|
+
>>> import mindspore
|
|
1987
|
+
>>> from mindspore import Tensor
|
|
1988
|
+
>>> from mindspore.train import Model
|
|
1989
|
+
>>>
|
|
1990
|
+
>>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), mindspore.float32)
|
|
1991
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
1992
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
1993
|
+
>>> model = Model(LeNet5())
|
|
1994
|
+
>>> result = model.predict(input_data)
|
|
1995
|
+
"""
|
|
1996
|
+
if backend not in ['lite', None]:
|
|
1997
|
+
raise ValueError(f"For Model.predict, `backend` should be 'lite' or None, but got {backend}")
|
|
1998
|
+
if backend == "lite" and self._lite_infer:
|
|
1999
|
+
# pylint: disable=broad-except
|
|
2000
|
+
try:
|
|
2001
|
+
return self._predict_lite(*predict_data, config=config)
|
|
2002
|
+
except RuntimeError:
|
|
2003
|
+
self._lite_infer = False
|
|
2004
|
+
logger.warning("Lite inference failed, fallback to original inference!")
|
|
2005
|
+
except ImportError:
|
|
2006
|
+
self._lite_infer = False
|
|
2007
|
+
logger.warning("Import mindspore_lite failed, fallback to original inference!")
|
|
2008
|
+
except BaseException as e:
|
|
2009
|
+
self._lite_infer = False
|
|
2010
|
+
logger.warning(f"Lite inference failed, {e.__str__()}, fallback to original inference!")
|
|
2011
|
+
|
|
2012
|
+
def _check_input_data():
|
|
2013
|
+
"""Input data check."""
|
|
2014
|
+
for item in predict_data:
|
|
2015
|
+
if item is None:
|
|
2016
|
+
continue
|
|
2017
|
+
if isinstance(item, Tensor):
|
|
2018
|
+
if item.size == 0:
|
|
2019
|
+
msg = "The input data can not be empty."
|
|
2020
|
+
logger.critical(msg)
|
|
2021
|
+
raise ValueError(msg)
|
|
2022
|
+
continue
|
|
2023
|
+
if not isinstance(item, (int, float, str)):
|
|
2024
|
+
data_class_str = "Tensor, None, int, float, str"
|
|
2025
|
+
raise TypeError(f'The types of input data must be in the Union({data_class_str}, ' \
|
|
2026
|
+
f'tuple[{data_class_str}], list[{data_class_str}], dict[{data_class_str}]), ' \
|
|
2027
|
+
f'but got type {item if item is None else type(item).__name__}.')
|
|
2028
|
+
|
|
2029
|
+
self._check_network_mode(self._predict_network, False)
|
|
2030
|
+
_check_input_data()
|
|
2031
|
+
_parallel_predict_check()
|
|
2032
|
+
result = self._predict_network(*predict_data)
|
|
2033
|
+
|
|
2034
|
+
check_output_data(result)
|
|
2035
|
+
|
|
2036
|
+
# When it's distributed training and using MindRT,
|
|
2037
|
+
# the node id should be reset to start from 0.
|
|
2038
|
+
# This is to avoid the timeout when finding the actor route tables in 'train' and 'eval' case(or 'fit').
|
|
2039
|
+
if _enable_distributed_mindrt():
|
|
2040
|
+
_reset_op_id_with_offset()
|
|
2041
|
+
|
|
2042
|
+
return result
|
|
2043
|
+
|
|
2044
|
+
def _infer_train_check(self, train_dataset, dataset_sink_mode, sink_size):
|
|
2045
|
+
"""
|
|
2046
|
+
Check arguments of training.
|
|
2047
|
+
|
|
2048
|
+
Args:
|
|
2049
|
+
train_dataset (Dataset): A training dataset iterator.
|
|
2050
|
+
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
|
|
2051
|
+
sink_size (int): Control the amount of data in each sink.
|
|
2052
|
+
"""
|
|
2053
|
+
if context.get_context("mode") != context.GRAPH_MODE:
|
|
2054
|
+
raise RuntimeError("Pre-compile process that generate parameter layout for the train network "
|
|
2055
|
+
"only supports GRAPH MODE and Ascend target currently.")
|
|
2056
|
+
if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
|
2057
|
+
raise RuntimeError("'infer_train_layout' only supports 'semi_auto_parallel' and 'auto_parallel' "
|
|
2058
|
+
"mode, but got {}.".format(_get_parallel_mode()))
|
|
2059
|
+
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
|
|
2060
|
+
if not dataset_sink_mode:
|
|
2061
|
+
raise ValueError("Only dataset sink mode is supported for now.")
|
|
2062
|
+
if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode:
|
|
2063
|
+
raise ValueError("Dataset sink mode is currently not supported when training with a GraphCell.")
|
|
2064
|
+
Validator.check_is_int(sink_size)
|
|
2065
|
+
dataset_size = train_dataset.get_dataset_size()
|
|
2066
|
+
if dataset_size == 0:
|
|
2067
|
+
raise ValueError("There is no valid data in dataset, please check dataset file firstly.")
|
|
2068
|
+
if sink_size == -1:
|
|
2069
|
+
sink_size = dataset_size
|
|
2070
|
+
if sink_size < -1 or sink_size == 0:
|
|
2071
|
+
raise ValueError("For 'infer_train_layout', the argument 'sink_size' must be -1 or positive, "
|
|
2072
|
+
"but got sink_size {}.".format(sink_size))
|
|
2073
|
+
|
|
2074
|
+
def infer_train_layout(self, train_dataset, dataset_sink_mode=True, sink_size=-1):
|
|
2075
|
+
"""
|
|
2076
|
+
Generate parameter layout for the train network in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
|
|
2077
|
+
Only dataset sink mode is supported for now.
|
|
2078
|
+
|
|
2079
|
+
.. warning::
|
|
2080
|
+
This is an experimental API that is subject to change or deletion.
|
|
2081
|
+
|
|
2082
|
+
Note:
|
|
2083
|
+
This is a pre-compile function. The arguments should be the same as model.train() function.
|
|
2084
|
+
|
|
2085
|
+
Args:
|
|
2086
|
+
train_dataset (Dataset): A training dataset iterator. If there is no
|
|
2087
|
+
loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
|
|
2088
|
+
returned and passed to the network. Otherwise, a tuple (data, label) should
|
|
2089
|
+
be returned. The data and label would be passed to the network and loss
|
|
2090
|
+
function respectively.
|
|
2091
|
+
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
|
|
2092
|
+
Configure pynative mode or CPU, the training process will be performed with
|
|
2093
|
+
dataset not sink. Default: ``True`` .
|
|
2094
|
+
sink_size (int): Control the number of steps for each sinking.
|
|
2095
|
+
If sink_size = -1, sink the complete dataset for each epoch.
|
|
2096
|
+
If sink_size > 0, sink sink_size data for each epoch.
|
|
2097
|
+
If dataset_sink_mode is False, set sink_size as invalid.
|
|
2098
|
+
Default: ``-1`` .
|
|
2099
|
+
|
|
2100
|
+
Returns:
|
|
2101
|
+
Dict, Parameter layout dictionary used for load distributed checkpoint
|
|
2102
|
+
|
|
2103
|
+
Examples:
|
|
2104
|
+
>>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
|
|
2105
|
+
>>> # mindspore.cn.
|
|
2106
|
+
>>> import numpy as np
|
|
2107
|
+
>>> import mindspore as ms
|
|
2108
|
+
>>> from mindspore import Tensor, nn
|
|
2109
|
+
>>> from mindspore.train import Model
|
|
2110
|
+
>>> from mindspore.communication import init
|
|
2111
|
+
>>>
|
|
2112
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
2113
|
+
>>> init()
|
|
2114
|
+
>>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
|
|
2115
|
+
>>>
|
|
2116
|
+
>>> # Create the dataset taking MNIST as an example. Refer to
|
|
2117
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/mnist.py
|
|
2118
|
+
>>> dataset = create_dataset()
|
|
2119
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
2120
|
+
>>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
|
|
2121
|
+
>>> net = LeNet5()
|
|
2122
|
+
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
2123
|
+
>>> loss_scale_manager = ms.FixedLossScaleManager()
|
|
2124
|
+
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
|
2125
|
+
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None,
|
|
2126
|
+
... loss_scale_manager=loss_scale_manager)
|
|
2127
|
+
>>> layout_dict = model.infer_train_layout(dataset)
|
|
2128
|
+
"""
|
|
2129
|
+
self._infer_train_check(train_dataset, dataset_sink_mode, sink_size)
|
|
2130
|
+
|
|
2131
|
+
train_dataset.__no_send__ = True
|
|
2132
|
+
train_dataset_helper, train_network = self._exec_preprocess(is_train=True,
|
|
2133
|
+
dataset=train_dataset,
|
|
2134
|
+
dataset_sink_mode=dataset_sink_mode,
|
|
2135
|
+
sink_size=sink_size)
|
|
2136
|
+
for inputs in train_dataset_helper:
|
|
2137
|
+
train_network.compile(*inputs)
|
|
2138
|
+
break
|
|
2139
|
+
train_dataset.__model_hash__ = hash(self)
|
|
2140
|
+
return train_network.parameter_layout_dict
|
|
2141
|
+
|
|
2142
|
+
def infer_predict_layout(self, *predict_data, skip_backend_compile=False):
|
|
2143
|
+
"""
|
|
2144
|
+
Generate parameter layout for the predict network in 'AUTO_PARALLEL' or 'SEMI_AUTO_PARALLEL' mode.
|
|
2145
|
+
|
|
2146
|
+
Data could be a single tensor or multiple tensors.
|
|
2147
|
+
|
|
2148
|
+
Note:
|
|
2149
|
+
Batch data should be put together in one tensor.
|
|
2150
|
+
|
|
2151
|
+
Args:
|
|
2152
|
+
predict_data (Union[Tensor, list[Tensor], tuple[Tensor]], optional):
|
|
2153
|
+
The predict data, can be a single tensor,
|
|
2154
|
+
a list of tensor, or a tuple of tensor.
|
|
2155
|
+
skip_backend_compile (bool): Only run the frontend compile process,
|
|
2156
|
+
skip the compile process on the device side. Set this flag to True may
|
|
2157
|
+
lead to recompiling process can not hit cache.
|
|
2158
|
+
|
|
2159
|
+
Returns:
|
|
2160
|
+
Dict, Parameter layout dictionary used for load distributed checkpoint.
|
|
2161
|
+
Using as one of input parameters of load_distributed_checkpoint, always.
|
|
2162
|
+
|
|
2163
|
+
Raises:
|
|
2164
|
+
RuntimeError: If not in GRAPH_MODE.
|
|
2165
|
+
|
|
2166
|
+
Examples:
|
|
2167
|
+
>>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on
|
|
2168
|
+
>>> # mindspore.cn.
|
|
2169
|
+
>>> import numpy as np
|
|
2170
|
+
>>> import mindspore as ms
|
|
2171
|
+
>>> from mindspore import Tensor
|
|
2172
|
+
>>> from mindspore.train import Model
|
|
2173
|
+
>>> from mindspore.communication import init
|
|
2174
|
+
>>>
|
|
2175
|
+
>>> ms.set_context(mode=ms.GRAPH_MODE)
|
|
2176
|
+
>>> init()
|
|
2177
|
+
>>> ms.set_auto_parallel_context(full_batch=True, parallel_mode=ms.ParallelMode.SEMI_AUTO_PARALLEL)
|
|
2178
|
+
>>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
|
|
2179
|
+
>>> model = Model(Net())
|
|
2180
|
+
>>> predict_map = model.infer_predict_layout(input_data)
|
|
2181
|
+
"""
|
|
2182
|
+
if context.get_context("mode") != context.GRAPH_MODE:
|
|
2183
|
+
raise RuntimeError("Pre-compile process that generate parameter layout for the predict network "
|
|
2184
|
+
"only supports GRAPH MODE and Ascend target currently.")
|
|
2185
|
+
if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
|
2186
|
+
raise RuntimeError('Infer predict layout only supports semi auto parallel and auto parallel mode.')
|
|
2187
|
+
_parallel_predict_check()
|
|
2188
|
+
check_input_data(*predict_data, data_class=(int, float, str, None, Tensor))
|
|
2189
|
+
|
|
2190
|
+
predict_net = self._predict_network
|
|
2191
|
+
# Unlike the cases in build_train_network() and build_eval_network(), 'multi_subgraphs' is not set
|
|
2192
|
+
predict_net = self._check_network_mode(predict_net, False)
|
|
2193
|
+
if skip_backend_compile:
|
|
2194
|
+
origin_phase = predict_net.phase
|
|
2195
|
+
predict_net.phase = "export." + predict_net.phase
|
|
2196
|
+
predict_net.compile(*predict_data)
|
|
2197
|
+
# set phase back to prevent from hitting incomplete compile cache
|
|
2198
|
+
predict_net.phase = origin_phase
|
|
2199
|
+
else:
|
|
2200
|
+
predict_net.compile(*predict_data)
|
|
2201
|
+
return predict_net.parameter_layout_dict
|
|
2202
|
+
|
|
2203
|
+
def _flush_from_cache(self, cb_params):
|
|
2204
|
+
"""Flush cache data to host if tensor is cache enable."""
|
|
2205
|
+
params = cb_params.train_network.get_parameters()
|
|
2206
|
+
for param in params:
|
|
2207
|
+
if param.cache_enable:
|
|
2208
|
+
Tensor(param).flush_from_cache()
|
|
2209
|
+
|
|
2210
|
+
@property
|
|
2211
|
+
def train_network(self):
|
|
2212
|
+
"""
|
|
2213
|
+
Get the model's train network.
|
|
2214
|
+
|
|
2215
|
+
Returns:
|
|
2216
|
+
Object, the instance of train network.
|
|
2217
|
+
"""
|
|
2218
|
+
return self._train_network
|
|
2219
|
+
|
|
2220
|
+
@property
|
|
2221
|
+
def predict_network(self):
|
|
2222
|
+
"""
|
|
2223
|
+
Get the model's predict network.
|
|
2224
|
+
|
|
2225
|
+
Returns:
|
|
2226
|
+
Object, the instance of predict network.
|
|
2227
|
+
"""
|
|
2228
|
+
return self._predict_network
|
|
2229
|
+
|
|
2230
|
+
@property
|
|
2231
|
+
def eval_network(self):
|
|
2232
|
+
"""
|
|
2233
|
+
Get the model's eval network.
|
|
2234
|
+
|
|
2235
|
+
Returns:
|
|
2236
|
+
Object, the instance of evaluate network.
|
|
2237
|
+
"""
|
|
2238
|
+
return self._eval_network
|
|
2239
|
+
|
|
2240
|
+
def _prepare_obf_dataset(self, dataset):
|
|
2241
|
+
if not hasattr(self._network, 'obf_ratios'):
|
|
2242
|
+
return dataset
|
|
2243
|
+
data_size = dataset.get_dataset_size()
|
|
2244
|
+
obf_ratio_dataset = []
|
|
2245
|
+
for _ in range(data_size):
|
|
2246
|
+
obf_ratio_dataset.append(self._network.obf_ratios)
|
|
2247
|
+
obf_ratio_dataset = ds.NumpySlicesDataset(data=obf_ratio_dataset, column_names=["y_obf"])
|
|
2248
|
+
dataset = ds.zip((dataset, obf_ratio_dataset))
|
|
2249
|
+
return dataset
|
|
2250
|
+
|
|
2251
|
+
|
|
2252
|
+
__all__ = ["Model"]
|