mindspore 2.4.0__cp311-cp311-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -0
- mindspore/__init__.py +53 -0
- mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
- mindspore/_c_expression.cpython-311-darwin.so +0 -0
- mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
- mindspore/_check_jit_forbidden_api.py +106 -0
- mindspore/_checkparam.py +1419 -0
- mindspore/_extends/__init__.py +23 -0
- mindspore/_extends/builtin_operations.py +224 -0
- mindspore/_extends/graph_kernel/__init__.py +17 -0
- mindspore/_extends/graph_kernel/model/__init__.py +19 -0
- mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
- mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
- mindspore/_extends/graph_kernel/model/model.py +553 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
- mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
- mindspore/_extends/graph_kernel/splitter.py +140 -0
- mindspore/_extends/graph_kernel/utils.py +28 -0
- mindspore/_extends/parallel_compile/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
- mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
- mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
- mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
- mindspore/_extends/parse/__init__.py +49 -0
- mindspore/_extends/parse/compile_config.py +299 -0
- mindspore/_extends/parse/namespace.py +136 -0
- mindspore/_extends/parse/parser.py +1448 -0
- mindspore/_extends/parse/resources.py +213 -0
- mindspore/_extends/parse/standard_method.py +4475 -0
- mindspore/_extends/parse/trope.py +97 -0
- mindspore/_extends/pijit/__init__.py +23 -0
- mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
- mindspore/_extends/remote/__init__.py +19 -0
- mindspore/_extends/remote/kernel_build_server.py +199 -0
- mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/_extends/utils.py +68 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_profiler.py +30 -0
- mindspore/amp.py +433 -0
- mindspore/boost/__init__.py +42 -0
- mindspore/boost/adasum.py +319 -0
- mindspore/boost/base.py +535 -0
- mindspore/boost/boost.py +400 -0
- mindspore/boost/boost_cell_wrapper.py +790 -0
- mindspore/boost/dim_reduce.py +323 -0
- mindspore/boost/grad_accumulation.py +79 -0
- mindspore/boost/grad_freeze.py +382 -0
- mindspore/boost/group_loss_scale_manager.py +166 -0
- mindspore/boost/less_batch_normalization.py +174 -0
- mindspore/common/__init__.py +86 -0
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_decorator.py +50 -0
- mindspore/common/_jit_fallback_utils.py +110 -0
- mindspore/common/_monad.py +25 -0
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_adapter.py +74 -0
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_register_for_tensor.py +46 -0
- mindspore/common/_stub_tensor.py +210 -0
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/_utils.py +122 -0
- mindspore/common/api.py +2064 -0
- mindspore/common/auto_dynamic_shape.py +507 -0
- mindspore/common/dtype.py +422 -0
- mindspore/common/dump.py +130 -0
- mindspore/common/file_system.py +48 -0
- mindspore/common/generator.py +254 -0
- mindspore/common/hook_handle.py +143 -0
- mindspore/common/initializer.py +880 -0
- mindspore/common/jit_config.py +98 -0
- mindspore/common/lazy_inline.py +240 -0
- mindspore/common/mindir_util.py +111 -0
- mindspore/common/mutable.py +234 -0
- mindspore/common/no_inline.py +54 -0
- mindspore/common/np_dtype.py +25 -0
- mindspore/common/parameter.py +1081 -0
- mindspore/common/recompute.py +292 -0
- mindspore/common/seed.py +260 -0
- mindspore/common/sparse_tensor.py +1175 -0
- mindspore/common/symbol.py +122 -0
- mindspore/common/tensor.py +5039 -0
- mindspore/communication/__init__.py +37 -0
- mindspore/communication/_comm_helper.py +501 -0
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/communication/comm_func.py +1395 -0
- mindspore/communication/management.py +673 -0
- mindspore/config/op_info.config +533 -0
- mindspore/context.py +2077 -0
- mindspore/dataset/__init__.py +90 -0
- mindspore/dataset/audio/__init__.py +61 -0
- mindspore/dataset/audio/transforms.py +3690 -0
- mindspore/dataset/audio/utils.py +386 -0
- mindspore/dataset/audio/validators.py +1172 -0
- mindspore/dataset/callback/__init__.py +20 -0
- mindspore/dataset/callback/ds_callback.py +368 -0
- mindspore/dataset/callback/validators.py +32 -0
- mindspore/dataset/core/__init__.py +13 -0
- mindspore/dataset/core/config.py +1095 -0
- mindspore/dataset/core/datatypes.py +101 -0
- mindspore/dataset/core/py_util_helpers.py +65 -0
- mindspore/dataset/core/validator_helpers.py +781 -0
- mindspore/dataset/debug/__init__.py +21 -0
- mindspore/dataset/debug/debug_hook.py +97 -0
- mindspore/dataset/debug/pre_defined_hook.py +67 -0
- mindspore/dataset/engine/__init__.py +124 -0
- mindspore/dataset/engine/cache_admin.py +47 -0
- mindspore/dataset/engine/cache_client.py +129 -0
- mindspore/dataset/engine/datasets.py +4582 -0
- mindspore/dataset/engine/datasets_audio.py +911 -0
- mindspore/dataset/engine/datasets_standard_format.py +543 -0
- mindspore/dataset/engine/datasets_text.py +2161 -0
- mindspore/dataset/engine/datasets_user_defined.py +1184 -0
- mindspore/dataset/engine/datasets_vision.py +4816 -0
- mindspore/dataset/engine/iterators.py +371 -0
- mindspore/dataset/engine/obs/__init__.py +23 -0
- mindspore/dataset/engine/obs/config_loader.py +68 -0
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
- mindspore/dataset/engine/obs/util.py +482 -0
- mindspore/dataset/engine/offload.py +596 -0
- mindspore/dataset/engine/queue.py +304 -0
- mindspore/dataset/engine/samplers.py +895 -0
- mindspore/dataset/engine/serializer_deserializer.py +159 -0
- mindspore/dataset/engine/validators.py +2895 -0
- mindspore/dataset/text/__init__.py +51 -0
- mindspore/dataset/text/transforms.py +1703 -0
- mindspore/dataset/text/utils.py +715 -0
- mindspore/dataset/text/validators.py +642 -0
- mindspore/dataset/transforms/__init__.py +45 -0
- mindspore/dataset/transforms/c_transforms.py +638 -0
- mindspore/dataset/transforms/py_transforms.py +393 -0
- mindspore/dataset/transforms/py_transforms_util.py +255 -0
- mindspore/dataset/transforms/transforms.py +1260 -0
- mindspore/dataset/transforms/validators.py +410 -0
- mindspore/dataset/utils/__init__.py +19 -0
- mindspore/dataset/utils/browse_dataset.py +190 -0
- mindspore/dataset/utils/line_reader.py +126 -0
- mindspore/dataset/vision/__init__.py +65 -0
- mindspore/dataset/vision/c_transforms.py +2641 -0
- mindspore/dataset/vision/py_transforms.py +2120 -0
- mindspore/dataset/vision/py_transforms_util.py +1660 -0
- mindspore/dataset/vision/transforms.py +7295 -0
- mindspore/dataset/vision/utils.py +863 -0
- mindspore/dataset/vision/validators.py +1483 -0
- mindspore/default_config.py +2 -0
- mindspore/experimental/__init__.py +20 -0
- mindspore/experimental/es/__init__.py +22 -0
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/experimental/es/embedding_service_layer.py +581 -0
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/experimental/llm_boost/atb/__init__.py +23 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/map_parameter.py +309 -0
- mindspore/experimental/optim/__init__.py +40 -0
- mindspore/experimental/optim/adadelta.py +161 -0
- mindspore/experimental/optim/adagrad.py +168 -0
- mindspore/experimental/optim/adam.py +193 -0
- mindspore/experimental/optim/adamax.py +170 -0
- mindspore/experimental/optim/adamw.py +290 -0
- mindspore/experimental/optim/asgd.py +153 -0
- mindspore/experimental/optim/lr_scheduler.py +1371 -0
- mindspore/experimental/optim/nadam.py +157 -0
- mindspore/experimental/optim/optimizer.py +262 -0
- mindspore/experimental/optim/radam.py +194 -0
- mindspore/experimental/optim/rmsprop.py +154 -0
- mindspore/experimental/optim/rprop.py +164 -0
- mindspore/experimental/optim/sgd.py +156 -0
- mindspore/hal/__init__.py +40 -0
- mindspore/hal/_ascend.py +57 -0
- mindspore/hal/_base.py +57 -0
- mindspore/hal/_cpu.py +56 -0
- mindspore/hal/_gpu.py +57 -0
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/device.py +356 -0
- mindspore/hal/event.py +179 -0
- mindspore/hal/memory.py +326 -0
- mindspore/hal/stream.py +357 -0
- mindspore/include/OWNERS +7 -0
- mindspore/include/api/allocator.h +97 -0
- mindspore/include/api/callback/callback.h +93 -0
- mindspore/include/api/callback/ckpt_saver.h +41 -0
- mindspore/include/api/callback/loss_monitor.h +33 -0
- mindspore/include/api/callback/lr_scheduler.h +51 -0
- mindspore/include/api/callback/time_monitor.h +34 -0
- mindspore/include/api/callback/train_accuracy.h +37 -0
- mindspore/include/api/cell.h +90 -0
- mindspore/include/api/cfg.h +82 -0
- mindspore/include/api/context.h +602 -0
- mindspore/include/api/data_type.h +47 -0
- mindspore/include/api/delegate.h +178 -0
- mindspore/include/api/delegate_api.h +75 -0
- mindspore/include/api/dual_abi_helper.h +208 -0
- mindspore/include/api/format.h +28 -0
- mindspore/include/api/graph.h +46 -0
- mindspore/include/api/kernel.h +58 -0
- mindspore/include/api/kernel_api.h +168 -0
- mindspore/include/api/metrics/accuracy.h +36 -0
- mindspore/include/api/metrics/metrics.h +41 -0
- mindspore/include/api/model.h +438 -0
- mindspore/include/api/model_group.h +91 -0
- mindspore/include/api/model_parallel_runner.h +168 -0
- mindspore/include/api/serialization.h +185 -0
- mindspore/include/api/status.h +192 -0
- mindspore/include/api/types.h +431 -0
- mindspore/include/api/visible.h +41 -0
- mindspore/include/c_api/context_c.h +179 -0
- mindspore/include/c_api/data_type_c.h +52 -0
- mindspore/include/c_api/format_c.h +46 -0
- mindspore/include/c_api/model_c.h +347 -0
- mindspore/include/c_api/status_c.h +79 -0
- mindspore/include/c_api/tensor_c.h +146 -0
- mindspore/include/c_api/types_c.h +67 -0
- mindspore/include/dataset/config.h +163 -0
- mindspore/include/dataset/constants.h +363 -0
- mindspore/include/dataset/execute.h +196 -0
- mindspore/include/dataset/text.h +1092 -0
- mindspore/include/dataset/transforms.h +638 -0
- mindspore/include/dataset/vision.h +2129 -0
- mindspore/include/dataset/vision_ascend.h +206 -0
- mindspore/include/dataset/vision_lite.h +625 -0
- mindspore/lib/libavcodec.59.dylib +0 -0
- mindspore/lib/libavdevice.59.dylib +0 -0
- mindspore/lib/libavfilter.8.dylib +0 -0
- mindspore/lib/libavformat.59.dylib +0 -0
- mindspore/lib/libavutil.57.dylib +0 -0
- mindspore/lib/libdnnl.2.dylib +0 -0
- mindspore/lib/libicudata.69.dylib +0 -0
- mindspore/lib/libicui18n.69.dylib +0 -0
- mindspore/lib/libicuuc.69.dylib +0 -0
- mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
- mindspore/lib/libmindspore_backend.dylib +0 -0
- mindspore/lib/libmindspore_common.dylib +0 -0
- mindspore/lib/libmindspore_core.dylib +0 -0
- mindspore/lib/libmindspore_glog.0.dylib +0 -0
- mindspore/lib/libmindspore_gpr.15.dylib +0 -0
- mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
- mindspore/lib/libmindspore_grpc.15.dylib +0 -0
- mindspore/lib/libmindspore_np_dtype.dylib +0 -0
- mindspore/lib/libmindspore_ops.dylib +0 -0
- mindspore/lib/libmindspore_upb.15.dylib +0 -0
- mindspore/lib/libnnacl.dylib +0 -0
- mindspore/lib/libopencv_core.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
- mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
- mindspore/lib/libps_cache.dylib +0 -0
- mindspore/lib/libswresample.4.dylib +0 -0
- mindspore/lib/libswscale.6.dylib +0 -0
- mindspore/lib/libtinyxml2.8.dylib +0 -0
- mindspore/log.py +633 -0
- mindspore/mindrecord/__init__.py +43 -0
- mindspore/mindrecord/common/__init__.py +17 -0
- mindspore/mindrecord/common/constant.py +20 -0
- mindspore/mindrecord/common/enums.py +44 -0
- mindspore/mindrecord/common/exceptions.py +311 -0
- mindspore/mindrecord/config.py +809 -0
- mindspore/mindrecord/filereader.py +174 -0
- mindspore/mindrecord/filewriter.py +722 -0
- mindspore/mindrecord/mindpage.py +210 -0
- mindspore/mindrecord/shardheader.py +141 -0
- mindspore/mindrecord/shardindexgenerator.py +74 -0
- mindspore/mindrecord/shardreader.py +117 -0
- mindspore/mindrecord/shardsegment.py +128 -0
- mindspore/mindrecord/shardutils.py +185 -0
- mindspore/mindrecord/shardwriter.py +237 -0
- mindspore/mindrecord/tools/__init__.py +17 -0
- mindspore/mindrecord/tools/cifar10.py +140 -0
- mindspore/mindrecord/tools/cifar100.py +153 -0
- mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
- mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
- mindspore/mindrecord/tools/csv_to_mr.py +200 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
- mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
- mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
- mindspore/mint/__init__.py +1586 -0
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/linalg/__init__.py +22 -0
- mindspore/mint/nn/__init__.py +757 -0
- mindspore/mint/nn/functional.py +679 -0
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +206 -0
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +73 -0
- mindspore/nn/__init__.py +47 -0
- mindspore/nn/cell.py +2787 -0
- mindspore/nn/dynamic_lr.py +482 -0
- mindspore/nn/grad/__init__.py +21 -0
- mindspore/nn/grad/cell_grad.py +196 -0
- mindspore/nn/layer/__init__.py +63 -0
- mindspore/nn/layer/activation.py +1822 -0
- mindspore/nn/layer/basic.py +1629 -0
- mindspore/nn/layer/channel_shuffle.py +90 -0
- mindspore/nn/layer/combined.py +248 -0
- mindspore/nn/layer/container.py +734 -0
- mindspore/nn/layer/conv.py +1505 -0
- mindspore/nn/layer/dense.py +204 -0
- mindspore/nn/layer/embedding.py +869 -0
- mindspore/nn/layer/image.py +661 -0
- mindspore/nn/layer/math.py +1069 -0
- mindspore/nn/layer/normalization.py +1273 -0
- mindspore/nn/layer/padding.py +880 -0
- mindspore/nn/layer/pooling.py +2302 -0
- mindspore/nn/layer/rnn_cells.py +388 -0
- mindspore/nn/layer/rnns.py +849 -0
- mindspore/nn/layer/thor_layer.py +963 -0
- mindspore/nn/layer/timedistributed.py +155 -0
- mindspore/nn/layer/transformer.py +823 -0
- mindspore/nn/learning_rate_schedule.py +512 -0
- mindspore/nn/loss/__init__.py +36 -0
- mindspore/nn/loss/loss.py +2924 -0
- mindspore/nn/metrics.py +53 -0
- mindspore/nn/optim/__init__.py +45 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
- mindspore/nn/optim/ada_grad.py +217 -0
- mindspore/nn/optim/adadelta.py +206 -0
- mindspore/nn/optim/adafactor.py +448 -0
- mindspore/nn/optim/adam.py +1297 -0
- mindspore/nn/optim/adamax.py +220 -0
- mindspore/nn/optim/adasum.py +548 -0
- mindspore/nn/optim/asgd.py +216 -0
- mindspore/nn/optim/ftrl.py +401 -0
- mindspore/nn/optim/lamb.py +296 -0
- mindspore/nn/optim/lars.py +202 -0
- mindspore/nn/optim/lazyadam.py +533 -0
- mindspore/nn/optim/momentum.py +239 -0
- mindspore/nn/optim/optimizer.py +1034 -0
- mindspore/nn/optim/proximal_ada_grad.py +242 -0
- mindspore/nn/optim/rmsprop.py +264 -0
- mindspore/nn/optim/rprop.py +251 -0
- mindspore/nn/optim/sgd.py +237 -0
- mindspore/nn/optim/tft_wrapper.py +127 -0
- mindspore/nn/optim/thor.py +1310 -0
- mindspore/nn/probability/__init__.py +22 -0
- mindspore/nn/probability/bijector/__init__.py +35 -0
- mindspore/nn/probability/bijector/bijector.py +337 -0
- mindspore/nn/probability/bijector/exp.py +65 -0
- mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
- mindspore/nn/probability/bijector/invert.py +126 -0
- mindspore/nn/probability/bijector/power_transform.py +196 -0
- mindspore/nn/probability/bijector/scalar_affine.py +167 -0
- mindspore/nn/probability/bijector/softplus.py +189 -0
- mindspore/nn/probability/bnn_layers/__init__.py +29 -0
- mindspore/nn/probability/bnn_layers/_util.py +46 -0
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
- mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
- mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
- mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
- mindspore/nn/probability/distribution/__init__.py +56 -0
- mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
- mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
- mindspore/nn/probability/distribution/_utils/utils.py +362 -0
- mindspore/nn/probability/distribution/bernoulli.py +334 -0
- mindspore/nn/probability/distribution/beta.py +391 -0
- mindspore/nn/probability/distribution/categorical.py +435 -0
- mindspore/nn/probability/distribution/cauchy.py +383 -0
- mindspore/nn/probability/distribution/distribution.py +827 -0
- mindspore/nn/probability/distribution/exponential.py +350 -0
- mindspore/nn/probability/distribution/gamma.py +391 -0
- mindspore/nn/probability/distribution/geometric.py +335 -0
- mindspore/nn/probability/distribution/gumbel.py +257 -0
- mindspore/nn/probability/distribution/half_normal.py +133 -0
- mindspore/nn/probability/distribution/laplace.py +128 -0
- mindspore/nn/probability/distribution/log_normal.py +272 -0
- mindspore/nn/probability/distribution/logistic.py +379 -0
- mindspore/nn/probability/distribution/normal.py +336 -0
- mindspore/nn/probability/distribution/poisson.py +288 -0
- mindspore/nn/probability/distribution/student_t.py +149 -0
- mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
- mindspore/nn/probability/distribution/uniform.py +375 -0
- mindspore/nn/reinforcement/__init__.py +24 -0
- mindspore/nn/reinforcement/_batch_read_write.py +142 -0
- mindspore/nn/reinforcement/_tensors_queue.py +152 -0
- mindspore/nn/reinforcement/tensor_array.py +145 -0
- mindspore/nn/sparse/__init__.py +23 -0
- mindspore/nn/sparse/sparse.py +147 -0
- mindspore/nn/wrap/__init__.py +49 -0
- mindspore/nn/wrap/cell_wrapper.py +968 -0
- mindspore/nn/wrap/grad_reducer.py +608 -0
- mindspore/nn/wrap/loss_scale.py +694 -0
- mindspore/numpy/__init__.py +121 -0
- mindspore/numpy/array_creations.py +2731 -0
- mindspore/numpy/array_ops.py +2629 -0
- mindspore/numpy/dtypes.py +185 -0
- mindspore/numpy/fft.py +966 -0
- mindspore/numpy/logic_ops.py +936 -0
- mindspore/numpy/math_ops.py +5911 -0
- mindspore/numpy/utils.py +214 -0
- mindspore/numpy/utils_const.py +565 -0
- mindspore/ops/__init__.py +56 -0
- mindspore/ops/_constants.py +30 -0
- mindspore/ops/_grad_experimental/__init__.py +31 -0
- mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
- mindspore/ops/_grad_experimental/grad_base.py +143 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
- mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
- mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
- mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
- mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
- mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
- mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
- mindspore/ops/_op_impl/__init__.py +23 -0
- mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
- mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
- mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
- mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
- mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
- mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
- mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
- mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
- mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
- mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
- mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
- mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
- mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
- mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
- mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
- mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
- mindspore/ops/_op_impl/aicpu/abs.py +36 -0
- mindspore/ops/_op_impl/aicpu/acos.py +32 -0
- mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
- mindspore/ops/_op_impl/aicpu/add.py +43 -0
- mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
- mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
- mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
- mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
- mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
- mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
- mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
- mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/angle.py +31 -0
- mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
- mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
- mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
- mindspore/ops/_op_impl/aicpu/asin.py +32 -0
- mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
- mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
- mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
- mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
- mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
- mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
- mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
- mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
- mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
- mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
- mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
- mindspore/ops/_op_impl/aicpu/cast.py +225 -0
- mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
- mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
- mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
- mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
- mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
- mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
- mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
- mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
- mindspore/ops/_op_impl/aicpu/complex.py +32 -0
- mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
- mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
- mindspore/ops/_op_impl/aicpu/concat.py +57 -0
- mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
- mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
- mindspore/ops/_op_impl/aicpu/conj.py +42 -0
- mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/cos.py +34 -0
- mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
- mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
- mindspore/ops/_op_impl/aicpu/cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
- mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
- mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
- mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
- mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
- mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
- mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
- mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
- mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
- mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
- mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
- mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
- mindspore/ops/_op_impl/aicpu/diag.py +36 -0
- mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
- mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
- mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
- mindspore/ops/_op_impl/aicpu/div.py +41 -0
- mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
- mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
- mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
- mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
- mindspore/ops/_op_impl/aicpu/eig.py +35 -0
- mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
- mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
- mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
- mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
- mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/exp.py +37 -0
- mindspore/ops/_op_impl/aicpu/expand.py +45 -0
- mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
- mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
- mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
- mindspore/ops/_op_impl/aicpu/eye.py +44 -0
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
- mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
- mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
- mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
- mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
- mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
- mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
- mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
- mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/gather.py +46 -0
- mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
- mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
- mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
- mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
- mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
- mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
- mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
- mindspore/ops/_op_impl/aicpu/glu.py +33 -0
- mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/greater.py +41 -0
- mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
- mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
- mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
- mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
- mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
- mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
- mindspore/ops/_op_impl/aicpu/identity.py +42 -0
- mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
- mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
- mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
- mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
- mindspore/ops/_op_impl/aicpu/imag.py +31 -0
- mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
- mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
- mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
- mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
- mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
- mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
- mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
- mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
- mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
- mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
- mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
- mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/less.py +41 -0
- mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
- mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
- mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
- mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
- mindspore/ops/_op_impl/aicpu/log.py +37 -0
- mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
- mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
- mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
- mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
- mindspore/ops/_op_impl/aicpu/logit.py +33 -0
- mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
- mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
- mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
- mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
- mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
- mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
- mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
- mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
- mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
- mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
- mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
- mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
- mindspore/ops/_op_impl/aicpu/median.py +39 -0
- mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
- mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
- mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
- mindspore/ops/_op_impl/aicpu/mul.py +43 -0
- mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
- mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
- mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
- mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
- mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
- mindspore/ops/_op_impl/aicpu/neg.py +36 -0
- mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
- mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
- mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
- mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
- mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
- mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
- mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
- mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
- mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
- mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
- mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
- mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
- mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
- mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
- mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
- mindspore/ops/_op_impl/aicpu/padding.py +41 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
- mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
- mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
- mindspore/ops/_op_impl/aicpu/polar.py +32 -0
- mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
- mindspore/ops/_op_impl/aicpu/pow.py +39 -0
- mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
- mindspore/ops/_op_impl/aicpu/qr.py +36 -0
- mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
- mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
- mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
- mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
- mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
- mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
- mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
- mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
- mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
- mindspore/ops/_op_impl/aicpu/range.py +36 -0
- mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/real.py +31 -0
- mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
- mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
- mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
- mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
- mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
- mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
- mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
- mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
- mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
- mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
- mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
- mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
- mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
- mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/round.py +34 -0
- mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
- mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
- mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
- mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
- mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
- mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
- mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
- mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
- mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
- mindspore/ops/_op_impl/aicpu/select.py +45 -0
- mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
- mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
- mindspore/ops/_op_impl/aicpu/sign.py +36 -0
- mindspore/ops/_op_impl/aicpu/sin.py +34 -0
- mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
- mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
- mindspore/ops/_op_impl/aicpu/slice.py +59 -0
- mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
- mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
- mindspore/ops/_op_impl/aicpu/sort.py +39 -0
- mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
- mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
- mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
- mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
- mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
- mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
- mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
- mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
- mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
- mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
- mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
- mindspore/ops/_op_impl/aicpu/split.py +45 -0
- mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
- mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/square.py +35 -0
- mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
- mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
- mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
- mindspore/ops/_op_impl/aicpu/stack.py +45 -0
- mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
- mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
- mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
- mindspore/ops/_op_impl/aicpu/stft.py +70 -0
- mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
- mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/sub.py +41 -0
- mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
- mindspore/ops/_op_impl/aicpu/tan.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
- mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
- mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
- mindspore/ops/_op_impl/aicpu/tile.py +56 -0
- mindspore/ops/_op_impl/aicpu/topk.py +34 -0
- mindspore/ops/_op_impl/aicpu/trace.py +40 -0
- mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
- mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
- mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
- mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
- mindspore/ops/_op_impl/aicpu/tril.py +42 -0
- mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
- mindspore/ops/_op_impl/aicpu/triu.py +43 -0
- mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
- mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
- mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
- mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
- mindspore/ops/_op_impl/aicpu/unique.py +31 -0
- mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
- mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
- mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
- mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
- mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
- mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
- mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
- mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
- mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
- mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
- mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
- mindspore/ops/_op_impl/akg/__init__.py +19 -0
- mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
- mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
- mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
- mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
- mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
- mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
- mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
- mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
- mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
- mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
- mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
- mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
- mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
- mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
- mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
- mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
- mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
- mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
- mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
- mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
- mindspore/ops/_op_impl/cpu/__init__.py +78 -0
- mindspore/ops/_op_impl/cpu/adam.py +49 -0
- mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
- mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
- mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
- mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
- mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
- mindspore/ops/_op_impl/cpu/cast.py +171 -0
- mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
- mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
- mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
- mindspore/ops/_op_impl/cpu/div.py +32 -0
- mindspore/ops/_op_impl/cpu/dropout.py +31 -0
- mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
- mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
- mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
- mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
- mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
- mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
- mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
- mindspore/ops/_op_impl/cpu/hswish.py +32 -0
- mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
- mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
- mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
- mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
- mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
- mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
- mindspore/ops/_op_impl/cpu/maximum.py +35 -0
- mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
- mindspore/ops/_op_impl/cpu/minimum.py +40 -0
- mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
- mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
- mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
- mindspore/ops/_op_impl/cpu/mul.py +32 -0
- mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
- mindspore/ops/_op_impl/cpu/pad.py +32 -0
- mindspore/ops/_op_impl/cpu/pow.py +32 -0
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
- mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
- mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
- mindspore/ops/_op_impl/cpu/range.py +34 -0
- mindspore/ops/_op_impl/cpu/real_div.py +33 -0
- mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
- mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
- mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
- mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
- mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
- mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
- mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
- mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
- mindspore/ops/_op_impl/cpu/split.py +34 -0
- mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
- mindspore/ops/_op_impl/cpu/stack.py +38 -0
- mindspore/ops/_op_impl/cpu/sub.py +32 -0
- mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
- mindspore/ops/_op_impl/cpu/tile.py +37 -0
- mindspore/ops/_op_impl/cpu/top_k.py +31 -0
- mindspore/ops/_op_impl/cpu/transpose.py +39 -0
- mindspore/ops/_primitive_cache.py +90 -0
- mindspore/ops/_register_for_op.py +73 -0
- mindspore/ops/_utils/__init__.py +20 -0
- mindspore/ops/_utils/utils.py +147 -0
- mindspore/ops/_vmap/__init__.py +25 -0
- mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
- mindspore/ops/_vmap/vmap_base.py +533 -0
- mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
- mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
- mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
- mindspore/ops/_vmap/vmap_image_ops.py +194 -0
- mindspore/ops/_vmap/vmap_math_ops.py +993 -0
- mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
- mindspore/ops/_vmap/vmap_other_ops.py +105 -0
- mindspore/ops/_vmap/vmap_random_ops.py +122 -0
- mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
- mindspore/ops/auto_generate/__init__.py +31 -0
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
- mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
- mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
- mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
- mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
- mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
- mindspore/ops/composite/__init__.py +71 -0
- mindspore/ops/composite/base.py +1318 -0
- mindspore/ops/composite/env_ops.py +41 -0
- mindspore/ops/composite/math_ops.py +125 -0
- mindspore/ops/composite/multitype_ops/__init__.py +77 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
- mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
- mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
- mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
- mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
- mindspore/ops/deprecated.py +315 -0
- mindspore/ops/function/__init__.py +782 -0
- mindspore/ops/function/array_func.py +7226 -0
- mindspore/ops/function/clip_func.py +384 -0
- mindspore/ops/function/debug_func.py +181 -0
- mindspore/ops/function/fft_func.py +44 -0
- mindspore/ops/function/grad/__init__.py +34 -0
- mindspore/ops/function/grad/grad_func.py +1425 -0
- mindspore/ops/function/image_func.py +292 -0
- mindspore/ops/function/linalg_func.py +416 -0
- mindspore/ops/function/math_func.py +12228 -0
- mindspore/ops/function/nn_func.py +8609 -0
- mindspore/ops/function/other_func.py +115 -0
- mindspore/ops/function/parameter_func.py +134 -0
- mindspore/ops/function/random_func.py +1715 -0
- mindspore/ops/function/reshard_func.py +104 -0
- mindspore/ops/function/sparse_func.py +884 -0
- mindspore/ops/function/sparse_unary_func.py +2422 -0
- mindspore/ops/function/spectral_func.py +150 -0
- mindspore/ops/function/vmap_func.py +117 -0
- mindspore/ops/functional.py +464 -0
- mindspore/ops/op_info_register.py +1572 -0
- mindspore/ops/operations/__init__.py +722 -0
- mindspore/ops/operations/_csr_ops.py +403 -0
- mindspore/ops/operations/_custom_grad.py +181 -0
- mindspore/ops/operations/_embedding_cache_ops.py +307 -0
- mindspore/ops/operations/_grad_ops.py +2978 -0
- mindspore/ops/operations/_infer_ops.py +19 -0
- mindspore/ops/operations/_inner_ops.py +2544 -0
- mindspore/ops/operations/_map_tensor_ops.py +112 -0
- mindspore/ops/operations/_ms_kernel.py +601 -0
- mindspore/ops/operations/_ocr_ops.py +379 -0
- mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
- mindspore/ops/operations/_pyfunc_registry.py +58 -0
- mindspore/ops/operations/_quant_ops.py +1844 -0
- mindspore/ops/operations/_rl_inner_ops.py +1231 -0
- mindspore/ops/operations/_scalar_ops.py +106 -0
- mindspore/ops/operations/_sequence_ops.py +1155 -0
- mindspore/ops/operations/_sparse_grad_ops.py +56 -0
- mindspore/ops/operations/_tensor_array.py +359 -0
- mindspore/ops/operations/_thor_ops.py +807 -0
- mindspore/ops/operations/array_ops.py +6124 -0
- mindspore/ops/operations/comm_ops.py +1985 -0
- mindspore/ops/operations/control_ops.py +127 -0
- mindspore/ops/operations/custom_ops.py +1129 -0
- mindspore/ops/operations/debug_ops.py +678 -0
- mindspore/ops/operations/image_ops.py +1041 -0
- mindspore/ops/operations/inner_ops.py +697 -0
- mindspore/ops/operations/linalg_ops.py +95 -0
- mindspore/ops/operations/manually_defined/__init__.py +24 -0
- mindspore/ops/operations/manually_defined/_inner.py +73 -0
- mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
- mindspore/ops/operations/math_ops.py +5095 -0
- mindspore/ops/operations/nn_ops.py +9575 -0
- mindspore/ops/operations/other_ops.py +874 -0
- mindspore/ops/operations/random_ops.py +1288 -0
- mindspore/ops/operations/reshard_ops.py +53 -0
- mindspore/ops/operations/rl_ops.py +288 -0
- mindspore/ops/operations/sparse_ops.py +2753 -0
- mindspore/ops/operations/spectral_ops.py +111 -0
- mindspore/ops/primitive.py +1046 -0
- mindspore/ops/signature.py +54 -0
- mindspore/ops/vm_impl_registry.py +91 -0
- mindspore/ops_generate/__init__.py +27 -0
- mindspore/ops_generate/arg_dtype_cast.py +252 -0
- mindspore/ops_generate/arg_handler.py +197 -0
- mindspore/ops_generate/gen_aclnn_implement.py +263 -0
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +1099 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
- mindspore/ops_generate/gen_pyboost_func.py +1052 -0
- mindspore/ops_generate/gen_utils.py +209 -0
- mindspore/ops_generate/op_proto.py +145 -0
- mindspore/ops_generate/pyboost_utils.py +367 -0
- mindspore/ops_generate/template.py +261 -0
- mindspore/parallel/__init__.py +30 -0
- mindspore/parallel/_auto_parallel_context.py +1486 -0
- mindspore/parallel/_cell_wrapper.py +174 -0
- mindspore/parallel/_cost_model_context.py +700 -0
- mindspore/parallel/_dp_allreduce_fusion.py +159 -0
- mindspore/parallel/_offload_context.py +275 -0
- mindspore/parallel/_parallel_serialization.py +561 -0
- mindspore/parallel/_ps_context.py +242 -0
- mindspore/parallel/_recovery_context.py +110 -0
- mindspore/parallel/_tensor.py +730 -0
- mindspore/parallel/_transformer/__init__.py +35 -0
- mindspore/parallel/_transformer/layers.py +765 -0
- mindspore/parallel/_transformer/loss.py +251 -0
- mindspore/parallel/_transformer/moe.py +693 -0
- mindspore/parallel/_transformer/op_parallel_config.py +222 -0
- mindspore/parallel/_transformer/transformer.py +3119 -0
- mindspore/parallel/_utils.py +612 -0
- mindspore/parallel/algo_parameter_config.py +400 -0
- mindspore/parallel/checkpoint_transform.py +650 -0
- mindspore/parallel/cluster/__init__.py +15 -0
- mindspore/parallel/cluster/process_entity/__init__.py +18 -0
- mindspore/parallel/cluster/process_entity/_api.py +352 -0
- mindspore/parallel/cluster/process_entity/_utils.py +101 -0
- mindspore/parallel/cluster/run.py +136 -0
- mindspore/parallel/mpi/__init__.py +14 -0
- mindspore/parallel/mpi/_mpi_config.py +116 -0
- mindspore/parallel/parameter_broadcast.py +151 -0
- mindspore/parallel/shard.py +481 -0
- mindspore/parallel/transform_safetensors.py +993 -0
- mindspore/profiler/__init__.py +28 -0
- mindspore/profiler/common/__init__.py +14 -0
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/exceptions/__init__.py +14 -0
- mindspore/profiler/common/exceptions/error_code.py +83 -0
- mindspore/profiler/common/exceptions/exceptions.py +286 -0
- mindspore/profiler/common/process_pool.py +41 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/singleton.py +28 -0
- mindspore/profiler/common/struct_type.py +118 -0
- mindspore/profiler/common/util.py +472 -0
- mindspore/profiler/common/validator/__init__.py +14 -0
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +254 -0
- mindspore/profiler/parser/__init__.py +14 -0
- mindspore/profiler/parser/aicpu_data_parser.py +272 -0
- mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
- mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
- mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
- mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
- mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
- mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
- mindspore/profiler/parser/ascend_flops_generator.py +116 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
- mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
- mindspore/profiler/parser/ascend_memory_generator.py +185 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
- mindspore/profiler/parser/ascend_op_generator.py +334 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
- mindspore/profiler/parser/base_timeline_generator.py +483 -0
- mindspore/profiler/parser/container.py +229 -0
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
- mindspore/profiler/parser/flops_parser.py +531 -0
- mindspore/profiler/parser/framework_enum.py +111 -0
- mindspore/profiler/parser/framework_parser.py +464 -0
- mindspore/profiler/parser/framework_struct.py +61 -0
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/hccl_parser.py +573 -0
- mindspore/profiler/parser/hwts_log_parser.py +122 -0
- mindspore/profiler/parser/integrator.py +526 -0
- mindspore/profiler/parser/memory_usage_parser.py +277 -0
- mindspore/profiler/parser/minddata_analyzer.py +800 -0
- mindspore/profiler/parser/minddata_parser.py +186 -0
- mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
- mindspore/profiler/parser/op_intermediate_parser.py +149 -0
- mindspore/profiler/parser/optime_parser.py +250 -0
- mindspore/profiler/parser/profiler_info.py +213 -0
- mindspore/profiler/parser/step_trace_parser.py +666 -0
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +1922 -0
- mindspore/rewrite/__init__.py +28 -0
- mindspore/rewrite/api/__init__.py +17 -0
- mindspore/rewrite/api/node.py +519 -0
- mindspore/rewrite/api/node_type.py +53 -0
- mindspore/rewrite/api/pattern_engine.py +490 -0
- mindspore/rewrite/api/scoped_value.py +181 -0
- mindspore/rewrite/api/symbol_tree.py +497 -0
- mindspore/rewrite/ast_helpers/__init__.py +25 -0
- mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
- mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
- mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
- mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
- mindspore/rewrite/common/__init__.py +19 -0
- mindspore/rewrite/common/config.py +24 -0
- mindspore/rewrite/common/error_log.py +39 -0
- mindspore/rewrite/common/event.py +28 -0
- mindspore/rewrite/common/namer.py +271 -0
- mindspore/rewrite/common/namespace.py +118 -0
- mindspore/rewrite/common/observable.py +44 -0
- mindspore/rewrite/common/observer.py +54 -0
- mindspore/rewrite/node/__init__.py +22 -0
- mindspore/rewrite/node/call_function.py +95 -0
- mindspore/rewrite/node/cell_container.py +139 -0
- mindspore/rewrite/node/control_flow.py +113 -0
- mindspore/rewrite/node/node.py +1428 -0
- mindspore/rewrite/node/node_manager.py +283 -0
- mindspore/rewrite/node/node_topological_manager.py +223 -0
- mindspore/rewrite/parsers/__init__.py +29 -0
- mindspore/rewrite/parsers/arguments_parser.py +63 -0
- mindspore/rewrite/parsers/assign_parser.py +852 -0
- mindspore/rewrite/parsers/attribute_parser.py +57 -0
- mindspore/rewrite/parsers/class_def_parser.py +289 -0
- mindspore/rewrite/parsers/constant_parser.py +104 -0
- mindspore/rewrite/parsers/container_parser.py +88 -0
- mindspore/rewrite/parsers/expr_parser.py +55 -0
- mindspore/rewrite/parsers/for_parser.py +61 -0
- mindspore/rewrite/parsers/function_def_parser.py +84 -0
- mindspore/rewrite/parsers/if_parser.py +85 -0
- mindspore/rewrite/parsers/module_parser.py +117 -0
- mindspore/rewrite/parsers/parser.py +43 -0
- mindspore/rewrite/parsers/parser_register.py +86 -0
- mindspore/rewrite/parsers/return_parser.py +37 -0
- mindspore/rewrite/parsers/while_parser.py +59 -0
- mindspore/rewrite/sparsify/__init__.py +0 -0
- mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
- mindspore/rewrite/sparsify/sparsify.py +112 -0
- mindspore/rewrite/sparsify/utils.py +179 -0
- mindspore/rewrite/symbol_tree/__init__.py +20 -0
- mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
- mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
- mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
- mindspore/run_check/__init__.py +20 -0
- mindspore/run_check/_check_version.py +507 -0
- mindspore/run_check/run_check.py +66 -0
- mindspore/safeguard/__init__.py +18 -0
- mindspore/safeguard/rewrite_obfuscation.py +875 -0
- mindspore/scipy/__init__.py +18 -0
- mindspore/scipy/fft.py +264 -0
- mindspore/scipy/linalg.py +919 -0
- mindspore/scipy/ops.py +165 -0
- mindspore/scipy/ops_grad.py +115 -0
- mindspore/scipy/ops_wrapper.py +74 -0
- mindspore/scipy/optimize/__init__.py +20 -0
- mindspore/scipy/optimize/_bfgs.py +230 -0
- mindspore/scipy/optimize/_lagrange.py +201 -0
- mindspore/scipy/optimize/_lbfgs.py +146 -0
- mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
- mindspore/scipy/optimize/line_search.py +370 -0
- mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
- mindspore/scipy/optimize/minimize.py +200 -0
- mindspore/scipy/utils.py +156 -0
- mindspore/scipy/utils_const.py +246 -0
- mindspore/train/__init__.py +48 -0
- mindspore/train/_utils.py +465 -0
- mindspore/train/amp.py +935 -0
- mindspore/train/anf_ir_pb2.py +1517 -0
- mindspore/train/callback/__init__.py +44 -0
- mindspore/train/callback/_backup_and_restore.py +117 -0
- mindspore/train/callback/_callback.py +613 -0
- mindspore/train/callback/_checkpoint.py +814 -0
- mindspore/train/callback/_cluster_monitor.py +201 -0
- mindspore/train/callback/_dataset_graph.py +150 -0
- mindspore/train/callback/_early_stop.py +239 -0
- mindspore/train/callback/_flops_collector.py +239 -0
- mindspore/train/callback/_history.py +92 -0
- mindspore/train/callback/_lambda_callback.py +80 -0
- mindspore/train/callback/_landscape.py +1049 -0
- mindspore/train/callback/_loss_monitor.py +107 -0
- mindspore/train/callback/_lr_scheduler_callback.py +76 -0
- mindspore/train/callback/_on_request_exit.py +298 -0
- mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
- mindspore/train/callback/_summary_collector.py +1184 -0
- mindspore/train/callback/_tft_register.py +352 -0
- mindspore/train/callback/_time_monitor.py +141 -0
- mindspore/train/checkpoint_pb2.py +233 -0
- mindspore/train/data_sink.py +219 -0
- mindspore/train/dataset_helper.py +692 -0
- mindspore/train/lineage_pb2.py +1260 -0
- mindspore/train/loss_scale_manager.py +213 -0
- mindspore/train/memory_profiling_pb2.py +298 -0
- mindspore/train/metrics/__init__.py +175 -0
- mindspore/train/metrics/accuracy.py +133 -0
- mindspore/train/metrics/auc.py +129 -0
- mindspore/train/metrics/bleu_score.py +170 -0
- mindspore/train/metrics/confusion_matrix.py +700 -0
- mindspore/train/metrics/cosine_similarity.py +109 -0
- mindspore/train/metrics/dice.py +116 -0
- mindspore/train/metrics/error.py +175 -0
- mindspore/train/metrics/fbeta.py +167 -0
- mindspore/train/metrics/hausdorff_distance.py +333 -0
- mindspore/train/metrics/loss.py +97 -0
- mindspore/train/metrics/mean_surface_distance.py +189 -0
- mindspore/train/metrics/metric.py +373 -0
- mindspore/train/metrics/occlusion_sensitivity.py +225 -0
- mindspore/train/metrics/perplexity.py +133 -0
- mindspore/train/metrics/precision.py +160 -0
- mindspore/train/metrics/recall.py +159 -0
- mindspore/train/metrics/roc.py +223 -0
- mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
- mindspore/train/metrics/topk.py +167 -0
- mindspore/train/mind_ir_pb2.py +1908 -0
- mindspore/train/model.py +2252 -0
- mindspore/train/node_strategy_pb2.py +653 -0
- mindspore/train/print_pb2.py +184 -0
- mindspore/train/profiling_parallel_pb2.py +151 -0
- mindspore/train/serialization.py +3325 -0
- mindspore/train/summary/__init__.py +23 -0
- mindspore/train/summary/_lineage_adapter.py +41 -0
- mindspore/train/summary/_summary_adapter.py +496 -0
- mindspore/train/summary/_writer_pool.py +207 -0
- mindspore/train/summary/enums.py +56 -0
- mindspore/train/summary/summary_record.py +581 -0
- mindspore/train/summary/writer.py +167 -0
- mindspore/train/summary_pb2.py +1165 -0
- mindspore/train/train_thor/__init__.py +20 -0
- mindspore/train/train_thor/convert_utils.py +268 -0
- mindspore/train/train_thor/dataset_helper.py +192 -0
- mindspore/train/train_thor/model_thor.py +257 -0
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -0
- mindspore-2.4.0.dist-info/METADATA +352 -0
- mindspore-2.4.0.dist-info/RECORD +1387 -0
- mindspore-2.4.0.dist-info/WHEEL +5 -0
- mindspore-2.4.0.dist-info/entry_points.txt +3 -0
- mindspore-2.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,2895 @@
|
|
|
1
|
+
# Copyright 2019-2023 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License foNtest_resr the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
Built-in validators.
|
|
18
|
+
"""
|
|
19
|
+
import inspect as ins
|
|
20
|
+
import os
|
|
21
|
+
from functools import wraps
|
|
22
|
+
import numpy as np
|
|
23
|
+
|
|
24
|
+
from mindspore._c_expression import typing
|
|
25
|
+
from mindspore import log as logger
|
|
26
|
+
from ..core.validator_helpers import parse_user_args, type_check, type_check_list, check_value, \
|
|
27
|
+
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
|
|
28
|
+
validate_dataset_param_value, check_padding_options, \
|
|
29
|
+
check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str, check_dataset_num_shards_shard_id, \
|
|
30
|
+
check_valid_list_tuple, check_int32, check_independent_mode
|
|
31
|
+
|
|
32
|
+
from . import datasets
|
|
33
|
+
from . import samplers
|
|
34
|
+
from . import cache_client
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def check_cmu_arctic_dataset(method):
|
|
38
|
+
"""A wrapper that wraps a parameter checker around the original CMUArcticDataset."""
|
|
39
|
+
|
|
40
|
+
@wraps(method)
|
|
41
|
+
def new_method(self, *args, **kwargs):
|
|
42
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
43
|
+
|
|
44
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
45
|
+
nreq_param_bool = ['shuffle']
|
|
46
|
+
|
|
47
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
48
|
+
check_dir(dataset_dir)
|
|
49
|
+
|
|
50
|
+
name = param_dict.get('name')
|
|
51
|
+
if name is not None:
|
|
52
|
+
check_valid_str(name, ['aew', 'ahw', 'aup', 'awb', 'axb', 'bdl', 'clb', 'eey',
|
|
53
|
+
'fem', 'gka', 'jmk', 'ksp', 'ljm', 'lnh', 'rms', 'rxr', 'slp', 'slt'], "name")
|
|
54
|
+
|
|
55
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
56
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
57
|
+
|
|
58
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
59
|
+
|
|
60
|
+
cache = param_dict.get('cache')
|
|
61
|
+
check_cache_option(cache)
|
|
62
|
+
|
|
63
|
+
return method(self, *args, **kwargs)
|
|
64
|
+
|
|
65
|
+
return new_method
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def check_gtzan_dataset(method):
|
|
69
|
+
"""A wrapper that wraps a parameter checker around the original GTZANDataset."""
|
|
70
|
+
|
|
71
|
+
@wraps(method)
|
|
72
|
+
def new_method(self, *args, **kwargs):
|
|
73
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
74
|
+
|
|
75
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
76
|
+
nreq_param_bool = ['shuffle']
|
|
77
|
+
|
|
78
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
79
|
+
check_dir(dataset_dir)
|
|
80
|
+
|
|
81
|
+
usage = param_dict.get('usage')
|
|
82
|
+
if usage is not None:
|
|
83
|
+
check_valid_str(usage, ['train', 'valid', 'test', 'all'], "usage")
|
|
84
|
+
|
|
85
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
86
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
87
|
+
|
|
88
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
89
|
+
|
|
90
|
+
cache = param_dict.get('cache')
|
|
91
|
+
check_cache_option(cache)
|
|
92
|
+
|
|
93
|
+
return method(self, *args, **kwargs)
|
|
94
|
+
|
|
95
|
+
return new_method
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def check_imagefolderdataset(method):
|
|
99
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(ImageFolderDataset)."""
|
|
100
|
+
|
|
101
|
+
@wraps(method)
|
|
102
|
+
def new_method(self, *args, **kwargs):
|
|
103
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
104
|
+
|
|
105
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
106
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
107
|
+
nreq_param_list = ['extensions']
|
|
108
|
+
nreq_param_dict = ['class_indexing']
|
|
109
|
+
|
|
110
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
111
|
+
check_dir(dataset_dir)
|
|
112
|
+
|
|
113
|
+
decrypt = param_dict.get('decrypt')
|
|
114
|
+
if decrypt is not None and not callable(decrypt):
|
|
115
|
+
raise TypeError("Argument decrypt is not a callable object, but got " + str(type(decrypt)))
|
|
116
|
+
|
|
117
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
118
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
119
|
+
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
|
120
|
+
validate_dataset_param_value(nreq_param_dict, param_dict, dict)
|
|
121
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
122
|
+
|
|
123
|
+
cache = param_dict.get('cache')
|
|
124
|
+
check_cache_option(cache)
|
|
125
|
+
|
|
126
|
+
return method(self, *args, **kwargs)
|
|
127
|
+
|
|
128
|
+
return new_method
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def check_imdb_dataset(method):
|
|
132
|
+
"""A wrapper that wraps a parameter checker around the original IMDBDataset."""
|
|
133
|
+
|
|
134
|
+
@wraps(method)
|
|
135
|
+
def new_method(self, *args, **kwargs):
|
|
136
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
137
|
+
|
|
138
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
139
|
+
nreq_param_bool = ['shuffle']
|
|
140
|
+
|
|
141
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
142
|
+
check_dir(dataset_dir)
|
|
143
|
+
|
|
144
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
145
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
146
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
147
|
+
|
|
148
|
+
cache = param_dict.get('cache')
|
|
149
|
+
check_cache_option(cache)
|
|
150
|
+
|
|
151
|
+
usage = param_dict.get('usage')
|
|
152
|
+
if usage is not None:
|
|
153
|
+
check_valid_str(usage, ["train", "test", "all"], "usage")
|
|
154
|
+
|
|
155
|
+
return method(self, *args, **kwargs)
|
|
156
|
+
|
|
157
|
+
return new_method
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def check_iwslt2016_dataset(method):
|
|
161
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(IWSLT2016dataset)."""
|
|
162
|
+
|
|
163
|
+
@wraps(method)
|
|
164
|
+
def new_method(self, *args, **kwargs):
|
|
165
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
166
|
+
|
|
167
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
168
|
+
|
|
169
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
170
|
+
check_dir(dataset_dir)
|
|
171
|
+
|
|
172
|
+
# check usage
|
|
173
|
+
usage = param_dict.get('usage')
|
|
174
|
+
if usage is not None:
|
|
175
|
+
check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
|
|
176
|
+
|
|
177
|
+
support_language_pair = [
|
|
178
|
+
['en', 'ar'], ['en', 'ar'], ['en', 'de'], ['en', 'fr'], ['en', 'cs'], ['ar', 'en'], ['fr', 'en'],
|
|
179
|
+
['de', 'en'], ['cs', 'en']
|
|
180
|
+
]
|
|
181
|
+
support_language_pair_tuple = (
|
|
182
|
+
('en', 'ar'), ('en', 'ar'), ('en', 'de'), ('en', 'fr'), ('en', 'cs'), ('ar', 'en'), ('fr', 'en'),
|
|
183
|
+
('de', 'en'), ('cs', 'en')
|
|
184
|
+
)
|
|
185
|
+
support_set_type = ["dev2010", "tst2010", "tst2011", "tst2012", "tst2013", "tst2014"]
|
|
186
|
+
# check language_pair
|
|
187
|
+
language_pair = param_dict.get('language_pair')
|
|
188
|
+
if language_pair is not None:
|
|
189
|
+
if isinstance(language_pair, (list,)):
|
|
190
|
+
check_valid_list_tuple(language_pair, support_language_pair, (str,), "language_pair")
|
|
191
|
+
elif isinstance(language_pair, (tuple,)):
|
|
192
|
+
check_valid_list_tuple(language_pair, support_language_pair_tuple, (str,), "language_pair")
|
|
193
|
+
else:
|
|
194
|
+
raise TypeError("language_pair should be a type list or tuple of length 2.")
|
|
195
|
+
|
|
196
|
+
# check valid_set
|
|
197
|
+
valid_set = param_dict.get('valid_set')
|
|
198
|
+
if valid_set is not None:
|
|
199
|
+
check_valid_str(valid_set, support_set_type, "valid_set")
|
|
200
|
+
|
|
201
|
+
# check test_set
|
|
202
|
+
test_set = param_dict.get('test_set')
|
|
203
|
+
if test_set is not None:
|
|
204
|
+
check_valid_str(test_set, support_set_type, "test_set")
|
|
205
|
+
|
|
206
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
207
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
208
|
+
|
|
209
|
+
cache = param_dict.get('cache')
|
|
210
|
+
check_cache_option(cache)
|
|
211
|
+
|
|
212
|
+
return method(self, *args, **kwargs)
|
|
213
|
+
|
|
214
|
+
return new_method
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def check_iwslt2017_dataset(method):
|
|
218
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(IWSLT2017dataset)."""
|
|
219
|
+
|
|
220
|
+
@wraps(method)
|
|
221
|
+
def new_method(self, *args, **kwargs):
|
|
222
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
223
|
+
|
|
224
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
225
|
+
|
|
226
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
227
|
+
check_dir(dataset_dir)
|
|
228
|
+
|
|
229
|
+
# check usage
|
|
230
|
+
usage = param_dict.get('usage')
|
|
231
|
+
if usage is not None:
|
|
232
|
+
check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
|
|
233
|
+
|
|
234
|
+
support_language_pair = [
|
|
235
|
+
['en', 'nl'], ['en', 'de'], ['en', 'it'], ['en', 'ro'], ['ro', 'de'], ['ro', 'en'], ['ro', 'nl'],
|
|
236
|
+
['ro', 'it'], ['de', 'ro'], ['de', 'en'], ['de', 'nl'], ['de', 'it'], ['it', 'en'], ['it', 'nl'],
|
|
237
|
+
['it', 'de'], ['it', 'ro'], ['nl', 'de'], ['nl', 'en'], ['nl', 'it'], ['nl', 'ro']
|
|
238
|
+
]
|
|
239
|
+
support_language_pair_tuple = (
|
|
240
|
+
('en', 'nl'), ('en', 'de'), ('en', 'it'), ('en', 'ro'), ('ro', 'de'), ('ro', 'en'), ('ro', 'nl'),
|
|
241
|
+
('ro', 'it'), ('de', 'ro'), ('de', 'en'), ('de', 'nl'), ('de', 'it'), ('it', 'en'), ('it', 'nl'),
|
|
242
|
+
('it', 'de'), ('it', 'ro'), ('nl', 'de'), ('nl', 'en'), ('nl', 'it'), ('nl', 'ro')
|
|
243
|
+
)
|
|
244
|
+
# check language_pair
|
|
245
|
+
language_pair = param_dict.get('language_pair')
|
|
246
|
+
if language_pair is not None:
|
|
247
|
+
if isinstance(language_pair, (list,)):
|
|
248
|
+
check_valid_list_tuple(language_pair, support_language_pair, (str,), "language_pair")
|
|
249
|
+
elif isinstance(language_pair, (tuple,)):
|
|
250
|
+
check_valid_list_tuple(language_pair, support_language_pair_tuple, (str,), "language_pair")
|
|
251
|
+
else:
|
|
252
|
+
raise TypeError("language_pair should be a type list or tuple of length 2.")
|
|
253
|
+
|
|
254
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
255
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
256
|
+
|
|
257
|
+
cache = param_dict.get('cache')
|
|
258
|
+
check_cache_option(cache)
|
|
259
|
+
|
|
260
|
+
return method(self, *args, **kwargs)
|
|
261
|
+
|
|
262
|
+
return new_method
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def check_kittidataset(method):
|
|
266
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(KITTIDataset)."""
|
|
267
|
+
|
|
268
|
+
@wraps(method)
|
|
269
|
+
def new_method(self, *args, **kwargs):
|
|
270
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
271
|
+
|
|
272
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
273
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
274
|
+
|
|
275
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
276
|
+
check_dir(dataset_dir)
|
|
277
|
+
|
|
278
|
+
usage = param_dict.get('usage')
|
|
279
|
+
if usage is not None:
|
|
280
|
+
check_valid_str(usage, ["train", "test"], "usage")
|
|
281
|
+
|
|
282
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
283
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
284
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
285
|
+
|
|
286
|
+
cache = param_dict.get('cache')
|
|
287
|
+
check_cache_option(cache)
|
|
288
|
+
|
|
289
|
+
return method(self, *args, **kwargs)
|
|
290
|
+
|
|
291
|
+
return new_method
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def check_lsun_dataset(method):
|
|
295
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(LSUNDataset)."""
|
|
296
|
+
|
|
297
|
+
@wraps(method)
|
|
298
|
+
def new_method(self, *args, **kwargs):
|
|
299
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
300
|
+
|
|
301
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
302
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
303
|
+
nreq_param_list = ['classes']
|
|
304
|
+
|
|
305
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
306
|
+
check_dir(dataset_dir)
|
|
307
|
+
|
|
308
|
+
usage = param_dict.get('usage')
|
|
309
|
+
if usage is not None:
|
|
310
|
+
check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
|
|
311
|
+
|
|
312
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
313
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
314
|
+
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
|
315
|
+
|
|
316
|
+
categories = [
|
|
317
|
+
'bedroom', 'bridge', 'church_outdoor', 'classroom', 'conference_room', 'dining_room', 'kitchen',
|
|
318
|
+
'living_room', 'restaurant', 'tower'
|
|
319
|
+
]
|
|
320
|
+
classes = param_dict.get('classes')
|
|
321
|
+
if classes is not None:
|
|
322
|
+
for class_name in classes:
|
|
323
|
+
check_valid_str(class_name, categories, "classes")
|
|
324
|
+
|
|
325
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
326
|
+
|
|
327
|
+
cache = param_dict.get('cache')
|
|
328
|
+
check_cache_option(cache)
|
|
329
|
+
|
|
330
|
+
return method(self, *args, **kwargs)
|
|
331
|
+
|
|
332
|
+
return new_method
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def check_mnist_cifar_dataset(method):
|
|
336
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset, Cifar10/100Dataset)."""
|
|
337
|
+
|
|
338
|
+
@wraps(method)
|
|
339
|
+
def new_method(self, *args, **kwargs):
|
|
340
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
341
|
+
|
|
342
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
343
|
+
nreq_param_bool = ['shuffle']
|
|
344
|
+
|
|
345
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
346
|
+
check_dir(dataset_dir)
|
|
347
|
+
|
|
348
|
+
usage = param_dict.get('usage')
|
|
349
|
+
if usage is not None:
|
|
350
|
+
check_valid_str(usage, ["train", "test", "all"], "usage")
|
|
351
|
+
|
|
352
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
353
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
354
|
+
|
|
355
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
356
|
+
|
|
357
|
+
cache = param_dict.get('cache')
|
|
358
|
+
check_cache_option(cache)
|
|
359
|
+
|
|
360
|
+
return method(self, *args, **kwargs)
|
|
361
|
+
|
|
362
|
+
return new_method
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def check_omniglotdataset(method):
|
|
366
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(OmniglotDataset)."""
|
|
367
|
+
|
|
368
|
+
@wraps(method)
|
|
369
|
+
def new_method(self, *args, **kwargs):
|
|
370
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
371
|
+
|
|
372
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
373
|
+
nreq_param_bool = ['shuffle', 'background', 'decode']
|
|
374
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
375
|
+
check_dir(dataset_dir)
|
|
376
|
+
|
|
377
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
378
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
379
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
380
|
+
|
|
381
|
+
cache = param_dict.get('cache')
|
|
382
|
+
check_cache_option(cache)
|
|
383
|
+
|
|
384
|
+
return method(self, *args, **kwargs)
|
|
385
|
+
|
|
386
|
+
return new_method
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def check_photo_tour_dataset(method):
|
|
390
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(PhotoTourDataset)."""
|
|
391
|
+
|
|
392
|
+
@wraps(method)
|
|
393
|
+
def new_method(self, *args, **kwargs):
|
|
394
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
395
|
+
|
|
396
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
397
|
+
nreq_param_bool = ['shuffle']
|
|
398
|
+
|
|
399
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
400
|
+
check_dir(dataset_dir)
|
|
401
|
+
|
|
402
|
+
usage = param_dict.get('usage')
|
|
403
|
+
if usage is not None:
|
|
404
|
+
check_valid_str(usage, ["train", "test"], "usage")
|
|
405
|
+
name = param_dict.get('name')
|
|
406
|
+
check_valid_str(name, ["notredame", "yosemite", "liberty", "notredame_harris",
|
|
407
|
+
"yosemite_harris", "liberty_harris"], "name")
|
|
408
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
409
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
410
|
+
|
|
411
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
412
|
+
cache = param_dict.get('cache')
|
|
413
|
+
check_cache_option(cache)
|
|
414
|
+
return method(self, *args, **kwargs)
|
|
415
|
+
|
|
416
|
+
return new_method
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def check_places365_dataset(method):
|
|
420
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(Places365Dataset)."""
|
|
421
|
+
|
|
422
|
+
@wraps(method)
|
|
423
|
+
def new_method(self, *args, **kwargs):
|
|
424
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
425
|
+
|
|
426
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
427
|
+
nreq_param_bool = ['shuffle', 'small', 'decode']
|
|
428
|
+
|
|
429
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
430
|
+
check_dir(dataset_dir)
|
|
431
|
+
|
|
432
|
+
usage = param_dict.get('usage')
|
|
433
|
+
if usage is not None:
|
|
434
|
+
check_valid_str(usage, ["train-standard", "train-challenge", "val"], "usage")
|
|
435
|
+
|
|
436
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
437
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
438
|
+
|
|
439
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
440
|
+
|
|
441
|
+
cache = param_dict.get('cache')
|
|
442
|
+
check_cache_option(cache)
|
|
443
|
+
|
|
444
|
+
return method(self, *args, **kwargs)
|
|
445
|
+
|
|
446
|
+
return new_method
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
def check_qmnist_dataset(method):
|
|
450
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(QMnistDataset)."""
|
|
451
|
+
|
|
452
|
+
@wraps(method)
|
|
453
|
+
def new_method(self, *args, **kwargs):
|
|
454
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
455
|
+
|
|
456
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
457
|
+
nreq_param_bool = ['shuffle', 'compat']
|
|
458
|
+
|
|
459
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
460
|
+
check_dir(dataset_dir)
|
|
461
|
+
|
|
462
|
+
usage = param_dict.get('usage')
|
|
463
|
+
if usage is not None:
|
|
464
|
+
check_valid_str(usage, ["train", "test", "test10k", "test50k", "nist", "all"], "usage")
|
|
465
|
+
|
|
466
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
467
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
468
|
+
|
|
469
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
470
|
+
|
|
471
|
+
cache = param_dict.get('cache')
|
|
472
|
+
check_cache_option(cache)
|
|
473
|
+
|
|
474
|
+
return method(self, *args, **kwargs)
|
|
475
|
+
|
|
476
|
+
return new_method
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def check_manifestdataset(method):
|
|
480
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(ManifestDataset)."""
|
|
481
|
+
|
|
482
|
+
@wraps(method)
|
|
483
|
+
def new_method(self, *args, **kwargs):
|
|
484
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
485
|
+
|
|
486
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
487
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
488
|
+
nreq_param_str = ['usage']
|
|
489
|
+
nreq_param_dict = ['class_indexing']
|
|
490
|
+
|
|
491
|
+
dataset_file = param_dict.get('dataset_file')
|
|
492
|
+
check_file(dataset_file)
|
|
493
|
+
|
|
494
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
495
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
496
|
+
validate_dataset_param_value(nreq_param_str, param_dict, str)
|
|
497
|
+
validate_dataset_param_value(nreq_param_dict, param_dict, dict)
|
|
498
|
+
|
|
499
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
500
|
+
|
|
501
|
+
cache = param_dict.get('cache')
|
|
502
|
+
check_cache_option(cache)
|
|
503
|
+
|
|
504
|
+
return method(self, *args, **kwargs)
|
|
505
|
+
|
|
506
|
+
return new_method
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def check_sbu_dataset(method):
|
|
510
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(SBUDataset)."""
|
|
511
|
+
|
|
512
|
+
@wraps(method)
|
|
513
|
+
def new_method(self, *args, **kwargs):
|
|
514
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
515
|
+
|
|
516
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
517
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
518
|
+
|
|
519
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
520
|
+
check_dir(dataset_dir)
|
|
521
|
+
|
|
522
|
+
check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_urls.txt"))
|
|
523
|
+
check_file(os.path.join(dataset_dir, "SBU_captioned_photo_dataset_captions.txt"))
|
|
524
|
+
check_dir(os.path.join(dataset_dir, "sbu_images"))
|
|
525
|
+
|
|
526
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
527
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
528
|
+
|
|
529
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
530
|
+
|
|
531
|
+
cache = param_dict.get('cache')
|
|
532
|
+
check_cache_option(cache)
|
|
533
|
+
|
|
534
|
+
return method(self, *args, **kwargs)
|
|
535
|
+
|
|
536
|
+
return new_method
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def check_sogou_news_dataset(method):
|
|
540
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(SogouNewsDataset)."""
|
|
541
|
+
|
|
542
|
+
@wraps(method)
|
|
543
|
+
def new_method(self, *args, **kwargs):
|
|
544
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
545
|
+
|
|
546
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
547
|
+
|
|
548
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
549
|
+
check_dir(dataset_dir)
|
|
550
|
+
|
|
551
|
+
usage = param_dict.get('usage')
|
|
552
|
+
if usage is not None:
|
|
553
|
+
check_valid_str(usage, ["train", "test", "all"], "usage")
|
|
554
|
+
|
|
555
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
556
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
557
|
+
|
|
558
|
+
cache = param_dict.get('cache')
|
|
559
|
+
check_cache_option(cache)
|
|
560
|
+
|
|
561
|
+
return method(self, *args, **kwargs)
|
|
562
|
+
|
|
563
|
+
return new_method
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def check_tfrecorddataset(method):
|
|
567
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(TFRecordDataset)."""
|
|
568
|
+
|
|
569
|
+
@wraps(method)
|
|
570
|
+
def new_method(self, *args, **kwargs):
|
|
571
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
572
|
+
|
|
573
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
574
|
+
nreq_param_list = ['columns_list']
|
|
575
|
+
nreq_param_bool = ['shard_equal_rows']
|
|
576
|
+
|
|
577
|
+
dataset_files = param_dict.get('dataset_files')
|
|
578
|
+
if not isinstance(dataset_files, (str, list)):
|
|
579
|
+
raise TypeError("dataset_files should be type str or a list of strings.")
|
|
580
|
+
if not dataset_files:
|
|
581
|
+
raise ValueError("Input dataset_files can not be empty, but got '" + str(dataset_files) + "'.")
|
|
582
|
+
|
|
583
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
584
|
+
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
|
585
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
586
|
+
|
|
587
|
+
compression_type = param_dict.get('compression_type')
|
|
588
|
+
if compression_type is not None and compression_type not in ['', 'ZLIB', 'GZIP']:
|
|
589
|
+
raise ValueError("Input compression_type can only be either '' (no compression), 'ZLIB', or 'GZIP', " +
|
|
590
|
+
"but got '" + str(compression_type) + "'.")
|
|
591
|
+
if compression_type is not None and compression_type in ['ZLIB', 'GZIP'] and \
|
|
592
|
+
param_dict.get('num_samples') is not None:
|
|
593
|
+
if param_dict.get('num_shards') is not None and ((isinstance(dataset_files, str) and \
|
|
594
|
+
param_dict.get('num_shards') > 1) or (isinstance(dataset_files, list) and \
|
|
595
|
+
len(dataset_files) < param_dict.get('num_shards'))):
|
|
596
|
+
num_files = len(dataset_files) if isinstance(dataset_files, list) else 1
|
|
597
|
+
act_num_shard = param_dict.get('num_shards') if param_dict.get('num_shards') is not None else 1
|
|
598
|
+
raise ValueError("When compression_type is provided, the number of dataset files cannot be less " +
|
|
599
|
+
"than num_shards, but the actual number of files is " + str(num_files) +
|
|
600
|
+
" and actual num_shards is " + str(act_num_shard) + ".")
|
|
601
|
+
if param_dict.get('shard_equal_rows') is None or not param_dict.get('shard_equal_rows'):
|
|
602
|
+
logger.warning("If compression_type is set, shard_equal_rows will be ignored.")
|
|
603
|
+
|
|
604
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
605
|
+
|
|
606
|
+
cache = param_dict.get('cache')
|
|
607
|
+
check_cache_option(cache)
|
|
608
|
+
|
|
609
|
+
return method(self, *args, **kwargs)
|
|
610
|
+
|
|
611
|
+
return new_method
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
def check_udpos_dataset(method):
|
|
615
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(UDPOSDataset)."""
|
|
616
|
+
|
|
617
|
+
@wraps(method)
|
|
618
|
+
def new_method(self, *args, **kwargs):
|
|
619
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
620
|
+
|
|
621
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
622
|
+
|
|
623
|
+
# check dataset_dir; required argument
|
|
624
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
625
|
+
check_dir(dataset_dir)
|
|
626
|
+
|
|
627
|
+
# check usage
|
|
628
|
+
usage = param_dict.get('usage')
|
|
629
|
+
if usage is not None:
|
|
630
|
+
check_valid_str(usage, ["train", "valid", "test", "all"], "usage")
|
|
631
|
+
|
|
632
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
633
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
634
|
+
|
|
635
|
+
cache = param_dict.get('cache')
|
|
636
|
+
check_cache_option(cache)
|
|
637
|
+
|
|
638
|
+
return method(self, *args, **kwargs)
|
|
639
|
+
|
|
640
|
+
return new_method
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
def check_usps_dataset(method):
|
|
644
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(USPSDataset)."""
|
|
645
|
+
|
|
646
|
+
@wraps(method)
|
|
647
|
+
def new_method(self, *args, **kwargs):
|
|
648
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
649
|
+
|
|
650
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
651
|
+
|
|
652
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
653
|
+
check_dir(dataset_dir)
|
|
654
|
+
|
|
655
|
+
usage = param_dict.get('usage')
|
|
656
|
+
if usage is not None:
|
|
657
|
+
check_valid_str(usage, ["train", "test", "all"], "usage")
|
|
658
|
+
|
|
659
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
660
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
661
|
+
|
|
662
|
+
cache = param_dict.get('cache')
|
|
663
|
+
check_cache_option(cache)
|
|
664
|
+
|
|
665
|
+
return method(self, *args, **kwargs)
|
|
666
|
+
|
|
667
|
+
return new_method
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
def check_caltech101_dataset(method):
|
|
671
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(Caltech101Dataset)."""
|
|
672
|
+
|
|
673
|
+
@wraps(method)
|
|
674
|
+
def new_method(self, *args, **kwargs):
|
|
675
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
676
|
+
|
|
677
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
678
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
679
|
+
nreq_param_str = ['target_type']
|
|
680
|
+
|
|
681
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
682
|
+
check_dir(dataset_dir)
|
|
683
|
+
|
|
684
|
+
target_type = param_dict.get('target_type')
|
|
685
|
+
if target_type is not None:
|
|
686
|
+
check_valid_str(target_type, ["category", "annotation", "all"], "target_type")
|
|
687
|
+
|
|
688
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
689
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
690
|
+
validate_dataset_param_value(nreq_param_str, param_dict, str)
|
|
691
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
692
|
+
|
|
693
|
+
cache = param_dict.get('cache')
|
|
694
|
+
check_cache_option(cache)
|
|
695
|
+
|
|
696
|
+
return method(self, *args, **kwargs)
|
|
697
|
+
|
|
698
|
+
return new_method
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
def check_caltech256_dataset(method):
|
|
702
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(Caltech256Dataset)."""
|
|
703
|
+
|
|
704
|
+
@wraps(method)
|
|
705
|
+
def new_method(self, *args, **kwargs):
|
|
706
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
707
|
+
|
|
708
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
709
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
710
|
+
|
|
711
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
712
|
+
check_dir(dataset_dir)
|
|
713
|
+
|
|
714
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
715
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
716
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
717
|
+
|
|
718
|
+
cache = param_dict.get('cache')
|
|
719
|
+
check_cache_option(cache)
|
|
720
|
+
|
|
721
|
+
return method(self, *args, **kwargs)
|
|
722
|
+
|
|
723
|
+
return new_method
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
def check_vocdataset(method):
|
|
727
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(VOCDataset)."""
|
|
728
|
+
|
|
729
|
+
@wraps(method)
|
|
730
|
+
def new_method(self, *args, **kwargs):
|
|
731
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
732
|
+
|
|
733
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
734
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
735
|
+
nreq_param_dict = ['class_indexing']
|
|
736
|
+
|
|
737
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
738
|
+
check_dir(dataset_dir)
|
|
739
|
+
|
|
740
|
+
task = param_dict.get('task')
|
|
741
|
+
type_check(task, (str,), "task")
|
|
742
|
+
|
|
743
|
+
usage = param_dict.get('usage')
|
|
744
|
+
type_check(usage, (str,), "usage")
|
|
745
|
+
dataset_dir = os.path.realpath(dataset_dir)
|
|
746
|
+
|
|
747
|
+
if task == "Segmentation":
|
|
748
|
+
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Segmentation", usage + ".txt")
|
|
749
|
+
if param_dict.get('class_indexing') is not None:
|
|
750
|
+
raise ValueError("class_indexing is not supported in Segmentation task.")
|
|
751
|
+
elif task == "Detection":
|
|
752
|
+
imagesets_file = os.path.join(dataset_dir, "ImageSets", "Main", usage + ".txt")
|
|
753
|
+
else:
|
|
754
|
+
raise ValueError("Invalid task : " + task + ".")
|
|
755
|
+
|
|
756
|
+
decrypt = param_dict.get('decrypt')
|
|
757
|
+
if decrypt is not None and not callable(decrypt):
|
|
758
|
+
raise TypeError("Argument decrypt is not a callable object, but got " + str(type(decrypt)))
|
|
759
|
+
|
|
760
|
+
check_file(imagesets_file)
|
|
761
|
+
|
|
762
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
763
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
764
|
+
validate_dataset_param_value(nreq_param_dict, param_dict, dict)
|
|
765
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
766
|
+
|
|
767
|
+
cache = param_dict.get('cache')
|
|
768
|
+
check_cache_option(cache)
|
|
769
|
+
|
|
770
|
+
return method(self, *args, **kwargs)
|
|
771
|
+
|
|
772
|
+
return new_method
|
|
773
|
+
|
|
774
|
+
|
|
775
|
+
def check_cocodataset(method):
|
|
776
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(CocoDataset)."""
|
|
777
|
+
|
|
778
|
+
@wraps(method)
|
|
779
|
+
def new_method(self, *args, **kwargs):
|
|
780
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
781
|
+
|
|
782
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
783
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
784
|
+
|
|
785
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
786
|
+
check_dir(dataset_dir)
|
|
787
|
+
|
|
788
|
+
annotation_file = param_dict.get('annotation_file')
|
|
789
|
+
check_file(annotation_file)
|
|
790
|
+
|
|
791
|
+
task = param_dict.get('task')
|
|
792
|
+
type_check(task, (str,), "task")
|
|
793
|
+
|
|
794
|
+
if task not in {'Detection', 'Stuff', 'Panoptic', 'Keypoint', 'Captioning'}:
|
|
795
|
+
raise ValueError("Invalid task type: " + task + ".")
|
|
796
|
+
|
|
797
|
+
decrypt = param_dict.get('decrypt')
|
|
798
|
+
if decrypt is not None and not callable(decrypt):
|
|
799
|
+
raise TypeError("Argument decrypt is not a callable object, but got " + str(type(decrypt)))
|
|
800
|
+
|
|
801
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
802
|
+
|
|
803
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
804
|
+
|
|
805
|
+
sampler = param_dict.get('sampler')
|
|
806
|
+
if sampler is not None and isinstance(sampler, samplers.PKSampler):
|
|
807
|
+
raise ValueError("CocoDataset doesn't support PKSampler.")
|
|
808
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
809
|
+
|
|
810
|
+
cache = param_dict.get('cache')
|
|
811
|
+
check_cache_option(cache)
|
|
812
|
+
|
|
813
|
+
return method(self, *args, **kwargs)
|
|
814
|
+
|
|
815
|
+
return new_method
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
def check_celebadataset(method):
|
|
819
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(CelebADataset)."""
|
|
820
|
+
|
|
821
|
+
@wraps(method)
|
|
822
|
+
def new_method(self, *args, **kwargs):
|
|
823
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
824
|
+
|
|
825
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
826
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
827
|
+
nreq_param_list = ['extensions']
|
|
828
|
+
nreq_param_str = ['dataset_type']
|
|
829
|
+
|
|
830
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
831
|
+
|
|
832
|
+
check_dir(dataset_dir)
|
|
833
|
+
|
|
834
|
+
decrypt = param_dict.get('decrypt')
|
|
835
|
+
if decrypt is not None and not callable(decrypt):
|
|
836
|
+
raise TypeError("Argument decrypt is not a callable object, but got " + str(type(decrypt)))
|
|
837
|
+
|
|
838
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
839
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
840
|
+
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
|
841
|
+
validate_dataset_param_value(nreq_param_str, param_dict, str)
|
|
842
|
+
|
|
843
|
+
usage = param_dict.get('usage')
|
|
844
|
+
if usage is not None and usage not in ('all', 'train', 'valid', 'test'):
|
|
845
|
+
raise ValueError("usage should be 'all', 'train', 'valid' or 'test'.")
|
|
846
|
+
|
|
847
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
848
|
+
|
|
849
|
+
sampler = param_dict.get('sampler')
|
|
850
|
+
if sampler is not None and isinstance(sampler, samplers.PKSampler):
|
|
851
|
+
raise ValueError("CelebADataset doesn't support PKSampler.")
|
|
852
|
+
|
|
853
|
+
cache = param_dict.get('cache')
|
|
854
|
+
check_cache_option(cache)
|
|
855
|
+
|
|
856
|
+
return method(self, *args, **kwargs)
|
|
857
|
+
|
|
858
|
+
return new_method
|
|
859
|
+
|
|
860
|
+
|
|
861
|
+
def check_libri_tts_dataset(method):
|
|
862
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(LibriTTSDataset)."""
|
|
863
|
+
|
|
864
|
+
@wraps(method)
|
|
865
|
+
def new_method(self, *args, **kwargs):
|
|
866
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
867
|
+
|
|
868
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
869
|
+
nreq_param_bool = ['shuffle']
|
|
870
|
+
|
|
871
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
872
|
+
check_dir(dataset_dir)
|
|
873
|
+
|
|
874
|
+
usage = param_dict.get('usage')
|
|
875
|
+
if usage is not None:
|
|
876
|
+
check_valid_str(usage, ["dev-clean", "dev-other", "test-clean", "test-other", "train-clean-100",
|
|
877
|
+
"train-clean-360", "train-other-500", "all"], "usage")
|
|
878
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
879
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
880
|
+
|
|
881
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
882
|
+
cache = param_dict.get('cache')
|
|
883
|
+
check_cache_option(cache)
|
|
884
|
+
|
|
885
|
+
return method(self, *args, **kwargs)
|
|
886
|
+
|
|
887
|
+
return new_method
|
|
888
|
+
|
|
889
|
+
|
|
890
|
+
def check_lj_speech_dataset(method):
|
|
891
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(LJSpeechDataset)."""
|
|
892
|
+
|
|
893
|
+
@wraps(method)
|
|
894
|
+
def new_method(self, *args, **kwargs):
|
|
895
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
896
|
+
|
|
897
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
898
|
+
nreq_param_bool = ['shuffle']
|
|
899
|
+
|
|
900
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
901
|
+
check_dir(dataset_dir)
|
|
902
|
+
|
|
903
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
904
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
905
|
+
|
|
906
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
907
|
+
|
|
908
|
+
cache = param_dict.get('cache')
|
|
909
|
+
check_cache_option(cache)
|
|
910
|
+
|
|
911
|
+
return method(self, *args, **kwargs)
|
|
912
|
+
|
|
913
|
+
return new_method
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
def check_lfw_dataset(method):
|
|
917
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(LFWDataset)."""
|
|
918
|
+
|
|
919
|
+
@wraps(method)
|
|
920
|
+
def new_method(self, *args, **kwargs):
|
|
921
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
922
|
+
|
|
923
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
924
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
925
|
+
|
|
926
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
927
|
+
check_dir(dataset_dir)
|
|
928
|
+
|
|
929
|
+
task = param_dict.get('task')
|
|
930
|
+
if task is not None:
|
|
931
|
+
check_valid_str(task, ["people", "pairs"], "task")
|
|
932
|
+
|
|
933
|
+
usage = param_dict.get('usage')
|
|
934
|
+
if usage is not None:
|
|
935
|
+
check_valid_str(usage, ["10fold", "train", "test", "all"], "usage")
|
|
936
|
+
|
|
937
|
+
image_set = param_dict.get('image_set')
|
|
938
|
+
if image_set is not None:
|
|
939
|
+
check_valid_str(image_set, ["original", "funneled", "deepfunneled"], "image_set")
|
|
940
|
+
|
|
941
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
942
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
943
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
944
|
+
|
|
945
|
+
cache = param_dict.get('cache')
|
|
946
|
+
check_cache_option(cache)
|
|
947
|
+
|
|
948
|
+
return method(self, *args, **kwargs)
|
|
949
|
+
|
|
950
|
+
return new_method
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
def check_save(method):
|
|
954
|
+
"""A wrapper that wraps a parameter checker around the saved operation."""
|
|
955
|
+
|
|
956
|
+
@wraps(method)
|
|
957
|
+
def new_method(self, *args, **kwargs):
|
|
958
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
959
|
+
|
|
960
|
+
nreq_param_int = ['num_files']
|
|
961
|
+
nreq_param_str = ['file_name', 'file_type']
|
|
962
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
963
|
+
if (param_dict.get('num_files') <= 0 or param_dict.get('num_files') > 1000):
|
|
964
|
+
raise ValueError("num_files should between 0 and 1000.")
|
|
965
|
+
validate_dataset_param_value(nreq_param_str, param_dict, str)
|
|
966
|
+
if param_dict.get('file_type') != 'mindrecord':
|
|
967
|
+
raise ValueError("{} dataset format is not supported.".format(param_dict.get('file_type')))
|
|
968
|
+
return method(self, *args, **kwargs)
|
|
969
|
+
|
|
970
|
+
return new_method
|
|
971
|
+
|
|
972
|
+
|
|
973
|
+
def check_tuple_iterator(method):
|
|
974
|
+
"""A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
|
|
975
|
+
|
|
976
|
+
@wraps(method)
|
|
977
|
+
def new_method(self, *args, **kwargs):
|
|
978
|
+
[columns, num_epochs, _, _], param_dict = parse_user_args(method, *args, **kwargs)
|
|
979
|
+
nreq_param_bool = ['output_numpy']
|
|
980
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
981
|
+
if num_epochs is not None:
|
|
982
|
+
type_check(num_epochs, (int,), "num_epochs")
|
|
983
|
+
check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
|
|
984
|
+
|
|
985
|
+
if columns is not None:
|
|
986
|
+
check_columns(columns, "column_names")
|
|
987
|
+
|
|
988
|
+
return method(self, *args, **kwargs)
|
|
989
|
+
|
|
990
|
+
return new_method
|
|
991
|
+
|
|
992
|
+
|
|
993
|
+
def check_dict_iterator(method):
|
|
994
|
+
"""A wrapper that wraps a parameter checker around the original create_tuple_iterator and create_dict_iterator."""
|
|
995
|
+
|
|
996
|
+
@wraps(method)
|
|
997
|
+
def new_method(self, *args, **kwargs):
|
|
998
|
+
[num_epochs, _, _], param_dict = parse_user_args(method, *args, **kwargs)
|
|
999
|
+
nreq_param_bool = ['output_numpy']
|
|
1000
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
1001
|
+
if num_epochs is not None:
|
|
1002
|
+
type_check(num_epochs, (int,), "num_epochs")
|
|
1003
|
+
check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
|
|
1004
|
+
|
|
1005
|
+
return method(self, *args, **kwargs)
|
|
1006
|
+
|
|
1007
|
+
return new_method
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
def check_minddataset(method):
|
|
1011
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(MindDataset)."""
|
|
1012
|
+
|
|
1013
|
+
@wraps(method)
|
|
1014
|
+
def new_method(self, *args, **kwargs):
|
|
1015
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
1016
|
+
|
|
1017
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'seed', 'num_shards', 'shard_id', 'num_padded']
|
|
1018
|
+
nreq_param_list = ['columns_list']
|
|
1019
|
+
nreq_param_dict = ['padded_sample']
|
|
1020
|
+
|
|
1021
|
+
dataset_file = param_dict.get('dataset_files')
|
|
1022
|
+
if isinstance(dataset_file, list):
|
|
1023
|
+
if len(dataset_file) > 4096:
|
|
1024
|
+
logger.warning("The number of MindRecord files greater than 4096"
|
|
1025
|
+
"may cause slow dataset initialization.")
|
|
1026
|
+
for f in dataset_file:
|
|
1027
|
+
check_file(f)
|
|
1028
|
+
else:
|
|
1029
|
+
check_file(dataset_file)
|
|
1030
|
+
|
|
1031
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
1032
|
+
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
|
1033
|
+
validate_dataset_param_value(nreq_param_dict, param_dict, dict)
|
|
1034
|
+
|
|
1035
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
1036
|
+
|
|
1037
|
+
check_padding_options(param_dict)
|
|
1038
|
+
check_cache_option(param_dict.get('cache'))
|
|
1039
|
+
return method(self, *args, **kwargs)
|
|
1040
|
+
|
|
1041
|
+
return new_method
|
|
1042
|
+
|
|
1043
|
+
|
|
1044
|
+
def check_source_function(source):
|
|
1045
|
+
"""Get used variable and source document in given function."""
|
|
1046
|
+
# check whether source is an instanced object of user defined class
|
|
1047
|
+
from types import FunctionType
|
|
1048
|
+
var = tuple()
|
|
1049
|
+
source_doc = ""
|
|
1050
|
+
if isinstance(source, FunctionType):
|
|
1051
|
+
try:
|
|
1052
|
+
var = ins.getclosurevars(source)
|
|
1053
|
+
source_doc = ins.getsource(source)
|
|
1054
|
+
except OSError:
|
|
1055
|
+
return ""
|
|
1056
|
+
else:
|
|
1057
|
+
try:
|
|
1058
|
+
source_attr = source.__class__.__dict__.keys()
|
|
1059
|
+
if '__init__' in source_attr:
|
|
1060
|
+
var = var + ins.getclosurevars(source.__class__.__init__)
|
|
1061
|
+
source_doc = source_doc + ins.getsource(source.__class__.__init__)
|
|
1062
|
+
if '__getitem__' in source_attr:
|
|
1063
|
+
var = var + ins.getclosurevars(source.__class__.__getitem__)
|
|
1064
|
+
source_doc = source_doc + ins.getsource(source.__class__.__getitem__)
|
|
1065
|
+
elif '__next__' in source_attr:
|
|
1066
|
+
var = var + ins.getclosurevars(source.__class__.__next__)
|
|
1067
|
+
source_doc = source_doc + ins.getsource(source.__class__.__next__)
|
|
1068
|
+
except (TypeError, OSError):
|
|
1069
|
+
# case: like input is LambdaType or GeneratorType, it will go to else branch, and unable to run normally
|
|
1070
|
+
pass
|
|
1071
|
+
return str(var) + source_doc
|
|
1072
|
+
|
|
1073
|
+
|
|
1074
|
+
def check_generatordataset(method):
|
|
1075
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(GeneratorDataset)."""
|
|
1076
|
+
|
|
1077
|
+
@wraps(method)
|
|
1078
|
+
def new_method(self, *args, **kwargs):
|
|
1079
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
1080
|
+
|
|
1081
|
+
source = param_dict.get('source')
|
|
1082
|
+
|
|
1083
|
+
if not callable(source):
|
|
1084
|
+
try:
|
|
1085
|
+
iter(source)
|
|
1086
|
+
except TypeError:
|
|
1087
|
+
raise TypeError("Input `source` function of GeneratorDataset should be callable, iterable or random"
|
|
1088
|
+
" accessible, commonly it should implement one of the method like yield, __getitem__ or"
|
|
1089
|
+
" __next__(__iter__).")
|
|
1090
|
+
|
|
1091
|
+
# check used variable and function document whether contain computing operator
|
|
1092
|
+
check_doc = check_source_function(source)
|
|
1093
|
+
check_list = ['mindspore.nn', 'mindspore.ops', 'mindspore.numpy', 'mindspore.compression']
|
|
1094
|
+
for item in check_list:
|
|
1095
|
+
if item in check_doc:
|
|
1096
|
+
setattr(self, 'operator_mixed', True)
|
|
1097
|
+
break
|
|
1098
|
+
|
|
1099
|
+
column_names = param_dict.get('column_names')
|
|
1100
|
+
if column_names is not None:
|
|
1101
|
+
check_columns(column_names, "column_names")
|
|
1102
|
+
schema = param_dict.get('schema')
|
|
1103
|
+
if column_names is None and schema is None:
|
|
1104
|
+
raise ValueError("Neither columns_names nor schema are provided.")
|
|
1105
|
+
|
|
1106
|
+
if schema is not None:
|
|
1107
|
+
if not isinstance(schema, (datasets.Schema, str)):
|
|
1108
|
+
raise ValueError("schema should be a path to schema file or a schema object.")
|
|
1109
|
+
|
|
1110
|
+
# check optional argument
|
|
1111
|
+
nreq_param_int = ["max_rowsize", "num_samples", "num_parallel_workers", "num_shards", "shard_id"]
|
|
1112
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
1113
|
+
nreq_param_list = ["column_types"]
|
|
1114
|
+
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
|
1115
|
+
nreq_param_bool = ["shuffle", "python_multiprocessing"]
|
|
1116
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
1117
|
+
|
|
1118
|
+
max_rowsize = param_dict.get("max_rowsize")
|
|
1119
|
+
if max_rowsize is not None:
|
|
1120
|
+
check_value(max_rowsize, [-1, INT32_MAX], "max_rowsize")
|
|
1121
|
+
|
|
1122
|
+
num_shards = param_dict.get("num_shards")
|
|
1123
|
+
shard_id = param_dict.get("shard_id")
|
|
1124
|
+
check_dataset_num_shards_shard_id(num_shards, shard_id)
|
|
1125
|
+
|
|
1126
|
+
sampler = param_dict.get("sampler")
|
|
1127
|
+
if sampler is not None:
|
|
1128
|
+
if isinstance(sampler, samplers.PKSampler):
|
|
1129
|
+
raise ValueError("GeneratorDataset doesn't support PKSampler.")
|
|
1130
|
+
if not isinstance(sampler, samplers.BuiltinSampler):
|
|
1131
|
+
try:
|
|
1132
|
+
iter(sampler)
|
|
1133
|
+
except TypeError:
|
|
1134
|
+
raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers.")
|
|
1135
|
+
|
|
1136
|
+
if sampler is not None and not hasattr(source, "__getitem__"):
|
|
1137
|
+
raise ValueError("sampler is not supported if source does not have attribute '__getitem__'.")
|
|
1138
|
+
if num_shards is not None and not hasattr(source, "__getitem__"):
|
|
1139
|
+
raise ValueError("num_shards is not supported if source does not have attribute '__getitem__'.")
|
|
1140
|
+
|
|
1141
|
+
return method(self, *args, **kwargs)
|
|
1142
|
+
|
|
1143
|
+
return new_method
|
|
1144
|
+
|
|
1145
|
+
|
|
1146
|
+
def check_random_dataset(method):
|
|
1147
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(RandomDataset)."""
|
|
1148
|
+
|
|
1149
|
+
@wraps(method)
|
|
1150
|
+
def new_method(self, *args, **kwargs):
|
|
1151
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
1152
|
+
|
|
1153
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id', 'total_rows']
|
|
1154
|
+
nreq_param_bool = ['shuffle']
|
|
1155
|
+
nreq_param_list = ['columns_list']
|
|
1156
|
+
|
|
1157
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
1158
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
1159
|
+
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
|
1160
|
+
|
|
1161
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
1162
|
+
|
|
1163
|
+
cache = param_dict.get('cache')
|
|
1164
|
+
check_cache_option(cache)
|
|
1165
|
+
|
|
1166
|
+
return method(self, *args, **kwargs)
|
|
1167
|
+
|
|
1168
|
+
return new_method
|
|
1169
|
+
|
|
1170
|
+
|
|
1171
|
+
def check_rendered_sst2_dataset(method):
|
|
1172
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(RenderedSST2Dataset)."""
|
|
1173
|
+
|
|
1174
|
+
@wraps(method)
|
|
1175
|
+
def new_method(self, *args, **kwargs):
|
|
1176
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
1177
|
+
|
|
1178
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
1179
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
1180
|
+
|
|
1181
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
1182
|
+
usage = param_dict.get('usage')
|
|
1183
|
+
check_dir(dataset_dir)
|
|
1184
|
+
if usage is not None:
|
|
1185
|
+
check_valid_str(usage, ['val', 'all', 'train', 'test'])
|
|
1186
|
+
|
|
1187
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
1188
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
1189
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
1190
|
+
|
|
1191
|
+
cache = param_dict.get('cache')
|
|
1192
|
+
check_cache_option(cache)
|
|
1193
|
+
|
|
1194
|
+
return method(self, *args, **kwargs)
|
|
1195
|
+
|
|
1196
|
+
return new_method
|
|
1197
|
+
|
|
1198
|
+
|
|
1199
|
+
def check_pad_info(key, val):
|
|
1200
|
+
"""check the key and value pair of pad_info in batch"""
|
|
1201
|
+
type_check(key, (str,), "key in pad_info")
|
|
1202
|
+
|
|
1203
|
+
if val is not None:
|
|
1204
|
+
if len(val) != 2:
|
|
1205
|
+
raise ValueError("value of pad_info should be a tuple of size 2.")
|
|
1206
|
+
type_check(val, (tuple,), "value in pad_info")
|
|
1207
|
+
|
|
1208
|
+
if val[0] is not None:
|
|
1209
|
+
type_check(val[0], (list,), "shape in pad_info")
|
|
1210
|
+
|
|
1211
|
+
for dim in val[0]:
|
|
1212
|
+
if dim is not None:
|
|
1213
|
+
check_pos_int32(dim, "dim of shape in pad_info")
|
|
1214
|
+
if val[1] is not None:
|
|
1215
|
+
type_check(val[1], (int, float, str, bytes), "pad_value")
|
|
1216
|
+
|
|
1217
|
+
|
|
1218
|
+
def check_bucket_batch_by_length(method):
|
|
1219
|
+
"""check the input arguments of bucket_batch_by_length."""
|
|
1220
|
+
|
|
1221
|
+
@wraps(method)
|
|
1222
|
+
def new_method(self, *args, **kwargs):
|
|
1223
|
+
[column_names, bucket_boundaries, bucket_batch_sizes, element_length_function, pad_info,
|
|
1224
|
+
pad_to_bucket_boundary, drop_remainder], _ = parse_user_args(method, *args, **kwargs)
|
|
1225
|
+
|
|
1226
|
+
nreq_param_list = ['column_names', 'bucket_boundaries', 'bucket_batch_sizes']
|
|
1227
|
+
|
|
1228
|
+
type_check_list([column_names, bucket_boundaries, bucket_batch_sizes], (list,), nreq_param_list)
|
|
1229
|
+
|
|
1230
|
+
nbool_param_list = ['pad_to_bucket_boundary', 'drop_remainder']
|
|
1231
|
+
type_check_list([pad_to_bucket_boundary, drop_remainder], (bool,), nbool_param_list)
|
|
1232
|
+
|
|
1233
|
+
# check column_names: must be list of string.
|
|
1234
|
+
check_columns(column_names, "column_names")
|
|
1235
|
+
|
|
1236
|
+
if element_length_function is None and len(column_names) != 1:
|
|
1237
|
+
raise ValueError("If element_length_function is not specified, exactly one column name should be passed.")
|
|
1238
|
+
|
|
1239
|
+
if element_length_function is not None and not callable(element_length_function):
|
|
1240
|
+
raise TypeError("element_length_function object is not callable.")
|
|
1241
|
+
|
|
1242
|
+
# check bucket_boundaries: must be list of int, positive and strictly increasing
|
|
1243
|
+
if not bucket_boundaries:
|
|
1244
|
+
raise ValueError("bucket_boundaries cannot be empty.")
|
|
1245
|
+
|
|
1246
|
+
all_int = all(isinstance(item, int) for item in bucket_boundaries)
|
|
1247
|
+
if not all_int:
|
|
1248
|
+
raise TypeError("bucket_boundaries should be a list of int.")
|
|
1249
|
+
|
|
1250
|
+
all_non_negative = all(item > 0 for item in bucket_boundaries)
|
|
1251
|
+
if not all_non_negative:
|
|
1252
|
+
raise ValueError("bucket_boundaries must only contain positive numbers.")
|
|
1253
|
+
|
|
1254
|
+
for i in range(len(bucket_boundaries) - 1):
|
|
1255
|
+
if not bucket_boundaries[i + 1] > bucket_boundaries[i]:
|
|
1256
|
+
raise ValueError("bucket_boundaries should be strictly increasing.")
|
|
1257
|
+
|
|
1258
|
+
# check bucket_batch_sizes: must be list of int and positive
|
|
1259
|
+
if len(bucket_batch_sizes) != len(bucket_boundaries) + 1:
|
|
1260
|
+
raise ValueError("bucket_batch_sizes must contain one element more than bucket_boundaries.")
|
|
1261
|
+
|
|
1262
|
+
all_int = all(isinstance(item, int) for item in bucket_batch_sizes)
|
|
1263
|
+
if not all_int:
|
|
1264
|
+
raise TypeError("bucket_batch_sizes should be a list of int.")
|
|
1265
|
+
|
|
1266
|
+
all_non_negative = all(item > 0 for item in bucket_batch_sizes)
|
|
1267
|
+
if not all_non_negative:
|
|
1268
|
+
raise ValueError("bucket_batch_sizes should be a list of positive numbers.")
|
|
1269
|
+
|
|
1270
|
+
if pad_info is not None:
|
|
1271
|
+
type_check(pad_info, (dict,), "pad_info")
|
|
1272
|
+
|
|
1273
|
+
for k, v in pad_info.items():
|
|
1274
|
+
check_pad_info(k, v)
|
|
1275
|
+
|
|
1276
|
+
return method(self, *args, **kwargs)
|
|
1277
|
+
|
|
1278
|
+
return new_method
|
|
1279
|
+
|
|
1280
|
+
|
|
1281
|
+
def get_batch_kwargs_from_dict(param_dict):
|
|
1282
|
+
"""get batch operation kwargs parameters."""
|
|
1283
|
+
if param_dict is not None:
|
|
1284
|
+
per_batch_map = param_dict.get("per_batch_map", None)
|
|
1285
|
+
input_columns = param_dict.get("input_columns", None)
|
|
1286
|
+
output_columns = param_dict.get("output_columns", None)
|
|
1287
|
+
python_multiprocessing = param_dict.get("python_multiprocessing", False)
|
|
1288
|
+
max_rowsize = param_dict.get("max_rowsize", None)
|
|
1289
|
+
return per_batch_map, input_columns, output_columns, python_multiprocessing, max_rowsize
|
|
1290
|
+
|
|
1291
|
+
|
|
1292
|
+
def check_batch(method):
|
|
1293
|
+
"""check the input arguments of batch."""
|
|
1294
|
+
|
|
1295
|
+
@wraps(method)
|
|
1296
|
+
def new_method(self, *args, **kwargs):
|
|
1297
|
+
[batch_size, drop_remainder, num_parallel_workers, param_dict], _ = parse_user_args(method, *args, **kwargs)
|
|
1298
|
+
|
|
1299
|
+
(per_batch_map, input_columns, output_columns, python_multiprocessing, max_rowsize) = \
|
|
1300
|
+
get_batch_kwargs_from_dict(param_dict)
|
|
1301
|
+
|
|
1302
|
+
if not (isinstance(batch_size, int) or (callable(batch_size))):
|
|
1303
|
+
raise TypeError("batch_size should either be an int or a callable.")
|
|
1304
|
+
|
|
1305
|
+
if callable(batch_size):
|
|
1306
|
+
sig = ins.signature(batch_size)
|
|
1307
|
+
if len(sig.parameters) != 1:
|
|
1308
|
+
raise ValueError("callable batch_size should take one parameter (BatchInfo).")
|
|
1309
|
+
else:
|
|
1310
|
+
check_pos_int32(int(batch_size), "batch_size")
|
|
1311
|
+
|
|
1312
|
+
if num_parallel_workers is not None:
|
|
1313
|
+
check_num_parallel_workers(num_parallel_workers)
|
|
1314
|
+
type_check(drop_remainder, (bool,), "drop_remainder")
|
|
1315
|
+
|
|
1316
|
+
check_max_rowsize(max_rowsize)
|
|
1317
|
+
|
|
1318
|
+
if (input_columns is not None) and (per_batch_map is None):
|
|
1319
|
+
# input_columns must be None when per_batch_map is not set
|
|
1320
|
+
raise ValueError("input_columns can be specified only when per_batch_map is set.")
|
|
1321
|
+
|
|
1322
|
+
if input_columns is not None:
|
|
1323
|
+
check_columns(input_columns, "input_columns")
|
|
1324
|
+
if len(input_columns) != (len(ins.signature(per_batch_map).parameters) - 1):
|
|
1325
|
+
raise ValueError("The signature of per_batch_map should match with input columns.")
|
|
1326
|
+
|
|
1327
|
+
if output_columns is not None:
|
|
1328
|
+
check_columns(output_columns, "output_columns")
|
|
1329
|
+
|
|
1330
|
+
if python_multiprocessing is not None:
|
|
1331
|
+
type_check(python_multiprocessing, (bool,), "python_multiprocessing")
|
|
1332
|
+
|
|
1333
|
+
return method(self, *args, **kwargs)
|
|
1334
|
+
|
|
1335
|
+
return new_method
|
|
1336
|
+
|
|
1337
|
+
|
|
1338
|
+
def check_padded_batch(method):
|
|
1339
|
+
"""check the input arguments of padded_batch."""
|
|
1340
|
+
|
|
1341
|
+
@wraps(method)
|
|
1342
|
+
def new_method(self, *args, **kwargs):
|
|
1343
|
+
[batch_size, drop_remainder, num_parallel_workers, pad_info], _ = parse_user_args(method, *args, **kwargs)
|
|
1344
|
+
|
|
1345
|
+
if not (isinstance(batch_size, int) or (callable(batch_size))):
|
|
1346
|
+
raise TypeError("batch_size should either be an int or a callable.")
|
|
1347
|
+
|
|
1348
|
+
if callable(batch_size):
|
|
1349
|
+
sig = ins.signature(batch_size)
|
|
1350
|
+
if len(sig.parameters) != 1:
|
|
1351
|
+
raise ValueError("callable batch_size should take one parameter (BatchInfo).")
|
|
1352
|
+
else:
|
|
1353
|
+
check_pos_int32(int(batch_size), "batch_size")
|
|
1354
|
+
|
|
1355
|
+
if num_parallel_workers is not None:
|
|
1356
|
+
check_num_parallel_workers(num_parallel_workers)
|
|
1357
|
+
type_check(drop_remainder, (bool,), "drop_remainder")
|
|
1358
|
+
|
|
1359
|
+
if pad_info is not None:
|
|
1360
|
+
type_check(pad_info, (dict,), "pad_info")
|
|
1361
|
+
for k, v in pad_info.items():
|
|
1362
|
+
check_pad_info(k, v)
|
|
1363
|
+
|
|
1364
|
+
return method(self, *args, **kwargs)
|
|
1365
|
+
|
|
1366
|
+
return new_method
|
|
1367
|
+
|
|
1368
|
+
|
|
1369
|
+
def check_sync_wait(method):
|
|
1370
|
+
"""check the input arguments of sync_wait."""
|
|
1371
|
+
|
|
1372
|
+
@wraps(method)
|
|
1373
|
+
def new_method(self, *args, **kwargs):
|
|
1374
|
+
[condition_name, num_batch, _], _ = parse_user_args(method, *args, **kwargs)
|
|
1375
|
+
|
|
1376
|
+
type_check(condition_name, (str,), "condition_name")
|
|
1377
|
+
type_check(num_batch, (int,), "num_batch")
|
|
1378
|
+
|
|
1379
|
+
check_independent_mode("Dataset sync wait")
|
|
1380
|
+
|
|
1381
|
+
return method(self, *args, **kwargs)
|
|
1382
|
+
|
|
1383
|
+
return new_method
|
|
1384
|
+
|
|
1385
|
+
|
|
1386
|
+
def check_sync_update(method):
|
|
1387
|
+
"""check the input arguments of sync_update."""
|
|
1388
|
+
|
|
1389
|
+
@wraps(method)
|
|
1390
|
+
def new_method(self, *args, **kwargs):
|
|
1391
|
+
check_independent_mode("Dataset sync update")
|
|
1392
|
+
return method(self, *args, **kwargs)
|
|
1393
|
+
|
|
1394
|
+
return new_method
|
|
1395
|
+
|
|
1396
|
+
|
|
1397
|
+
def check_shuffle(method):
|
|
1398
|
+
"""check the input arguments of shuffle."""
|
|
1399
|
+
|
|
1400
|
+
@wraps(method)
|
|
1401
|
+
def new_method(self, *args, **kwargs):
|
|
1402
|
+
[buffer_size], _ = parse_user_args(method, *args, **kwargs)
|
|
1403
|
+
|
|
1404
|
+
type_check(buffer_size, (int,), "buffer_size")
|
|
1405
|
+
|
|
1406
|
+
check_value(buffer_size, [2, INT32_MAX], "buffer_size")
|
|
1407
|
+
|
|
1408
|
+
return method(self, *args, **kwargs)
|
|
1409
|
+
|
|
1410
|
+
return new_method
|
|
1411
|
+
|
|
1412
|
+
|
|
1413
|
+
def get_map_kwargs_from_dict(param_dict):
|
|
1414
|
+
"""get map operation kwargs parameters."""
|
|
1415
|
+
if param_dict is not None:
|
|
1416
|
+
python_multiprocessing = param_dict.get("python_multiprocessing", False)
|
|
1417
|
+
max_rowsize = param_dict.get("max_rowsize", None)
|
|
1418
|
+
cache = param_dict.get("cache", None)
|
|
1419
|
+
callbacks = param_dict.get("callbacks", None)
|
|
1420
|
+
offload = param_dict.get("offload", None)
|
|
1421
|
+
return python_multiprocessing, max_rowsize, cache, callbacks, offload
|
|
1422
|
+
|
|
1423
|
+
|
|
1424
|
+
def check_max_rowsize(max_rowsize):
|
|
1425
|
+
"""check the max_rowsize"""
|
|
1426
|
+
if max_rowsize is not None:
|
|
1427
|
+
type_check(max_rowsize, (int, list), "max_rowsize")
|
|
1428
|
+
if isinstance(max_rowsize, int):
|
|
1429
|
+
type_check(max_rowsize, (int,), "max_rowsize")
|
|
1430
|
+
check_value(max_rowsize, [-1, INT32_MAX], "max_rowsize")
|
|
1431
|
+
elif isinstance(max_rowsize, list) and len(max_rowsize) == 2:
|
|
1432
|
+
for index, value in enumerate(max_rowsize):
|
|
1433
|
+
type_check(value, (int,), "max_rowsize[{}]".format(index))
|
|
1434
|
+
check_value(value, [-1, INT32_MAX], "max_rowsizei[{}]".format(index))
|
|
1435
|
+
else:
|
|
1436
|
+
raise TypeError("max_rowsize should be a single integer or a list[in_rowsize, out_rowsize] of length 2.")
|
|
1437
|
+
|
|
1438
|
+
|
|
1439
|
+
def check_map(method):
|
|
1440
|
+
"""check the input arguments of map."""
|
|
1441
|
+
|
|
1442
|
+
@wraps(method)
|
|
1443
|
+
def new_method(self, *args, **kwargs):
|
|
1444
|
+
from mindspore.dataset.callback import DSCallback
|
|
1445
|
+
[operations, input_columns, output_columns, column_order, num_parallel_workers, param_dict], _ = \
|
|
1446
|
+
parse_user_args(method, *args, **kwargs)
|
|
1447
|
+
|
|
1448
|
+
if column_order is not None:
|
|
1449
|
+
raise ValueError("The parameter 'column_order' had been deleted in map operation. "
|
|
1450
|
+
"Please use '.project' operation instead.\n"
|
|
1451
|
+
">> # Usage of old api:\n"
|
|
1452
|
+
">> dataset = dataset.map(operations=PyFunc,\n"
|
|
1453
|
+
">> input_columns=[\"column_a\"],\n"
|
|
1454
|
+
">> output_columns=[\"column_b\", \"column_c\"],\n"
|
|
1455
|
+
">> column_order=[\"column_b\", \"column_c\"])\n"
|
|
1456
|
+
">> # Usage of new api:\n"
|
|
1457
|
+
">> dataset = dataset.map(operations=PyFunc,\n"
|
|
1458
|
+
">> input_columns=[\"column_a\"],\n"
|
|
1459
|
+
">> output_columns=[\"column_b\", \"column_c\"])\n"
|
|
1460
|
+
">> dataset = dataset.project([\"column_b\", \"column_c\"])")
|
|
1461
|
+
|
|
1462
|
+
(python_multiprocessing, max_rowsize, cache, callbacks, offload) = get_map_kwargs_from_dict(param_dict)
|
|
1463
|
+
|
|
1464
|
+
# check whether network computing operator exist in input operations(python function)
|
|
1465
|
+
# check used variable and function document whether contain computing operator
|
|
1466
|
+
from types import FunctionType
|
|
1467
|
+
if isinstance(operations, FunctionType):
|
|
1468
|
+
try:
|
|
1469
|
+
var = ins.getclosurevars(operations)
|
|
1470
|
+
operations_doc = ins.getsource(operations)
|
|
1471
|
+
check_list = ['mindspore.nn', 'mindspore.ops', 'mindspore.numpy', 'mindspore.compression']
|
|
1472
|
+
check_doc = str(var) + operations_doc
|
|
1473
|
+
for item in check_list:
|
|
1474
|
+
if item in check_doc:
|
|
1475
|
+
setattr(self, 'operator_mixed', True)
|
|
1476
|
+
break
|
|
1477
|
+
except OSError:
|
|
1478
|
+
pass
|
|
1479
|
+
|
|
1480
|
+
operations = operations if isinstance(operations, list) else [operations]
|
|
1481
|
+
# import nn and ops locally for type check
|
|
1482
|
+
from mindspore import nn, ops
|
|
1483
|
+
for item in operations:
|
|
1484
|
+
if isinstance(item, (nn.Cell, ops.Primitive)):
|
|
1485
|
+
raise ValueError("Input operations should not contain network computing operator like in "
|
|
1486
|
+
"mindspore.nn or mindspore.ops, got operation: ", str(item))
|
|
1487
|
+
|
|
1488
|
+
nreq_param_columns = ['input_columns', 'output_columns']
|
|
1489
|
+
|
|
1490
|
+
if num_parallel_workers is not None:
|
|
1491
|
+
check_num_parallel_workers(num_parallel_workers)
|
|
1492
|
+
type_check(python_multiprocessing, (bool,), "python_multiprocessing")
|
|
1493
|
+
check_cache_option(cache)
|
|
1494
|
+
check_max_rowsize(max_rowsize)
|
|
1495
|
+
if offload is not None:
|
|
1496
|
+
type_check(offload, (bool,), "offload")
|
|
1497
|
+
check_independent_mode("Dataset Offload", offload)
|
|
1498
|
+
|
|
1499
|
+
if callbacks is not None:
|
|
1500
|
+
if isinstance(callbacks, (list, tuple)):
|
|
1501
|
+
type_check_list(callbacks, (DSCallback,), "callbacks")
|
|
1502
|
+
else:
|
|
1503
|
+
type_check(callbacks, (DSCallback,), "callbacks")
|
|
1504
|
+
check_independent_mode("Dataset Callbacks")
|
|
1505
|
+
|
|
1506
|
+
for param_name, param in zip(nreq_param_columns, [input_columns, output_columns]):
|
|
1507
|
+
if param is not None:
|
|
1508
|
+
check_columns(param, param_name)
|
|
1509
|
+
if callbacks is not None:
|
|
1510
|
+
type_check(callbacks, (list, DSCallback), "callbacks")
|
|
1511
|
+
|
|
1512
|
+
return method(self, *args, **kwargs)
|
|
1513
|
+
|
|
1514
|
+
return new_method
|
|
1515
|
+
|
|
1516
|
+
|
|
1517
|
+
def check_filter(method):
|
|
1518
|
+
""""check the input arguments of filter."""
|
|
1519
|
+
|
|
1520
|
+
@wraps(method)
|
|
1521
|
+
def new_method(self, *args, **kwargs):
|
|
1522
|
+
[predicate, input_columns, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs)
|
|
1523
|
+
if not callable(predicate):
|
|
1524
|
+
raise TypeError("Predicate should be a Python function or a callable Python object.")
|
|
1525
|
+
|
|
1526
|
+
if num_parallel_workers is not None:
|
|
1527
|
+
check_num_parallel_workers(num_parallel_workers)
|
|
1528
|
+
|
|
1529
|
+
if input_columns is not None:
|
|
1530
|
+
check_columns(input_columns, "input_columns")
|
|
1531
|
+
|
|
1532
|
+
return method(self, *args, **kwargs)
|
|
1533
|
+
|
|
1534
|
+
return new_method
|
|
1535
|
+
|
|
1536
|
+
|
|
1537
|
+
def check_repeat(method):
|
|
1538
|
+
"""check the input arguments of repeat."""
|
|
1539
|
+
|
|
1540
|
+
@wraps(method)
|
|
1541
|
+
def new_method(self, *args, **kwargs):
|
|
1542
|
+
[count], _ = parse_user_args(method, *args, **kwargs)
|
|
1543
|
+
|
|
1544
|
+
type_check(count, (int, type(None)), "repeat")
|
|
1545
|
+
if isinstance(count, int):
|
|
1546
|
+
if (count <= 0 and count != -1) or count > INT32_MAX:
|
|
1547
|
+
raise ValueError("count should be either -1 or positive integer, range[1, INT32_MAX].")
|
|
1548
|
+
return method(self, *args, **kwargs)
|
|
1549
|
+
|
|
1550
|
+
return new_method
|
|
1551
|
+
|
|
1552
|
+
|
|
1553
|
+
def check_skip(method):
|
|
1554
|
+
"""check the input arguments of skip."""
|
|
1555
|
+
|
|
1556
|
+
@wraps(method)
|
|
1557
|
+
def new_method(self, *args, **kwargs):
|
|
1558
|
+
[count], _ = parse_user_args(method, *args, **kwargs)
|
|
1559
|
+
|
|
1560
|
+
type_check(count, (int,), "count")
|
|
1561
|
+
check_value(count, (0, INT32_MAX), "count")
|
|
1562
|
+
|
|
1563
|
+
return method(self, *args, **kwargs)
|
|
1564
|
+
|
|
1565
|
+
return new_method
|
|
1566
|
+
|
|
1567
|
+
|
|
1568
|
+
def check_take(method):
|
|
1569
|
+
"""check the input arguments of take."""
|
|
1570
|
+
|
|
1571
|
+
@wraps(method)
|
|
1572
|
+
def new_method(self, *args, **kwargs):
|
|
1573
|
+
[count], _ = parse_user_args(method, *args, **kwargs)
|
|
1574
|
+
type_check(count, (int,), "count")
|
|
1575
|
+
if (count <= 0 and count != -1) or count > INT32_MAX:
|
|
1576
|
+
raise ValueError("count should be either -1 or within the required interval of ({}, {}], got {}."
|
|
1577
|
+
.format(0, INT32_MAX, count))
|
|
1578
|
+
|
|
1579
|
+
return method(self, *args, **kwargs)
|
|
1580
|
+
|
|
1581
|
+
return new_method
|
|
1582
|
+
|
|
1583
|
+
|
|
1584
|
+
def check_positive_int32(method):
|
|
1585
|
+
"""check whether the input argument is positive and int, only works for functions with one input."""
|
|
1586
|
+
|
|
1587
|
+
@wraps(method)
|
|
1588
|
+
def new_method(self, *args, **kwargs):
|
|
1589
|
+
[count], param_dict = parse_user_args(method, *args, **kwargs)
|
|
1590
|
+
para_name = None
|
|
1591
|
+
for key in list(param_dict.keys()):
|
|
1592
|
+
if key not in ['self', 'cls']:
|
|
1593
|
+
para_name = key
|
|
1594
|
+
# Need to get default value of param
|
|
1595
|
+
if count is not None:
|
|
1596
|
+
check_pos_int32(count, para_name)
|
|
1597
|
+
|
|
1598
|
+
return method(self, *args, **kwargs)
|
|
1599
|
+
|
|
1600
|
+
return new_method
|
|
1601
|
+
|
|
1602
|
+
|
|
1603
|
+
def check_device_send(method):
|
|
1604
|
+
"""check the input argument of device_que."""
|
|
1605
|
+
|
|
1606
|
+
@wraps(method)
|
|
1607
|
+
def new_method(self, *args, **kwargs):
|
|
1608
|
+
[send_epoch_end, create_data_info_queue, queue_name], _ = parse_user_args(method, *args, **kwargs)
|
|
1609
|
+
type_check(send_epoch_end, (bool,), "send_epoch_end")
|
|
1610
|
+
type_check(create_data_info_queue, (bool,), "create_data_info_queue")
|
|
1611
|
+
type_check(queue_name, (str,), "queue_name")
|
|
1612
|
+
|
|
1613
|
+
return method(self, *args, **kwargs)
|
|
1614
|
+
|
|
1615
|
+
return new_method
|
|
1616
|
+
|
|
1617
|
+
|
|
1618
|
+
def check_total_batch(total_batch):
|
|
1619
|
+
check_int32(total_batch, "total_batch")
|
|
1620
|
+
|
|
1621
|
+
|
|
1622
|
+
def check_zip(method):
|
|
1623
|
+
"""check the input arguments of zip."""
|
|
1624
|
+
|
|
1625
|
+
@wraps(method)
|
|
1626
|
+
def new_method(*args, **kwargs):
|
|
1627
|
+
[ds], _ = parse_user_args(method, *args, **kwargs)
|
|
1628
|
+
type_check(ds, (tuple,), "datasets")
|
|
1629
|
+
|
|
1630
|
+
return method(*args, **kwargs)
|
|
1631
|
+
|
|
1632
|
+
return new_method
|
|
1633
|
+
|
|
1634
|
+
|
|
1635
|
+
def check_zip_dataset(method):
|
|
1636
|
+
"""check the input arguments of zip method in `Dataset` ."""
|
|
1637
|
+
|
|
1638
|
+
@wraps(method)
|
|
1639
|
+
def new_method(self, *args, **kwargs):
|
|
1640
|
+
[ds], _ = parse_user_args(method, *args, **kwargs)
|
|
1641
|
+
type_check(ds, (tuple, datasets.Dataset), "datasets")
|
|
1642
|
+
|
|
1643
|
+
return method(self, *args, **kwargs)
|
|
1644
|
+
|
|
1645
|
+
return new_method
|
|
1646
|
+
|
|
1647
|
+
|
|
1648
|
+
def check_concat(method):
|
|
1649
|
+
"""check the input arguments of concat method in `Dataset` ."""
|
|
1650
|
+
|
|
1651
|
+
@wraps(method)
|
|
1652
|
+
def new_method(self, *args, **kwargs):
|
|
1653
|
+
[ds], _ = parse_user_args(method, *args, **kwargs)
|
|
1654
|
+
type_check(ds, (list, datasets.Dataset), "datasets")
|
|
1655
|
+
if isinstance(ds, list):
|
|
1656
|
+
type_check_list(ds, (datasets.Dataset,), "dataset")
|
|
1657
|
+
return method(self, *args, **kwargs)
|
|
1658
|
+
|
|
1659
|
+
return new_method
|
|
1660
|
+
|
|
1661
|
+
|
|
1662
|
+
def check_rename(method):
|
|
1663
|
+
"""check the input arguments of rename."""
|
|
1664
|
+
|
|
1665
|
+
@wraps(method)
|
|
1666
|
+
def new_method(self, *args, **kwargs):
|
|
1667
|
+
values, _ = parse_user_args(method, *args, **kwargs)
|
|
1668
|
+
|
|
1669
|
+
req_param_columns = ['input_columns', 'output_columns']
|
|
1670
|
+
for param_name, param in zip(req_param_columns, values):
|
|
1671
|
+
check_columns(param, param_name)
|
|
1672
|
+
|
|
1673
|
+
input_size, output_size = 1, 1
|
|
1674
|
+
input_columns, output_columns = values
|
|
1675
|
+
if isinstance(input_columns, list):
|
|
1676
|
+
input_size = len(input_columns)
|
|
1677
|
+
if isinstance(output_columns, list):
|
|
1678
|
+
output_size = len(output_columns)
|
|
1679
|
+
if input_size != output_size:
|
|
1680
|
+
raise ValueError("Number of column in input_columns and output_columns is not equal.")
|
|
1681
|
+
|
|
1682
|
+
return method(self, *args, **kwargs)
|
|
1683
|
+
|
|
1684
|
+
return new_method
|
|
1685
|
+
|
|
1686
|
+
|
|
1687
|
+
def check_output_shape(method):
|
|
1688
|
+
"""check the input arguments of output_shape."""
|
|
1689
|
+
|
|
1690
|
+
@wraps(method)
|
|
1691
|
+
def new_method(self, *args, **kwargs):
|
|
1692
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
1693
|
+
estimate = param_dict.get('estimate')
|
|
1694
|
+
type_check(estimate, (bool,), "estimate")
|
|
1695
|
+
|
|
1696
|
+
return method(self, *args, **kwargs)
|
|
1697
|
+
|
|
1698
|
+
return new_method
|
|
1699
|
+
|
|
1700
|
+
|
|
1701
|
+
def check_project(method):
|
|
1702
|
+
"""check the input arguments of project."""
|
|
1703
|
+
|
|
1704
|
+
@wraps(method)
|
|
1705
|
+
def new_method(self, *args, **kwargs):
|
|
1706
|
+
[columns], _ = parse_user_args(method, *args, **kwargs)
|
|
1707
|
+
check_columns(columns, 'columns')
|
|
1708
|
+
|
|
1709
|
+
return method(self, *args, **kwargs)
|
|
1710
|
+
|
|
1711
|
+
return new_method
|
|
1712
|
+
|
|
1713
|
+
|
|
1714
|
+
def check_schema(method):
|
|
1715
|
+
"""check the input arguments of Schema.__init__."""
|
|
1716
|
+
|
|
1717
|
+
@wraps(method)
|
|
1718
|
+
def new_method(self, *args, **kwargs):
|
|
1719
|
+
[schema_file], _ = parse_user_args(method, *args, **kwargs)
|
|
1720
|
+
|
|
1721
|
+
if schema_file is not None:
|
|
1722
|
+
check_file(schema_file)
|
|
1723
|
+
|
|
1724
|
+
return method(self, *args, **kwargs)
|
|
1725
|
+
|
|
1726
|
+
return new_method
|
|
1727
|
+
|
|
1728
|
+
|
|
1729
|
+
def check_add_column(method):
|
|
1730
|
+
"""check the input arguments of add_column."""
|
|
1731
|
+
|
|
1732
|
+
@wraps(method)
|
|
1733
|
+
def new_method(self, *args, **kwargs):
|
|
1734
|
+
[name, de_type, shape], _ = parse_user_args(method, *args, **kwargs)
|
|
1735
|
+
|
|
1736
|
+
type_check(name, (str,), "name")
|
|
1737
|
+
|
|
1738
|
+
if not name:
|
|
1739
|
+
raise TypeError("Expected non-empty string for column name.")
|
|
1740
|
+
|
|
1741
|
+
if de_type is not None:
|
|
1742
|
+
if not isinstance(de_type, typing.Type) and not check_valid_detype(de_type):
|
|
1743
|
+
raise TypeError("Unknown column type: {}.".format(de_type))
|
|
1744
|
+
else:
|
|
1745
|
+
raise TypeError("Expected non-empty string for de_type.")
|
|
1746
|
+
|
|
1747
|
+
if shape is not None:
|
|
1748
|
+
type_check(shape, (list,), "shape")
|
|
1749
|
+
type_check_list(shape, (int,), "shape")
|
|
1750
|
+
|
|
1751
|
+
return method(self, *args, **kwargs)
|
|
1752
|
+
|
|
1753
|
+
return new_method
|
|
1754
|
+
|
|
1755
|
+
|
|
1756
|
+
def check_cluedataset(method):
|
|
1757
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(CLUEDataset)."""
|
|
1758
|
+
|
|
1759
|
+
@wraps(method)
|
|
1760
|
+
def new_method(self, *args, **kwargs):
|
|
1761
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
1762
|
+
|
|
1763
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
1764
|
+
|
|
1765
|
+
dataset_files = param_dict.get('dataset_files')
|
|
1766
|
+
type_check(dataset_files, (str, list), "dataset files")
|
|
1767
|
+
if not dataset_files:
|
|
1768
|
+
raise ValueError("Input dataset_files can not be empty, but got '" + str(dataset_files) + "'.")
|
|
1769
|
+
|
|
1770
|
+
# check task
|
|
1771
|
+
task_param = param_dict.get('task')
|
|
1772
|
+
if task_param not in ['AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC', 'CSL']:
|
|
1773
|
+
raise ValueError("task should be 'AFQMC', 'TNEWS', 'IFLYTEK', 'CMNLI', 'WSC' or 'CSL'.")
|
|
1774
|
+
|
|
1775
|
+
# check usage
|
|
1776
|
+
usage_param = param_dict.get('usage')
|
|
1777
|
+
if usage_param not in ['train', 'test', 'eval']:
|
|
1778
|
+
raise ValueError("usage should be 'train', 'test' or 'eval'.")
|
|
1779
|
+
|
|
1780
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
1781
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
1782
|
+
|
|
1783
|
+
cache = param_dict.get('cache')
|
|
1784
|
+
check_cache_option(cache)
|
|
1785
|
+
|
|
1786
|
+
return method(self, *args, **kwargs)
|
|
1787
|
+
|
|
1788
|
+
return new_method
|
|
1789
|
+
|
|
1790
|
+
|
|
1791
|
+
def check_csvdataset(method):
|
|
1792
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(CSVDataset)."""
|
|
1793
|
+
|
|
1794
|
+
@wraps(method)
|
|
1795
|
+
def new_method(self, *args, **kwargs):
|
|
1796
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
1797
|
+
|
|
1798
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
1799
|
+
|
|
1800
|
+
# check dataset_files; required argument
|
|
1801
|
+
dataset_files = param_dict.get('dataset_files')
|
|
1802
|
+
type_check(dataset_files, (str, list), "dataset files")
|
|
1803
|
+
if not dataset_files:
|
|
1804
|
+
raise ValueError("Input dataset_files can not be empty, but got '" + str(dataset_files) + "'.")
|
|
1805
|
+
|
|
1806
|
+
# check field_delim
|
|
1807
|
+
field_delim = param_dict.get('field_delim')
|
|
1808
|
+
if field_delim is not None:
|
|
1809
|
+
type_check(field_delim, (str,), 'field delim')
|
|
1810
|
+
if field_delim in ['"', '\r', '\n'] or len(field_delim) > 1:
|
|
1811
|
+
raise ValueError("field_delim is invalid.")
|
|
1812
|
+
|
|
1813
|
+
# check column_defaults
|
|
1814
|
+
column_defaults = param_dict.get('column_defaults')
|
|
1815
|
+
if column_defaults is not None:
|
|
1816
|
+
if not isinstance(column_defaults, list):
|
|
1817
|
+
raise TypeError("column_defaults should be type of list.")
|
|
1818
|
+
for item in column_defaults:
|
|
1819
|
+
if not isinstance(item, (str, int, float)):
|
|
1820
|
+
raise TypeError("column type in column_defaults is invalid.")
|
|
1821
|
+
|
|
1822
|
+
# check column_names: must be list of string.
|
|
1823
|
+
column_names = param_dict.get("column_names")
|
|
1824
|
+
if column_names is not None:
|
|
1825
|
+
all_string = all(isinstance(item, str) for item in column_names)
|
|
1826
|
+
if not all_string:
|
|
1827
|
+
raise TypeError("column_names should be a list of str.")
|
|
1828
|
+
|
|
1829
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
1830
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
1831
|
+
|
|
1832
|
+
cache = param_dict.get('cache')
|
|
1833
|
+
check_cache_option(cache)
|
|
1834
|
+
|
|
1835
|
+
return method(self, *args, **kwargs)
|
|
1836
|
+
|
|
1837
|
+
return new_method
|
|
1838
|
+
|
|
1839
|
+
|
|
1840
|
+
def check_flowers102dataset(method):
|
|
1841
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(Flowers102Dataset)."""
|
|
1842
|
+
|
|
1843
|
+
@wraps(method)
|
|
1844
|
+
def new_method(self, *args, **kwargs):
|
|
1845
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
1846
|
+
|
|
1847
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
1848
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
1849
|
+
|
|
1850
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
1851
|
+
check_dir(dataset_dir)
|
|
1852
|
+
|
|
1853
|
+
check_dir(os.path.join(dataset_dir, "jpg"))
|
|
1854
|
+
|
|
1855
|
+
check_file(os.path.join(dataset_dir, "imagelabels.mat"))
|
|
1856
|
+
check_file(os.path.join(dataset_dir, "setid.mat"))
|
|
1857
|
+
|
|
1858
|
+
usage = param_dict.get('usage')
|
|
1859
|
+
if usage is not None:
|
|
1860
|
+
check_valid_str(usage, ["train", "valid", "test", "all"], "usage")
|
|
1861
|
+
|
|
1862
|
+
task = param_dict.get('task')
|
|
1863
|
+
if task is not None:
|
|
1864
|
+
check_valid_str(task, ["Classification", "Segmentation"], "task")
|
|
1865
|
+
if task == "Segmentation":
|
|
1866
|
+
check_dir(os.path.join(dataset_dir, "segmim"))
|
|
1867
|
+
|
|
1868
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
1869
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
1870
|
+
|
|
1871
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
1872
|
+
|
|
1873
|
+
return method(self, *args, **kwargs)
|
|
1874
|
+
|
|
1875
|
+
return new_method
|
|
1876
|
+
|
|
1877
|
+
|
|
1878
|
+
def check_textfiledataset(method):
|
|
1879
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(TextFileDataset)."""
|
|
1880
|
+
|
|
1881
|
+
@wraps(method)
|
|
1882
|
+
def new_method(self, *args, **kwargs):
|
|
1883
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
1884
|
+
|
|
1885
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
1886
|
+
|
|
1887
|
+
dataset_files = param_dict.get('dataset_files')
|
|
1888
|
+
type_check(dataset_files, (str, list), "dataset files")
|
|
1889
|
+
if not dataset_files:
|
|
1890
|
+
raise ValueError("Input dataset_files can not be empty, but got '" + str(dataset_files) + "'.")
|
|
1891
|
+
|
|
1892
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
1893
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
1894
|
+
|
|
1895
|
+
cache = param_dict.get('cache')
|
|
1896
|
+
check_cache_option(cache)
|
|
1897
|
+
|
|
1898
|
+
return method(self, *args, **kwargs)
|
|
1899
|
+
|
|
1900
|
+
return new_method
|
|
1901
|
+
|
|
1902
|
+
|
|
1903
|
+
def check_penn_treebank_dataset(method):
|
|
1904
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(PennTreebankDataset)."""
|
|
1905
|
+
|
|
1906
|
+
@wraps(method)
|
|
1907
|
+
def new_method(self, *args, **kwargs):
|
|
1908
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
1909
|
+
|
|
1910
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
1911
|
+
|
|
1912
|
+
# check dataset_dir; required argument
|
|
1913
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
1914
|
+
check_dir(dataset_dir)
|
|
1915
|
+
|
|
1916
|
+
# check usage
|
|
1917
|
+
usage = param_dict.get('usage')
|
|
1918
|
+
if usage is not None:
|
|
1919
|
+
check_valid_str(usage, ["train", "valid", "test", "all"], "usage")
|
|
1920
|
+
|
|
1921
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
1922
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
1923
|
+
|
|
1924
|
+
cache = param_dict.get('cache')
|
|
1925
|
+
check_cache_option(cache)
|
|
1926
|
+
|
|
1927
|
+
return method(self, *args, **kwargs)
|
|
1928
|
+
|
|
1929
|
+
return new_method
|
|
1930
|
+
|
|
1931
|
+
|
|
1932
|
+
def check_split(method):
|
|
1933
|
+
"""check the input arguments of split."""
|
|
1934
|
+
|
|
1935
|
+
@wraps(method)
|
|
1936
|
+
def new_method(self, *args, **kwargs):
|
|
1937
|
+
[sizes, randomize], _ = parse_user_args(method, *args, **kwargs)
|
|
1938
|
+
|
|
1939
|
+
type_check(sizes, (list,), "sizes")
|
|
1940
|
+
type_check(randomize, (bool,), "randomize")
|
|
1941
|
+
|
|
1942
|
+
# check sizes: must be list of float or list of int
|
|
1943
|
+
if not sizes:
|
|
1944
|
+
raise ValueError("sizes cannot be empty.")
|
|
1945
|
+
|
|
1946
|
+
all_int = all(isinstance(item, int) for item in sizes)
|
|
1947
|
+
all_float = all(isinstance(item, float) for item in sizes)
|
|
1948
|
+
|
|
1949
|
+
if not (all_int or all_float):
|
|
1950
|
+
raise ValueError("sizes should be list of int or list of float.")
|
|
1951
|
+
|
|
1952
|
+
if all_int:
|
|
1953
|
+
all_positive = all(item > 0 for item in sizes)
|
|
1954
|
+
if not all_positive:
|
|
1955
|
+
raise ValueError("sizes is a list of int, but there should be no negative or zero numbers.")
|
|
1956
|
+
|
|
1957
|
+
if all_float:
|
|
1958
|
+
all_valid_percentages = all(0 < item <= 1 for item in sizes)
|
|
1959
|
+
if not all_valid_percentages:
|
|
1960
|
+
raise ValueError("sizes is a list of float, but there should be no numbers outside the range (0, 1].")
|
|
1961
|
+
|
|
1962
|
+
epsilon = 0.00001
|
|
1963
|
+
if not abs(sum(sizes) - 1) < epsilon:
|
|
1964
|
+
raise ValueError("sizes is a list of float, but the percentages do not sum up to 1.")
|
|
1965
|
+
|
|
1966
|
+
return method(self, *args, **kwargs)
|
|
1967
|
+
|
|
1968
|
+
return new_method
|
|
1969
|
+
|
|
1970
|
+
|
|
1971
|
+
def check_hostname(hostname):
|
|
1972
|
+
if not hostname or len(hostname) > 255:
|
|
1973
|
+
return False
|
|
1974
|
+
if hostname[-1] == ".":
|
|
1975
|
+
hostname = hostname[:-1] # strip exactly one dot from the right, if present
|
|
1976
|
+
allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE)
|
|
1977
|
+
return all(allowed.match(x) for x in hostname.split("."))
|
|
1978
|
+
|
|
1979
|
+
|
|
1980
|
+
def check_numpyslicesdataset(method):
|
|
1981
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(NumpySlicesDataset)."""
|
|
1982
|
+
|
|
1983
|
+
@wraps(method)
|
|
1984
|
+
def new_method(self, *args, **kwargs):
|
|
1985
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
1986
|
+
|
|
1987
|
+
data = param_dict.get("data")
|
|
1988
|
+
column_names = param_dict.get("column_names")
|
|
1989
|
+
type_check(data, (list, tuple, dict, np.ndarray), "data")
|
|
1990
|
+
if data is None or len(data) == 0: # pylint: disable=len-as-condition
|
|
1991
|
+
raise ValueError("Argument data cannot be empty")
|
|
1992
|
+
if isinstance(data, tuple):
|
|
1993
|
+
type_check(data[0], (list, np.ndarray), "data[0]")
|
|
1994
|
+
|
|
1995
|
+
# check column_names
|
|
1996
|
+
if column_names is not None:
|
|
1997
|
+
check_columns(column_names, "column_names")
|
|
1998
|
+
|
|
1999
|
+
# check num of input column in column_names
|
|
2000
|
+
column_num = 1 if isinstance(column_names, str) else len(column_names)
|
|
2001
|
+
if isinstance(data, dict):
|
|
2002
|
+
data_column = len(list(data.keys()))
|
|
2003
|
+
if column_num != data_column:
|
|
2004
|
+
raise ValueError("Num of input column names is {0}, but required is {1}."
|
|
2005
|
+
.format(column_num, data_column))
|
|
2006
|
+
|
|
2007
|
+
elif isinstance(data, tuple):
|
|
2008
|
+
if column_num != len(data):
|
|
2009
|
+
raise ValueError("Num of input column names is {0}, but required is {1}."
|
|
2010
|
+
.format(column_num, len(data)))
|
|
2011
|
+
else:
|
|
2012
|
+
if column_num != 1:
|
|
2013
|
+
raise ValueError("Num of input column names is {0}, but required is {1} as data is list."
|
|
2014
|
+
.format(column_num, 1))
|
|
2015
|
+
|
|
2016
|
+
return method(self, *args, **kwargs)
|
|
2017
|
+
|
|
2018
|
+
return new_method
|
|
2019
|
+
|
|
2020
|
+
|
|
2021
|
+
def check_paddeddataset(method):
|
|
2022
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(PaddedDataset)."""
|
|
2023
|
+
|
|
2024
|
+
@wraps(method)
|
|
2025
|
+
def new_method(self, *args, **kwargs):
|
|
2026
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2027
|
+
|
|
2028
|
+
padded_samples = param_dict.get("padded_samples")
|
|
2029
|
+
if not padded_samples:
|
|
2030
|
+
raise ValueError("padded_samples cannot be empty.")
|
|
2031
|
+
type_check(padded_samples, (list,), "padded_samples")
|
|
2032
|
+
type_check(padded_samples[0], (dict,), "padded_element")
|
|
2033
|
+
return method(self, *args, **kwargs)
|
|
2034
|
+
|
|
2035
|
+
return new_method
|
|
2036
|
+
|
|
2037
|
+
|
|
2038
|
+
def check_cache_option(cache):
|
|
2039
|
+
"""Sanity check for cache parameter"""
|
|
2040
|
+
if cache is not None:
|
|
2041
|
+
type_check(cache, (cache_client.DatasetCache,), "cache")
|
|
2042
|
+
check_independent_mode("Dataset Cache")
|
|
2043
|
+
|
|
2044
|
+
|
|
2045
|
+
def check_to_device_send(method):
|
|
2046
|
+
"""Check the input arguments of send function for TransferDataset."""
|
|
2047
|
+
|
|
2048
|
+
@wraps(method)
|
|
2049
|
+
def new_method(self, *args, **kwargs):
|
|
2050
|
+
[num_epochs], _ = parse_user_args(method, *args, **kwargs)
|
|
2051
|
+
|
|
2052
|
+
if num_epochs is not None:
|
|
2053
|
+
type_check(num_epochs, (int,), "num_epochs")
|
|
2054
|
+
check_value(num_epochs, [-1, INT32_MAX], "num_epochs")
|
|
2055
|
+
|
|
2056
|
+
return method(self, *args, **kwargs)
|
|
2057
|
+
|
|
2058
|
+
return new_method
|
|
2059
|
+
|
|
2060
|
+
|
|
2061
|
+
def check_emnist_dataset(method):
|
|
2062
|
+
"""A wrapper that wraps a parameter checker emnist dataset"""
|
|
2063
|
+
|
|
2064
|
+
@wraps(method)
|
|
2065
|
+
def new_method(self, *args, **kwargs):
|
|
2066
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2067
|
+
|
|
2068
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2069
|
+
nreq_param_bool = ['shuffle']
|
|
2070
|
+
|
|
2071
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2072
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2073
|
+
|
|
2074
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2075
|
+
check_dir(dataset_dir)
|
|
2076
|
+
|
|
2077
|
+
name = param_dict.get('name')
|
|
2078
|
+
check_valid_str(name, ["byclass", "bymerge", "balanced", "letters", "digits", "mnist"], "name")
|
|
2079
|
+
|
|
2080
|
+
usage = param_dict.get('usage')
|
|
2081
|
+
if usage is not None:
|
|
2082
|
+
check_valid_str(usage, ["train", "test", "all"], "usage")
|
|
2083
|
+
|
|
2084
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2085
|
+
|
|
2086
|
+
cache = param_dict.get('cache')
|
|
2087
|
+
check_cache_option(cache)
|
|
2088
|
+
|
|
2089
|
+
return method(self, *args, **kwargs)
|
|
2090
|
+
|
|
2091
|
+
return new_method
|
|
2092
|
+
|
|
2093
|
+
|
|
2094
|
+
def check_flickr_dataset(method):
|
|
2095
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(Flickr8k, Flickr30k)."""
|
|
2096
|
+
|
|
2097
|
+
@wraps(method)
|
|
2098
|
+
def new_method(self, *args, **kwargs):
|
|
2099
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2100
|
+
|
|
2101
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2102
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
2103
|
+
|
|
2104
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2105
|
+
annotation_file = param_dict.get('annotation_file')
|
|
2106
|
+
check_dir(dataset_dir)
|
|
2107
|
+
check_file(annotation_file)
|
|
2108
|
+
|
|
2109
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2110
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2111
|
+
|
|
2112
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2113
|
+
|
|
2114
|
+
cache = param_dict.get('cache')
|
|
2115
|
+
check_cache_option(cache)
|
|
2116
|
+
|
|
2117
|
+
return method(self, *args, **kwargs)
|
|
2118
|
+
|
|
2119
|
+
return new_method
|
|
2120
|
+
|
|
2121
|
+
|
|
2122
|
+
def check_food101_dataset(method):
|
|
2123
|
+
"""A wrapper that wraps a parameter checker around the Food101Dataset."""
|
|
2124
|
+
|
|
2125
|
+
@wraps(method)
|
|
2126
|
+
def new_method(self, *args, **kwargs):
|
|
2127
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2128
|
+
|
|
2129
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2130
|
+
nreq_param_bool = ['decode', 'shuffle']
|
|
2131
|
+
|
|
2132
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2133
|
+
check_dir(dataset_dir)
|
|
2134
|
+
|
|
2135
|
+
usage = param_dict.get('usage')
|
|
2136
|
+
if usage is not None:
|
|
2137
|
+
check_valid_str(usage, ["train", "test", "all"], "usage")
|
|
2138
|
+
|
|
2139
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2140
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2141
|
+
|
|
2142
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2143
|
+
|
|
2144
|
+
cache = param_dict.get('cache')
|
|
2145
|
+
check_cache_option(cache)
|
|
2146
|
+
|
|
2147
|
+
return method(self, *args, **kwargs)
|
|
2148
|
+
|
|
2149
|
+
return new_method
|
|
2150
|
+
|
|
2151
|
+
|
|
2152
|
+
def check_sb_dataset(method):
|
|
2153
|
+
"""A wrapper that wraps a parameter checker around the original Semantic Boundaries Dataset."""
|
|
2154
|
+
|
|
2155
|
+
@wraps(method)
|
|
2156
|
+
def new_method(self, *args, **kwargs):
|
|
2157
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2158
|
+
|
|
2159
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2160
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
2161
|
+
|
|
2162
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2163
|
+
check_dir(dataset_dir)
|
|
2164
|
+
|
|
2165
|
+
usage = param_dict.get('usage')
|
|
2166
|
+
if usage is not None:
|
|
2167
|
+
check_valid_str(usage, ["train", "val", "train_noval", "all"], "usage")
|
|
2168
|
+
|
|
2169
|
+
task = param_dict.get('task')
|
|
2170
|
+
if task is not None:
|
|
2171
|
+
check_valid_str(task, ["Boundaries", "Segmentation"], "task")
|
|
2172
|
+
|
|
2173
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2174
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2175
|
+
|
|
2176
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2177
|
+
|
|
2178
|
+
return method(self, *args, **kwargs)
|
|
2179
|
+
|
|
2180
|
+
return new_method
|
|
2181
|
+
|
|
2182
|
+
|
|
2183
|
+
def check_speech_commands_dataset(method):
|
|
2184
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(SpeechCommandsDataset)."""
|
|
2185
|
+
|
|
2186
|
+
@wraps(method)
|
|
2187
|
+
def new_method(self, *args, **kwargs):
|
|
2188
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2189
|
+
|
|
2190
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2191
|
+
nreq_param_bool = ['shuffle']
|
|
2192
|
+
|
|
2193
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2194
|
+
check_dir(dataset_dir)
|
|
2195
|
+
|
|
2196
|
+
usage = param_dict.get('usage')
|
|
2197
|
+
if usage is not None:
|
|
2198
|
+
check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
|
|
2199
|
+
|
|
2200
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2201
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2202
|
+
|
|
2203
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2204
|
+
|
|
2205
|
+
cache = param_dict.get('cache')
|
|
2206
|
+
check_cache_option(cache)
|
|
2207
|
+
|
|
2208
|
+
return method(self, *args, **kwargs)
|
|
2209
|
+
|
|
2210
|
+
return new_method
|
|
2211
|
+
|
|
2212
|
+
|
|
2213
|
+
def check_squad_dataset(method):
|
|
2214
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(SQuADDataset)."""
|
|
2215
|
+
|
|
2216
|
+
@wraps(method)
|
|
2217
|
+
def new_method(self, *args, **kwargs):
|
|
2218
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2219
|
+
|
|
2220
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2221
|
+
|
|
2222
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2223
|
+
check_dir(dataset_dir)
|
|
2224
|
+
|
|
2225
|
+
# check usage
|
|
2226
|
+
usage = param_dict.get('usage')
|
|
2227
|
+
if usage is not None:
|
|
2228
|
+
check_valid_str(usage, ['train', 'dev', 'all'], "usage")
|
|
2229
|
+
|
|
2230
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2231
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2232
|
+
|
|
2233
|
+
cache = param_dict.get('cache')
|
|
2234
|
+
check_cache_option(cache)
|
|
2235
|
+
|
|
2236
|
+
return method(self, *args, **kwargs)
|
|
2237
|
+
|
|
2238
|
+
return new_method
|
|
2239
|
+
|
|
2240
|
+
|
|
2241
|
+
def check_cityscapes_dataset(method):
|
|
2242
|
+
"""A wrapper that wraps a parameter checker around the original CityScapesDataset."""
|
|
2243
|
+
|
|
2244
|
+
@wraps(method)
|
|
2245
|
+
def new_method(self, *args, **kwargs):
|
|
2246
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2247
|
+
|
|
2248
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2249
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
2250
|
+
|
|
2251
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2252
|
+
check_dir(dataset_dir)
|
|
2253
|
+
|
|
2254
|
+
task = param_dict.get('task')
|
|
2255
|
+
check_valid_str(task, ["instance", "semantic", "polygon", "color"], "task")
|
|
2256
|
+
|
|
2257
|
+
quality_mode = param_dict.get('quality_mode')
|
|
2258
|
+
check_valid_str(quality_mode, ["fine", "coarse"], "quality_mode")
|
|
2259
|
+
|
|
2260
|
+
usage = param_dict.get('usage')
|
|
2261
|
+
if quality_mode == "fine":
|
|
2262
|
+
valid_strings = ["train", "test", "val", "all"]
|
|
2263
|
+
else:
|
|
2264
|
+
valid_strings = ["train", "train_extra", "val", "all"]
|
|
2265
|
+
check_valid_str(usage, valid_strings, "usage")
|
|
2266
|
+
|
|
2267
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2268
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2269
|
+
|
|
2270
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2271
|
+
|
|
2272
|
+
return method(self, *args, **kwargs)
|
|
2273
|
+
|
|
2274
|
+
return new_method
|
|
2275
|
+
|
|
2276
|
+
|
|
2277
|
+
def check_div2k_dataset(method):
|
|
2278
|
+
"""A wrapper that wraps a parameter checker around the original DIV2KDataset."""
|
|
2279
|
+
|
|
2280
|
+
@wraps(method)
|
|
2281
|
+
def new_method(self, *args, **kwargs):
|
|
2282
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2283
|
+
|
|
2284
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2285
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
2286
|
+
|
|
2287
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2288
|
+
check_dir(dataset_dir)
|
|
2289
|
+
|
|
2290
|
+
usage = param_dict.get('usage')
|
|
2291
|
+
check_valid_str(usage, ['train', 'valid', 'all'], "usage")
|
|
2292
|
+
|
|
2293
|
+
downgrade = param_dict.get('downgrade')
|
|
2294
|
+
check_valid_str(downgrade, ['bicubic', 'unknown', 'mild', 'difficult', 'wild'], 'downgrade')
|
|
2295
|
+
|
|
2296
|
+
validate_dataset_param_value(['scale'], param_dict, int)
|
|
2297
|
+
scale = param_dict.get('scale')
|
|
2298
|
+
scale_values = [2, 3, 4, 8]
|
|
2299
|
+
if scale not in scale_values:
|
|
2300
|
+
raise ValueError("Input scale is not within the valid set of {0}.".format(str(scale_values)))
|
|
2301
|
+
|
|
2302
|
+
if scale == 8 and downgrade != "bicubic":
|
|
2303
|
+
raise ValueError("DIV2KNode: scale equal to 8 is allowed only in bicubic downgrade.")
|
|
2304
|
+
|
|
2305
|
+
downgrade_2018 = ["mild", "difficult", "wild"]
|
|
2306
|
+
if downgrade in downgrade_2018 and scale != 4:
|
|
2307
|
+
raise ValueError("DIV2KNode: {0} downgrade requires scale equal to 4.".format(downgrade))
|
|
2308
|
+
|
|
2309
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2310
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2311
|
+
|
|
2312
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2313
|
+
|
|
2314
|
+
return method(self, *args, **kwargs)
|
|
2315
|
+
|
|
2316
|
+
return new_method
|
|
2317
|
+
|
|
2318
|
+
|
|
2319
|
+
def check_fake_image_dataset(method):
|
|
2320
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(FakeImageDataset)."""
|
|
2321
|
+
|
|
2322
|
+
@wraps(method)
|
|
2323
|
+
def new_method(self, *args, **kwargs):
|
|
2324
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2325
|
+
|
|
2326
|
+
nreq_param_int = ['num_images', 'num_classes', 'base_seed', 'num_samples',
|
|
2327
|
+
'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2328
|
+
nreq_param_bool = ['shuffle']
|
|
2329
|
+
|
|
2330
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2331
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2332
|
+
|
|
2333
|
+
num_images = param_dict.get("num_images")
|
|
2334
|
+
check_pos_int32(num_images, "num_images")
|
|
2335
|
+
|
|
2336
|
+
image_size = param_dict.get("image_size")
|
|
2337
|
+
type_check(image_size, (list, tuple), "image_size")
|
|
2338
|
+
if len(image_size) != 3:
|
|
2339
|
+
raise ValueError("image_size should be a list or tuple of length 3, but got {0}".format(len(image_size)))
|
|
2340
|
+
for i, value in enumerate(image_size):
|
|
2341
|
+
check_pos_int32(value, "image_size[{0}]".format(i))
|
|
2342
|
+
|
|
2343
|
+
num_classes = param_dict.get("num_classes")
|
|
2344
|
+
check_pos_int32(num_classes, "num_classes")
|
|
2345
|
+
|
|
2346
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2347
|
+
|
|
2348
|
+
cache = param_dict.get('cache')
|
|
2349
|
+
check_cache_option(cache)
|
|
2350
|
+
|
|
2351
|
+
return method(self, *args, **kwargs)
|
|
2352
|
+
|
|
2353
|
+
return new_method
|
|
2354
|
+
|
|
2355
|
+
|
|
2356
|
+
def check_ag_news_dataset(method):
|
|
2357
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(AGNewsDataset)."""
|
|
2358
|
+
|
|
2359
|
+
@wraps(method)
|
|
2360
|
+
def new_method(self, *args, **kwargs):
|
|
2361
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2362
|
+
|
|
2363
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2364
|
+
|
|
2365
|
+
# check dataset_files; required argument
|
|
2366
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2367
|
+
check_dir(dataset_dir)
|
|
2368
|
+
|
|
2369
|
+
# check usage
|
|
2370
|
+
usage = param_dict.get('usage')
|
|
2371
|
+
if usage is not None:
|
|
2372
|
+
check_valid_str(usage, ["train", "test", "all"], "usage")
|
|
2373
|
+
|
|
2374
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2375
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2376
|
+
|
|
2377
|
+
cache = param_dict.get('cache')
|
|
2378
|
+
check_cache_option(cache)
|
|
2379
|
+
|
|
2380
|
+
return method(self, *args, **kwargs)
|
|
2381
|
+
|
|
2382
|
+
return new_method
|
|
2383
|
+
|
|
2384
|
+
|
|
2385
|
+
def check_dbpedia_dataset(method):
|
|
2386
|
+
"""A wrapper that wraps a parameter checker around the original DBpediaDataset."""
|
|
2387
|
+
|
|
2388
|
+
@wraps(method)
|
|
2389
|
+
def new_method(self, *args, **kwargs):
|
|
2390
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2391
|
+
|
|
2392
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2393
|
+
|
|
2394
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2395
|
+
check_dir(dataset_dir)
|
|
2396
|
+
|
|
2397
|
+
usage = param_dict.get('usage')
|
|
2398
|
+
if usage is not None:
|
|
2399
|
+
check_valid_str(usage, ["train", "test", "all"], "usage")
|
|
2400
|
+
|
|
2401
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2402
|
+
|
|
2403
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2404
|
+
|
|
2405
|
+
cache = param_dict.get('cache')
|
|
2406
|
+
check_cache_option(cache)
|
|
2407
|
+
|
|
2408
|
+
return method(self, *args, **kwargs)
|
|
2409
|
+
|
|
2410
|
+
return new_method
|
|
2411
|
+
|
|
2412
|
+
|
|
2413
|
+
def check_wider_face_dataset(method):
|
|
2414
|
+
"""A wrapper that wraps a parameter checker around the WIDERFaceDataset."""
|
|
2415
|
+
|
|
2416
|
+
@wraps(method)
|
|
2417
|
+
def new_method(self, *args, **kwargs):
|
|
2418
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2419
|
+
|
|
2420
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2421
|
+
nreq_param_bool = ['decode', 'shuffle']
|
|
2422
|
+
|
|
2423
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2424
|
+
check_dir(dataset_dir)
|
|
2425
|
+
|
|
2426
|
+
usage = param_dict.get('usage')
|
|
2427
|
+
if usage is not None:
|
|
2428
|
+
check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
|
|
2429
|
+
|
|
2430
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2431
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2432
|
+
|
|
2433
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2434
|
+
|
|
2435
|
+
cache = param_dict.get('cache')
|
|
2436
|
+
check_cache_option(cache)
|
|
2437
|
+
|
|
2438
|
+
return method(self, *args, **kwargs)
|
|
2439
|
+
|
|
2440
|
+
return new_method
|
|
2441
|
+
|
|
2442
|
+
|
|
2443
|
+
def check_yelp_review_dataset(method):
|
|
2444
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(YelpReviewDataset)."""
|
|
2445
|
+
|
|
2446
|
+
@wraps(method)
|
|
2447
|
+
def new_method(self, *args, **kwargs):
|
|
2448
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2449
|
+
|
|
2450
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2451
|
+
|
|
2452
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2453
|
+
check_dir(dataset_dir)
|
|
2454
|
+
|
|
2455
|
+
# check usage
|
|
2456
|
+
usage = param_dict.get('usage')
|
|
2457
|
+
if usage is not None:
|
|
2458
|
+
check_valid_str(usage, ["train", "test", "all"], "usage")
|
|
2459
|
+
|
|
2460
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2461
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2462
|
+
|
|
2463
|
+
cache = param_dict.get('cache')
|
|
2464
|
+
check_cache_option(cache)
|
|
2465
|
+
|
|
2466
|
+
return method(self, *args, **kwargs)
|
|
2467
|
+
|
|
2468
|
+
return new_method
|
|
2469
|
+
|
|
2470
|
+
|
|
2471
|
+
def check_yes_no_dataset(method):
|
|
2472
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(YesNoDataset)."""
|
|
2473
|
+
|
|
2474
|
+
@wraps(method)
|
|
2475
|
+
def new_method(self, *args, **kwargs):
|
|
2476
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2477
|
+
|
|
2478
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2479
|
+
nreq_param_bool = ['shuffle']
|
|
2480
|
+
|
|
2481
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2482
|
+
check_dir(dataset_dir)
|
|
2483
|
+
|
|
2484
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2485
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2486
|
+
|
|
2487
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2488
|
+
|
|
2489
|
+
cache = param_dict.get('cache')
|
|
2490
|
+
check_cache_option(cache)
|
|
2491
|
+
|
|
2492
|
+
return method(self, *args, **kwargs)
|
|
2493
|
+
|
|
2494
|
+
return new_method
|
|
2495
|
+
|
|
2496
|
+
|
|
2497
|
+
def check_tedlium_dataset(method):
|
|
2498
|
+
"""Wrapper method to check the parameters of TedliumDataset."""
|
|
2499
|
+
|
|
2500
|
+
@wraps(method)
|
|
2501
|
+
def new_method(self, *args, **kwargs):
|
|
2502
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2503
|
+
|
|
2504
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2505
|
+
nreq_param_bool = ['shuffle']
|
|
2506
|
+
|
|
2507
|
+
release = param_dict.get('release')
|
|
2508
|
+
check_valid_str(release, ["release1", "release2", "release3"], "release")
|
|
2509
|
+
|
|
2510
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2511
|
+
check_dir(dataset_dir)
|
|
2512
|
+
|
|
2513
|
+
usage = param_dict.get('usage')
|
|
2514
|
+
if usage is not None:
|
|
2515
|
+
if release in ["release1", "release2"]:
|
|
2516
|
+
check_valid_str(usage, ["train", "test", "dev", "all"], "usage")
|
|
2517
|
+
else:
|
|
2518
|
+
check_valid_str(usage, ["all"], "usage")
|
|
2519
|
+
|
|
2520
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2521
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2522
|
+
|
|
2523
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2524
|
+
|
|
2525
|
+
cache = param_dict.get('cache')
|
|
2526
|
+
check_cache_option(cache)
|
|
2527
|
+
|
|
2528
|
+
return method(self, *args, **kwargs)
|
|
2529
|
+
|
|
2530
|
+
return new_method
|
|
2531
|
+
|
|
2532
|
+
|
|
2533
|
+
def check_svhn_dataset(method):
|
|
2534
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(SVHNDataset)."""
|
|
2535
|
+
|
|
2536
|
+
@wraps(method)
|
|
2537
|
+
def new_method(self, *args, **kwargs):
|
|
2538
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2539
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2540
|
+
check_dir(dataset_dir)
|
|
2541
|
+
|
|
2542
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2543
|
+
nreq_param_bool = ['shuffle']
|
|
2544
|
+
|
|
2545
|
+
usage = param_dict.get('usage')
|
|
2546
|
+
if usage is not None:
|
|
2547
|
+
check_valid_str(usage, ["train", "test", "extra", "all"], "usage")
|
|
2548
|
+
if usage == "all":
|
|
2549
|
+
for _usage in ["train", "test", "extra"]:
|
|
2550
|
+
check_file(os.path.join(dataset_dir, _usage + "_32x32.mat"))
|
|
2551
|
+
else:
|
|
2552
|
+
check_file(os.path.join(dataset_dir, usage + "_32x32.mat"))
|
|
2553
|
+
|
|
2554
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2555
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2556
|
+
|
|
2557
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2558
|
+
|
|
2559
|
+
return method(self, *args, **kwargs)
|
|
2560
|
+
|
|
2561
|
+
return new_method
|
|
2562
|
+
|
|
2563
|
+
|
|
2564
|
+
def check_sst2_dataset(method):
|
|
2565
|
+
"""A wrapper that wraps a parameter checker around the original SST2 Dataset."""
|
|
2566
|
+
|
|
2567
|
+
@wraps(method)
|
|
2568
|
+
def new_method(self, *args, **kwargs):
|
|
2569
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2570
|
+
|
|
2571
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2572
|
+
|
|
2573
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2574
|
+
check_dir(dataset_dir)
|
|
2575
|
+
|
|
2576
|
+
usage = param_dict.get('usage')
|
|
2577
|
+
if usage is not None:
|
|
2578
|
+
check_valid_str(usage, ["train", "test", "dev"], "usage")
|
|
2579
|
+
|
|
2580
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2581
|
+
|
|
2582
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2583
|
+
|
|
2584
|
+
cache = param_dict.get('cache')
|
|
2585
|
+
check_cache_option(cache)
|
|
2586
|
+
|
|
2587
|
+
return method(self, *args, **kwargs)
|
|
2588
|
+
|
|
2589
|
+
return new_method
|
|
2590
|
+
|
|
2591
|
+
|
|
2592
|
+
def check_stl10_dataset(method):
|
|
2593
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(STL10Dataset)."""
|
|
2594
|
+
|
|
2595
|
+
@wraps(method)
|
|
2596
|
+
def new_method(self, *args, **kwargs):
|
|
2597
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2598
|
+
|
|
2599
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2600
|
+
nreq_param_bool = ['shuffle']
|
|
2601
|
+
|
|
2602
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2603
|
+
check_dir(dataset_dir)
|
|
2604
|
+
|
|
2605
|
+
usage = param_dict.get('usage')
|
|
2606
|
+
if usage is not None:
|
|
2607
|
+
check_valid_str(usage, ["train", "test", "unlabeled", "train+unlabeled", "all"], "usage")
|
|
2608
|
+
if usage == "all":
|
|
2609
|
+
for _usage in ["train", "test", "unlabeled"]:
|
|
2610
|
+
check_file(os.path.join(dataset_dir, _usage + "_X.bin"))
|
|
2611
|
+
if _usage == "unlabeled":
|
|
2612
|
+
continue
|
|
2613
|
+
else:
|
|
2614
|
+
check_file(os.path.join(dataset_dir, _usage + "_y.bin"))
|
|
2615
|
+
elif usage == "train+unlabeled":
|
|
2616
|
+
check_file(os.path.join(dataset_dir, "train_X.bin"))
|
|
2617
|
+
check_file(os.path.join(dataset_dir, "train_y.bin"))
|
|
2618
|
+
check_file(os.path.join(dataset_dir, "unlabeled_X.bin"))
|
|
2619
|
+
elif usage == "unlabeled":
|
|
2620
|
+
check_file(os.path.join(dataset_dir, "unlabeled_X.bin"))
|
|
2621
|
+
else:
|
|
2622
|
+
check_file(os.path.join(dataset_dir, usage + "_X.bin"))
|
|
2623
|
+
check_file(os.path.join(dataset_dir, usage + "_y.bin"))
|
|
2624
|
+
|
|
2625
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2626
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2627
|
+
|
|
2628
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2629
|
+
|
|
2630
|
+
cache = param_dict.get('cache')
|
|
2631
|
+
check_cache_option(cache)
|
|
2632
|
+
|
|
2633
|
+
return method(self, *args, **kwargs)
|
|
2634
|
+
|
|
2635
|
+
return new_method
|
|
2636
|
+
|
|
2637
|
+
|
|
2638
|
+
def check_sun397_dataset(method):
|
|
2639
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(SUN397Dataset)."""
|
|
2640
|
+
|
|
2641
|
+
@wraps(method)
|
|
2642
|
+
def new_method(self, *args, **kwargs):
|
|
2643
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2644
|
+
|
|
2645
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2646
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
2647
|
+
|
|
2648
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2649
|
+
check_dir(dataset_dir)
|
|
2650
|
+
|
|
2651
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2652
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2653
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2654
|
+
|
|
2655
|
+
cache = param_dict.get('cache')
|
|
2656
|
+
check_cache_option(cache)
|
|
2657
|
+
|
|
2658
|
+
return method(self, *args, **kwargs)
|
|
2659
|
+
|
|
2660
|
+
return new_method
|
|
2661
|
+
|
|
2662
|
+
|
|
2663
|
+
def check_yahoo_answers_dataset(method):
|
|
2664
|
+
"""A wrapper that wraps a parameter checker around the original YahooAnswers Dataset."""
|
|
2665
|
+
|
|
2666
|
+
@wraps(method)
|
|
2667
|
+
def new_method(self, *args, **kwargs):
|
|
2668
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2669
|
+
|
|
2670
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2671
|
+
|
|
2672
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2673
|
+
check_dir(dataset_dir)
|
|
2674
|
+
|
|
2675
|
+
usage = param_dict.get('usage')
|
|
2676
|
+
if usage is not None:
|
|
2677
|
+
check_valid_str(usage, ["train", "test", "all"], "usage")
|
|
2678
|
+
|
|
2679
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2680
|
+
|
|
2681
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2682
|
+
|
|
2683
|
+
cache = param_dict.get('cache')
|
|
2684
|
+
check_cache_option(cache)
|
|
2685
|
+
|
|
2686
|
+
return method(self, *args, **kwargs)
|
|
2687
|
+
|
|
2688
|
+
return new_method
|
|
2689
|
+
|
|
2690
|
+
|
|
2691
|
+
def check_conll2000_dataset(method):
|
|
2692
|
+
""" A wrapper that wraps a parameter checker around the original Dataset(CoNLL2000Dataset)."""
|
|
2693
|
+
|
|
2694
|
+
@wraps(method)
|
|
2695
|
+
def new_method(self, *args, **kwargs):
|
|
2696
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2697
|
+
|
|
2698
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2699
|
+
|
|
2700
|
+
# check dataset_dir
|
|
2701
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2702
|
+
check_dir(dataset_dir)
|
|
2703
|
+
|
|
2704
|
+
# check usage
|
|
2705
|
+
usage = param_dict.get('usage')
|
|
2706
|
+
if usage is not None:
|
|
2707
|
+
check_valid_str(usage, ["train", "test", "all"], "usage")
|
|
2708
|
+
|
|
2709
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2710
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2711
|
+
|
|
2712
|
+
cache = param_dict.get('cache')
|
|
2713
|
+
check_cache_option(cache)
|
|
2714
|
+
|
|
2715
|
+
return method(self, *args, **kwargs)
|
|
2716
|
+
|
|
2717
|
+
return new_method
|
|
2718
|
+
|
|
2719
|
+
|
|
2720
|
+
def check_amazon_review_dataset(method):
|
|
2721
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(AmazonReviewDataset)."""
|
|
2722
|
+
|
|
2723
|
+
@wraps(method)
|
|
2724
|
+
def new_method(self, *args, **kwargs):
|
|
2725
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2726
|
+
|
|
2727
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2728
|
+
|
|
2729
|
+
# check dataset_files
|
|
2730
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2731
|
+
check_dir(dataset_dir)
|
|
2732
|
+
|
|
2733
|
+
# check usage
|
|
2734
|
+
usage = param_dict.get('usage')
|
|
2735
|
+
if usage is not None:
|
|
2736
|
+
check_valid_str(usage, ["train", "test", "all"], "usage")
|
|
2737
|
+
|
|
2738
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2739
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2740
|
+
|
|
2741
|
+
cache = param_dict.get('cache')
|
|
2742
|
+
check_cache_option(cache)
|
|
2743
|
+
|
|
2744
|
+
return method(self, *args, **kwargs)
|
|
2745
|
+
|
|
2746
|
+
return new_method
|
|
2747
|
+
|
|
2748
|
+
|
|
2749
|
+
def check_semeion_dataset(method):
|
|
2750
|
+
"""Wrapper method to check the parameters of SemeionDataset."""
|
|
2751
|
+
|
|
2752
|
+
@wraps(method)
|
|
2753
|
+
def new_method(self, *args, **kwargs):
|
|
2754
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2755
|
+
|
|
2756
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2757
|
+
nreq_param_bool = ['shuffle']
|
|
2758
|
+
|
|
2759
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2760
|
+
check_dir(dataset_dir)
|
|
2761
|
+
|
|
2762
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2763
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2764
|
+
|
|
2765
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2766
|
+
|
|
2767
|
+
cache = param_dict.get('cache')
|
|
2768
|
+
check_cache_option(cache)
|
|
2769
|
+
|
|
2770
|
+
return method(self, *args, **kwargs)
|
|
2771
|
+
|
|
2772
|
+
return new_method
|
|
2773
|
+
|
|
2774
|
+
|
|
2775
|
+
def check_wiki_text_dataset(method):
|
|
2776
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(WikiTextDataset)."""
|
|
2777
|
+
|
|
2778
|
+
@wraps(method)
|
|
2779
|
+
def new_method(self, *args, **kwargs):
|
|
2780
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2781
|
+
|
|
2782
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2783
|
+
|
|
2784
|
+
# check dataset_dir
|
|
2785
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2786
|
+
check_dir(dataset_dir)
|
|
2787
|
+
|
|
2788
|
+
# check usage
|
|
2789
|
+
usage = param_dict.get('usage')
|
|
2790
|
+
if usage is not None:
|
|
2791
|
+
check_valid_str(usage, ["train", "valid", "test", "all"], "usage")
|
|
2792
|
+
|
|
2793
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2794
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2795
|
+
|
|
2796
|
+
cache = param_dict.get('cache')
|
|
2797
|
+
check_cache_option(cache)
|
|
2798
|
+
|
|
2799
|
+
return method(self, *args, **kwargs)
|
|
2800
|
+
|
|
2801
|
+
return new_method
|
|
2802
|
+
|
|
2803
|
+
|
|
2804
|
+
def check_en_wik9_dataset(method):
|
|
2805
|
+
"""Wrapper method to check the parameters of EnWik9 dataset."""
|
|
2806
|
+
|
|
2807
|
+
@wraps(method)
|
|
2808
|
+
def new_method(self, *args, **kwargs):
|
|
2809
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2810
|
+
|
|
2811
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2812
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2813
|
+
check_dir(dataset_dir)
|
|
2814
|
+
|
|
2815
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2816
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2817
|
+
|
|
2818
|
+
cache = param_dict.get('cache')
|
|
2819
|
+
check_cache_option(cache)
|
|
2820
|
+
|
|
2821
|
+
return method(self, *args, **kwargs)
|
|
2822
|
+
|
|
2823
|
+
return new_method
|
|
2824
|
+
|
|
2825
|
+
|
|
2826
|
+
def check_multi30k_dataset(method):
|
|
2827
|
+
"""A wrapper that wraps a parameter checker around the original Dataset (Multi30kDataset)."""
|
|
2828
|
+
|
|
2829
|
+
@wraps(method)
|
|
2830
|
+
def new_method(self, *args, **kwargs):
|
|
2831
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2832
|
+
|
|
2833
|
+
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']
|
|
2834
|
+
nreq_param_bool = ['shuffle', 'decode']
|
|
2835
|
+
|
|
2836
|
+
dataset_dir = param_dict.get('dataset_dir')
|
|
2837
|
+
check_dir(dataset_dir)
|
|
2838
|
+
|
|
2839
|
+
usage = param_dict.get('usage')
|
|
2840
|
+
if usage is not None:
|
|
2841
|
+
check_valid_str(usage, ["train", "test", "valid", "all"], "usage")
|
|
2842
|
+
|
|
2843
|
+
language_pair = param_dict.get('language_pair')
|
|
2844
|
+
support_language_pair = [['en', 'de'], ['de', 'en'], ('en', 'de'), ('de', 'en')]
|
|
2845
|
+
if language_pair is not None:
|
|
2846
|
+
type_check(language_pair, (list, tuple), "language_pair")
|
|
2847
|
+
if len(language_pair) != 2:
|
|
2848
|
+
raise ValueError(
|
|
2849
|
+
"language_pair should be a list or tuple of length 2, but got {0}".format(len(language_pair)))
|
|
2850
|
+
if language_pair not in support_language_pair:
|
|
2851
|
+
raise ValueError(
|
|
2852
|
+
"language_pair can only be ['en', 'de'] or ['en', 'de'], but got {0}".format(language_pair))
|
|
2853
|
+
|
|
2854
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2855
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2856
|
+
|
|
2857
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2858
|
+
|
|
2859
|
+
return method(self, *args, **kwargs)
|
|
2860
|
+
|
|
2861
|
+
return new_method
|
|
2862
|
+
|
|
2863
|
+
|
|
2864
|
+
def check_obsminddataset(method):
|
|
2865
|
+
"""A wrapper that wraps a parameter checker around the original Dataset(OBSMindDataset)."""
|
|
2866
|
+
|
|
2867
|
+
@wraps(method)
|
|
2868
|
+
def new_method(self, *args, **kwargs):
|
|
2869
|
+
_, param_dict = parse_user_args(method, *args, **kwargs)
|
|
2870
|
+
|
|
2871
|
+
nreq_param_int = ['num_shards', 'shard_id']
|
|
2872
|
+
nreq_param_list = ['columns_list']
|
|
2873
|
+
nreq_param_bool = ['shard_equal_rows']
|
|
2874
|
+
nreq_param_str = ['server', 'ak', 'sk', 'sync_obs_path']
|
|
2875
|
+
|
|
2876
|
+
dataset_files = param_dict.get('dataset_files')
|
|
2877
|
+
type_check(dataset_files, (list,), "dataset_files")
|
|
2878
|
+
for dataset_file in dataset_files:
|
|
2879
|
+
if not isinstance(dataset_file, str):
|
|
2880
|
+
raise TypeError("Item of dataset files is not of type [{}], but got {}.".format(type(''),
|
|
2881
|
+
type(dataset_file)))
|
|
2882
|
+
validate_dataset_param_value(nreq_param_int, param_dict, int)
|
|
2883
|
+
validate_dataset_param_value(nreq_param_list, param_dict, list)
|
|
2884
|
+
validate_dataset_param_value(nreq_param_bool, param_dict, bool)
|
|
2885
|
+
validate_dataset_param_value(nreq_param_str, param_dict, str)
|
|
2886
|
+
|
|
2887
|
+
server = param_dict.get('server')
|
|
2888
|
+
if not server.startswith(('http://', 'https://')):
|
|
2889
|
+
raise ValueError("server should be a str that starts with http:// or https://, but got {}.".format(server))
|
|
2890
|
+
|
|
2891
|
+
check_sampler_shuffle_shard_options(param_dict)
|
|
2892
|
+
|
|
2893
|
+
return method(self, *args, **kwargs)
|
|
2894
|
+
|
|
2895
|
+
return new_method
|