mindspore 2.1.0__cp38-cp38-manylinux1_x86_64.whl → 2.2.11__cp38-cp38-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +139 -22
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -1,506 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""GraphKernel Op Infer"""
|
|
16
|
-
|
|
17
|
-
import copy
|
|
18
|
-
import sys
|
|
19
|
-
from functools import reduce as prod_reduce
|
|
20
|
-
from .model import GraphKernelUnsupportedException as GKException
|
|
21
|
-
from .model import PrimLib, DataFormat as DF
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def infer(op_name, inputs, attrs):
|
|
25
|
-
"""infer shape dtype and format"""
|
|
26
|
-
|
|
27
|
-
def _create_opinfer():
|
|
28
|
-
self_module = sys.modules.get(__name__, None)
|
|
29
|
-
if self_module is None:
|
|
30
|
-
raise GKException("OpInfo does not support op {}".format(op_name))
|
|
31
|
-
|
|
32
|
-
if hasattr(self_module, op_name):
|
|
33
|
-
op_cls = getattr(self_module, op_name)
|
|
34
|
-
return op_cls(op_name, inputs, attrs)
|
|
35
|
-
# common infer
|
|
36
|
-
class_name_map = {
|
|
37
|
-
PrimLib.ELEMWISE: "_Elemwise",
|
|
38
|
-
PrimLib.REDUCE: "_Reduce",
|
|
39
|
-
}
|
|
40
|
-
cls_name = class_name_map.get(PrimLib.primtives.get(op_name, PrimLib.default_primtive).iter_type, None)
|
|
41
|
-
if not cls_name:
|
|
42
|
-
raise GKException("OpInfo does not support op {}".format(op_name))
|
|
43
|
-
op_cls = getattr(self_module, cls_name)
|
|
44
|
-
return op_cls(op_name, inputs, attrs)
|
|
45
|
-
|
|
46
|
-
return _create_opinfer().infer()
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
class OpInfer:
|
|
50
|
-
"""
|
|
51
|
-
OpInfer is the base class for inferring operator info in GraphKernel model builder.
|
|
52
|
-
|
|
53
|
-
There are three methods should be overridden to define the infer logic of the operator:
|
|
54
|
-
_infer_shape(), _infer_type() and _infer_format().
|
|
55
|
-
"""
|
|
56
|
-
|
|
57
|
-
def __init__(self, name, inputs, attrs):
|
|
58
|
-
self.name = name
|
|
59
|
-
self.inputs = inputs
|
|
60
|
-
self.attrs = attrs
|
|
61
|
-
|
|
62
|
-
def infer(self):
|
|
63
|
-
"""Infer shape, type and format by op inputs"""
|
|
64
|
-
self._check()
|
|
65
|
-
return self._infer_shape(), self._infer_type(), self._infer_format()
|
|
66
|
-
|
|
67
|
-
def _infer_shape(self):
|
|
68
|
-
return self.inputs[0].shape
|
|
69
|
-
|
|
70
|
-
def _infer_type(self):
|
|
71
|
-
return self.inputs[0].dtype
|
|
72
|
-
|
|
73
|
-
def _infer_format(self):
|
|
74
|
-
return self.inputs[0].data_format
|
|
75
|
-
|
|
76
|
-
def _check(self):
|
|
77
|
-
self._check_shape()
|
|
78
|
-
self._check_type()
|
|
79
|
-
self._check_format()
|
|
80
|
-
|
|
81
|
-
def _check_shape(self):
|
|
82
|
-
pass
|
|
83
|
-
|
|
84
|
-
def _check_type(self):
|
|
85
|
-
"""check all dtypes are same"""
|
|
86
|
-
dtype = self.inputs[0].dtype
|
|
87
|
-
for i, t in enumerate(self.inputs[1:]):
|
|
88
|
-
if t.dtype != dtype:
|
|
89
|
-
raise GKException(
|
|
90
|
-
"Incompatible data type between input {}({}) and {}({})".format(0, dtype, i + 1, t.dtype))
|
|
91
|
-
|
|
92
|
-
def _check_format(self):
|
|
93
|
-
"""check formats are compatible. only DefaultFormat is compatible with others"""
|
|
94
|
-
result = self.inputs[0].data_format
|
|
95
|
-
i = 0
|
|
96
|
-
for j, t in enumerate(self.inputs[1:]):
|
|
97
|
-
if t.data_format != result:
|
|
98
|
-
if DF.DEFAULT not in (result, t.data_format):
|
|
99
|
-
raise GKException("Incompatible format between input {}({}) and {}({})".format(
|
|
100
|
-
i, result, j + 1, t.data_format))
|
|
101
|
-
if result == DF.DEFAULT:
|
|
102
|
-
result = t.data_format
|
|
103
|
-
i = j + 1
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
class _Elemwise(OpInfer):
|
|
107
|
-
"""Common infer for elementwise operators"""
|
|
108
|
-
|
|
109
|
-
@staticmethod
|
|
110
|
-
def broadcast_shape(shapes):
|
|
111
|
-
"""deduce broadcast shape using same rules as numpy"""
|
|
112
|
-
dim_size = max(len(shape) for shape in shapes)
|
|
113
|
-
align_shapes = [[1] * (dim_size - len(shape)) + shape for shape in shapes]
|
|
114
|
-
out_shape = [1] * dim_size
|
|
115
|
-
for i in range(dim_size):
|
|
116
|
-
for align_shape in align_shapes:
|
|
117
|
-
if align_shape[i] == 1:
|
|
118
|
-
continue
|
|
119
|
-
if out_shape[i] == 1:
|
|
120
|
-
out_shape[i] = align_shape[i]
|
|
121
|
-
elif out_shape[i] != align_shape[i]:
|
|
122
|
-
raise GKException("Input shapes {} can not broadcast.".format(shapes))
|
|
123
|
-
return out_shape
|
|
124
|
-
|
|
125
|
-
@staticmethod
|
|
126
|
-
def defaultformat_to_nz(default_shape):
|
|
127
|
-
"""default format shape to fractal_Nz format shape"""
|
|
128
|
-
# As shape (1,) can broadcast to any shape, it can be regarded as a special FractalNZ shape
|
|
129
|
-
if len(default_shape) == 1 and default_shape[0] == 1:
|
|
130
|
-
return default_shape
|
|
131
|
-
more_two_d_shape, two_d_shape = default_shape[:-2], default_shape[-2:]
|
|
132
|
-
# (32) or (1, 32) -> (2, 1, 1, 16)
|
|
133
|
-
if len(two_d_shape) == 1 or (len(two_d_shape) == 2 and two_d_shape[0] == 1):
|
|
134
|
-
shape = [two_d_shape[-1] // 16, 1, 1, 16]
|
|
135
|
-
if two_d_shape[-1] % 16 != 0:
|
|
136
|
-
raise GKException("Can not convert default format shape{} to fractal_Nz format shape, because default "
|
|
137
|
-
"format shape[-1] should be multiplies of 16, but got {}"
|
|
138
|
-
.format(default_shape, two_d_shape[-1]))
|
|
139
|
-
return more_two_d_shape + shape
|
|
140
|
-
# (32, 1) -> (1, 2, 16, 1)
|
|
141
|
-
if len(two_d_shape) == 2 and two_d_shape[1] == 1:
|
|
142
|
-
shape = [1, two_d_shape[0] // 16, 16, 1]
|
|
143
|
-
if two_d_shape[0] % 16 != 0:
|
|
144
|
-
raise GKException("Can not convert default format shape{} to fractal_Nz format shape, because default "
|
|
145
|
-
"format shape[-2] should be multiples of 16, but got {}"
|
|
146
|
-
.format(default_shape, two_d_shape[0]))
|
|
147
|
-
return more_two_d_shape + shape
|
|
148
|
-
# (32, 48) -> (3, 2, 16, 16)
|
|
149
|
-
shape = [two_d_shape[1] // 16, two_d_shape[0] // 16, 16, 16]
|
|
150
|
-
if two_d_shape[0] % 16 != 0 or two_d_shape[1] % 16 != 0:
|
|
151
|
-
raise GKException("Can not convert default format shape{} to fractal_Nz format shape, because default "
|
|
152
|
-
"format shape[-2] and shape[-1] should be multiples of 16, but got {} and {}"
|
|
153
|
-
.format(default_shape, two_d_shape[0], two_d_shape[1]))
|
|
154
|
-
return more_two_d_shape + shape
|
|
155
|
-
|
|
156
|
-
def _infer_shape(self):
|
|
157
|
-
"""returns the output shape with broadcast"""
|
|
158
|
-
|
|
159
|
-
# in case all inputs are default format/NHWC/NCHW
|
|
160
|
-
is_default = [op_input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW) for op_input in self.inputs]
|
|
161
|
-
if all(is_default):
|
|
162
|
-
return self.broadcast_shape([op_input.shape for op_input in self.inputs])
|
|
163
|
-
|
|
164
|
-
# in case formats are fractal_nz, default_fromat/NHWC/HCHW(optional)
|
|
165
|
-
is_default_frac_nz = (op_input.data_format in (DF.DEFAULT, DF.NHWC, DF.NCHW, DF.FRAC_NZ)
|
|
166
|
-
for op_input in self.inputs)
|
|
167
|
-
if all(is_default_frac_nz):
|
|
168
|
-
nz_shapes = [self.defaultformat_to_nz(op_input.shape) if op_input.data_format != DF.FRAC_NZ
|
|
169
|
-
else op_input.shape for op_input in self.inputs]
|
|
170
|
-
return self.broadcast_shape(nz_shapes)
|
|
171
|
-
|
|
172
|
-
inputs_format = [op_input.data_format for op_input in self.inputs]
|
|
173
|
-
raise GKException("Only support DefaultFormat, NHWC, NCHW and FRACTAL_NZ in inputs format, but got {}"
|
|
174
|
-
.format(inputs_format))
|
|
175
|
-
|
|
176
|
-
def _infer_format(self):
|
|
177
|
-
for tensor in self.inputs:
|
|
178
|
-
if tensor.data_format != DF.DEFAULT:
|
|
179
|
-
return tensor.data_format
|
|
180
|
-
return DF.DEFAULT
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
class _Reduce(OpInfer):
|
|
184
|
-
"""Common infer for reduction operators"""
|
|
185
|
-
|
|
186
|
-
def _check(self):
|
|
187
|
-
super(_Reduce, self)._check()
|
|
188
|
-
# check reduce axis in the range [-len, len)
|
|
189
|
-
shape_len = len(self.inputs[0].shape)
|
|
190
|
-
axis = self.attrs['reduce_axis']
|
|
191
|
-
if isinstance(axis, int):
|
|
192
|
-
axis = [axis]
|
|
193
|
-
if not all((-shape_len <= i < shape_len) for i in axis):
|
|
194
|
-
raise GKException(
|
|
195
|
-
"Reduce axis should be in range [{},{}) but got {}".format(-shape_len, shape_len, axis))
|
|
196
|
-
|
|
197
|
-
def _infer_shape(self):
|
|
198
|
-
shape = copy.deepcopy(self.inputs[0].shape)
|
|
199
|
-
axis = self.attrs['reduce_axis']
|
|
200
|
-
|
|
201
|
-
if isinstance(axis, int):
|
|
202
|
-
axis = [axis]
|
|
203
|
-
if any(i < 0 for i in axis):
|
|
204
|
-
# change the axis to non-negative number.
|
|
205
|
-
axis = list(map(lambda i: i + len(shape) if i < 0 else i, axis))
|
|
206
|
-
self.attrs['reduce_axis'] = sorted(axis)
|
|
207
|
-
|
|
208
|
-
if self.attrs['keep_dims']:
|
|
209
|
-
for i in axis:
|
|
210
|
-
shape[i] = 1
|
|
211
|
-
return shape
|
|
212
|
-
|
|
213
|
-
real_shape = []
|
|
214
|
-
for i, s in enumerate(shape):
|
|
215
|
-
if i not in axis:
|
|
216
|
-
real_shape.append(s)
|
|
217
|
-
return real_shape
|
|
218
|
-
|
|
219
|
-
def _infer_format(self):
|
|
220
|
-
return DF.DEFAULT
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
class _Reshape(OpInfer):
|
|
224
|
-
"""Common infer for reshape operators, should not be instantiated"""
|
|
225
|
-
|
|
226
|
-
def _infer_shape(self):
|
|
227
|
-
raise GKException("_infer_shape should be implemented by subclass")
|
|
228
|
-
|
|
229
|
-
def _infer_format(self):
|
|
230
|
-
return DF.DEFAULT if "format" not in self.attrs else self.attrs["format"]
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
class Reshape(_Reshape):
|
|
234
|
-
"""Reshape op infer"""
|
|
235
|
-
|
|
236
|
-
def _check_shape(self):
|
|
237
|
-
input_shape = self.inputs[0].shape
|
|
238
|
-
output_shape = self.attrs["shape"]
|
|
239
|
-
size_before_reshape = prod_reduce(lambda x, y: x * y, input_shape)
|
|
240
|
-
size_after_reshape = prod_reduce(lambda x, y: x * y, output_shape)
|
|
241
|
-
if size_before_reshape != size_after_reshape:
|
|
242
|
-
raise GKException("For 'Reshape', can not reshape {} to {}".format(input_shape, output_shape))
|
|
243
|
-
|
|
244
|
-
def _infer_shape(self):
|
|
245
|
-
return self.attrs["shape"]
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
class Cast(_Elemwise):
|
|
249
|
-
"""Cast op infer"""
|
|
250
|
-
|
|
251
|
-
def _infer_type(self):
|
|
252
|
-
return self.attrs["dst_type"]
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
class InplaceAssign(_Elemwise):
|
|
256
|
-
"""InplaceAssign op infer"""
|
|
257
|
-
|
|
258
|
-
def _infer_shape(self):
|
|
259
|
-
return self.inputs[2].shape
|
|
260
|
-
|
|
261
|
-
def _infer_type(self):
|
|
262
|
-
return self.inputs[2].dtype
|
|
263
|
-
|
|
264
|
-
def _infer_format(self):
|
|
265
|
-
return self.inputs[2].data_format
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
class BroadcastTo(OpInfer):
|
|
269
|
-
"""BroadcastTo op infer"""
|
|
270
|
-
|
|
271
|
-
def _infer_shape(self):
|
|
272
|
-
return self.attrs["shape"]
|
|
273
|
-
|
|
274
|
-
def _infer_format(self):
|
|
275
|
-
return self.inputs[0].data_format
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
class _CompareOp(_Elemwise):
|
|
279
|
-
"""Compare operators"""
|
|
280
|
-
|
|
281
|
-
def _infer_type(self):
|
|
282
|
-
return "bool"
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
class CImag(OpInfer):
|
|
286
|
-
"""CImag op infer"""
|
|
287
|
-
|
|
288
|
-
def _check_type(self):
|
|
289
|
-
if self.inputs[0].dtype != "complex64" and self.inputs[0].dtype != "complex128":
|
|
290
|
-
raise GKException("For 'CImag', input[0] should be of type complex64 or"
|
|
291
|
-
"type complex128, but got {}".format(self.inputs[0].dtype))
|
|
292
|
-
|
|
293
|
-
def _infer_type(self):
|
|
294
|
-
if self.inputs[0].dtype == "complex64":
|
|
295
|
-
return "float32"
|
|
296
|
-
return "float64"
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
class CReal(OpInfer):
|
|
300
|
-
"""CReal op infer"""
|
|
301
|
-
|
|
302
|
-
def _check_type(self):
|
|
303
|
-
if self.inputs[0].dtype != "complex64" and self.inputs[0].dtype != "complex128":
|
|
304
|
-
raise GKException("For 'CReal', input[0] should be of type complex64 or"
|
|
305
|
-
"type complex128, but got {}".format(self.inputs[0].dtype))
|
|
306
|
-
|
|
307
|
-
def _infer_type(self):
|
|
308
|
-
if self.inputs[0].dtype == "complex64":
|
|
309
|
-
return "float32"
|
|
310
|
-
return "float64"
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
class Complex(OpInfer):
|
|
314
|
-
"""Complex op infer"""
|
|
315
|
-
|
|
316
|
-
def _check_type(self):
|
|
317
|
-
if self.inputs[0].dtype != "float32" and self.inputs[0].dtype != "float64":
|
|
318
|
-
raise GKException("For 'Complex', input[0] should be of type float32 or type float64,"
|
|
319
|
-
"but got {}".format(self.inputs[0].dtype))
|
|
320
|
-
if self.inputs[0].dtype != self.inputs[1].dtype:
|
|
321
|
-
raise GKException("For 'Complex', inputs data type mismatch ({} vs {})"
|
|
322
|
-
.format(self.inputs[0].dtype, self.inputs[1].dtype))
|
|
323
|
-
|
|
324
|
-
def _infer_type(self):
|
|
325
|
-
if self.inputs[0].dtype == "float32":
|
|
326
|
-
return "complex64"
|
|
327
|
-
return "complex128"
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
class Less(_CompareOp):
|
|
331
|
-
"""Less op infer"""
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
class LessEqual(_CompareOp):
|
|
335
|
-
"""LessEqual op infer"""
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
class Equal(_CompareOp):
|
|
339
|
-
"""Equal op infer"""
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
class Greater(_CompareOp):
|
|
343
|
-
"""Greater op infer"""
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
class GreaterEqual(_CompareOp):
|
|
347
|
-
"""GreaterEqual op infer"""
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
class Select(_Elemwise):
|
|
351
|
-
"""Select op infer"""
|
|
352
|
-
|
|
353
|
-
def _check_type(self):
|
|
354
|
-
if self.inputs[0].dtype != "bool":
|
|
355
|
-
raise GKException("For 'Select', input[0] should be of type bool, but got {}".format(self.inputs[0].dtype))
|
|
356
|
-
if self.inputs[1].dtype != self.inputs[2].dtype:
|
|
357
|
-
raise GKException("For 'Select', input[1] and input[2] data type mismatch ({} vs {})"
|
|
358
|
-
.format(self.inputs[1].dtype, self.inputs[2].dtype))
|
|
359
|
-
|
|
360
|
-
def _infer_type(self):
|
|
361
|
-
return self.inputs[1].dtype
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
def check_format_any(formats, checked_format):
|
|
365
|
-
"""Check whether input format in formats list"""
|
|
366
|
-
if not isinstance(formats, (list, tuple)):
|
|
367
|
-
raise GKException("formats {} should be of type list or tuple, but got {}.".format(formats, type(formats)))
|
|
368
|
-
if checked_format not in formats:
|
|
369
|
-
raise GKException("Check {} failed: can not find it in {}".format(checked_format, formats))
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
def check_nd(data, nd):
|
|
373
|
-
"""Check whether data are nd format"""
|
|
374
|
-
if not isinstance(data, (list, tuple)) or len(data) != nd:
|
|
375
|
-
raise GKException("input should be {}D list or tuple, but got {}.".format(nd, data))
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
def conv_had_pad(pad_list, pad_mode):
|
|
379
|
-
"""Check whether conv need to add pad"""
|
|
380
|
-
if not isinstance(pad_list, (list, tuple)) or len(pad_list) != 4:
|
|
381
|
-
raise GKException("pad_list should be 4D list or tuple, but got {}".format(pad_list))
|
|
382
|
-
if pad_list[0] != pad_list[1] or pad_list[2] != pad_list[3]:
|
|
383
|
-
return True
|
|
384
|
-
if pad_mode not in ["VALID", "valid"]:
|
|
385
|
-
for _, pad in enumerate(pad_list):
|
|
386
|
-
if pad != 0:
|
|
387
|
-
return True
|
|
388
|
-
return False
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
class Conv2D(OpInfer):
|
|
392
|
-
"""Conv2D infer"""
|
|
393
|
-
|
|
394
|
-
def _infer_type(self):
|
|
395
|
-
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
|
|
396
|
-
return self.attrs["dst_type"]
|
|
397
|
-
return self.inputs[0].dtype
|
|
398
|
-
|
|
399
|
-
def _infer_shape(self):
|
|
400
|
-
shape_0 = list(self.inputs[0].shape)
|
|
401
|
-
shape_1 = list(self.inputs[1].shape)
|
|
402
|
-
check_nd(shape_0, 4)
|
|
403
|
-
check_nd(shape_1, 4)
|
|
404
|
-
|
|
405
|
-
formats = [self.inputs[0].data_format, self.inputs[1].data_format, self.attrs["format"]]
|
|
406
|
-
check_format_any(formats, DF.NHWC)
|
|
407
|
-
|
|
408
|
-
n, h, w, out_channel = shape_0[0], shape_0[1], shape_0[2], shape_1[0]
|
|
409
|
-
pad_list = self.attrs["pad_list"]
|
|
410
|
-
pad_mode = self.attrs["pad_mode"]
|
|
411
|
-
kernel_size = self.attrs["kernel_size"]
|
|
412
|
-
stride = self.attrs["stride"]
|
|
413
|
-
dilation = self.attrs["dilation"]
|
|
414
|
-
check_nd(pad_list, 4)
|
|
415
|
-
check_nd(kernel_size, 2)
|
|
416
|
-
check_nd(stride, 4)
|
|
417
|
-
check_nd(dilation, 4)
|
|
418
|
-
|
|
419
|
-
has_pad = conv_had_pad(pad_list, pad_mode)
|
|
420
|
-
if not has_pad:
|
|
421
|
-
pad_list = [0, 0, 0, 0]
|
|
422
|
-
|
|
423
|
-
k_h = (kernel_size[0] - 1) * dilation[-2] + 1
|
|
424
|
-
k_w = (kernel_size[1] - 1) * dilation[-1] + 1
|
|
425
|
-
out_h = (h + pad_list[0] + pad_list[1] - k_h) // stride[-2] + 1
|
|
426
|
-
out_w = (w + pad_list[2] + pad_list[3] - k_w) // stride[-1] + 1
|
|
427
|
-
return [n, out_h, out_w, out_channel]
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
class MatMul(OpInfer):
|
|
431
|
-
"""MatMul infer"""
|
|
432
|
-
|
|
433
|
-
def _infer_type(self):
|
|
434
|
-
if isinstance(self.attrs, dict) and "dst_type" in self.attrs:
|
|
435
|
-
return self.attrs["dst_type"]
|
|
436
|
-
return self.inputs[0].dtype
|
|
437
|
-
|
|
438
|
-
def _infer_shape(self):
|
|
439
|
-
shape_0 = list(self.inputs[0].shape)
|
|
440
|
-
shape_1 = list(self.inputs[1].shape)
|
|
441
|
-
if len(shape_0) != 2 or len(shape_1) != 2:
|
|
442
|
-
raise GKException("For 'MatMul', inputs shape must be 2D, but got {}, {}"
|
|
443
|
-
.format(shape_0, shape_1))
|
|
444
|
-
transpose_a = self.attrs["transpose_a"]
|
|
445
|
-
transpose_b = self.attrs["transpose_b"]
|
|
446
|
-
m, k1 = (shape_0[-1], shape_0[-2]) if transpose_a else (shape_0[-2], shape_0[-1])
|
|
447
|
-
k2, n = (shape_1[-1], shape_1[-2]) if transpose_b else (shape_1[-2], shape_1[-1])
|
|
448
|
-
if k1 != k2:
|
|
449
|
-
raise GKException("For 'MatMul', inputs have different k value: {} vs {}".format(k1, k2))
|
|
450
|
-
output_shape = [m, n]
|
|
451
|
-
return output_shape
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
class PadAkg(OpInfer):
|
|
455
|
-
"""PadAkg infer"""
|
|
456
|
-
|
|
457
|
-
def _infer_shape(self):
|
|
458
|
-
shape = list(self.inputs[0].shape)
|
|
459
|
-
n = len(shape)
|
|
460
|
-
pad_before = list(self.attrs["head"])
|
|
461
|
-
pad_after = list(self.attrs["tail"])
|
|
462
|
-
if len(pad_before) != n or len(pad_after) != n:
|
|
463
|
-
raise GKException("For 'PadAkg', input dimension and pad mismatch: {}d vs {}d vs {}d"
|
|
464
|
-
.format(n, len(pad_before), len(pad_after)))
|
|
465
|
-
out_shape = [shape[i] + pad_before[i] + pad_after[i] for i in range(n)]
|
|
466
|
-
return out_shape
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
class UnPadAkg(OpInfer):
|
|
470
|
-
"""UnPadAkg infer"""
|
|
471
|
-
|
|
472
|
-
def _infer_shape(self):
|
|
473
|
-
shape = list(self.inputs[0].shape)
|
|
474
|
-
n = len(shape)
|
|
475
|
-
unpad_after = list(self.attrs["tail"])
|
|
476
|
-
if len(unpad_after) != n:
|
|
477
|
-
raise GKException("For 'UnPadAkg', input dimension and pad mismatch: {}d vs {}d"
|
|
478
|
-
.format(n, len(unpad_after)))
|
|
479
|
-
out_shape = [shape[i] - unpad_after[i] for i in range(n)]
|
|
480
|
-
return out_shape
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
class Gather(OpInfer):
|
|
484
|
-
"""Gather infer"""
|
|
485
|
-
|
|
486
|
-
def _infer_shape(self):
|
|
487
|
-
input_shape = self.inputs[0].shape
|
|
488
|
-
indices_shape = self.inputs[1].shape
|
|
489
|
-
axis = self.attrs['axis']
|
|
490
|
-
output_shape = input_shape
|
|
491
|
-
indices_shape_one_dim = 1
|
|
492
|
-
for dim in indices_shape:
|
|
493
|
-
indices_shape_one_dim *= dim
|
|
494
|
-
output_shape[axis] = indices_shape_one_dim
|
|
495
|
-
return output_shape
|
|
496
|
-
|
|
497
|
-
def _infer_type(self):
|
|
498
|
-
return self.inputs[0].dtype
|
|
499
|
-
|
|
500
|
-
def _infer_format(self):
|
|
501
|
-
return self.inputs[0].data_format
|
|
502
|
-
|
|
503
|
-
def _check_type(self):
|
|
504
|
-
if self.inputs[1].dtype != "int32":
|
|
505
|
-
raise GKException("For 'Gather', inputs[1] should be of type int32, but got {}"
|
|
506
|
-
.format(self.inputs[1].dtype))
|
|
@@ -1,20 +0,0 @@
|
|
|
1
|
-
# Copyright 2019 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
|
|
16
|
-
"""Preprocess of dataset.
|
|
17
|
-
"""
|
|
18
|
-
from __future__ import absolute_import
|
|
19
|
-
|
|
20
|
-
from mindspore.dataset.datapreprocess.preprocess_imagenet_validate_dataset import *
|
|
@@ -1,54 +0,0 @@
|
|
|
1
|
-
# Copyright 2019 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ==============================================================================
|
|
15
|
-
"""Process imagenet validate dataset.
|
|
16
|
-
"""
|
|
17
|
-
from __future__ import absolute_import
|
|
18
|
-
|
|
19
|
-
import os
|
|
20
|
-
import stat
|
|
21
|
-
from mindspore import log as logger
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def preprocess_imagenet_validation_dataset(train_dataset_path, validation_dataset_path, image_label_mapping_file):
|
|
25
|
-
"""
|
|
26
|
-
Call this function before read imagenet validation dataset.
|
|
27
|
-
|
|
28
|
-
Args:
|
|
29
|
-
train_dataset_path (str): train dataset path
|
|
30
|
-
validation_dataset_path (str): validation dataset path
|
|
31
|
-
image_label_mapping_file (str): imagenet_validate_dataset_2012_image_dir_map.txt file path
|
|
32
|
-
"""
|
|
33
|
-
train_dataset_path = os.path.realpath(train_dataset_path)
|
|
34
|
-
sub_dir = [dir_.name for dir_ in os.scandir(train_dataset_path) if dir_.is_dir()]
|
|
35
|
-
for sub_dir_name in sub_dir:
|
|
36
|
-
validate_sub_dir = os.path.join(validation_dataset_path, sub_dir_name)
|
|
37
|
-
validate_sub_dir = os.path.realpath(validate_sub_dir)
|
|
38
|
-
if not os.path.exists(validate_sub_dir):
|
|
39
|
-
os.makedirs(validate_sub_dir, mode=stat.S_IRWXU)
|
|
40
|
-
real_file_path = os.path.realpath(image_label_mapping_file)
|
|
41
|
-
mappings = [mapping.strip() for mapping in open(real_file_path).readlines()]
|
|
42
|
-
for mapping in mappings:
|
|
43
|
-
image_dir = mapping.split(':')
|
|
44
|
-
old_image_path = os.path.join(validation_dataset_path, image_dir[0])
|
|
45
|
-
old_image_path = os.path.realpath(old_image_path)
|
|
46
|
-
if not os.path.exists(old_image_path):
|
|
47
|
-
logger.warning('Image is not existed %s', old_image_path)
|
|
48
|
-
new_image_sub_dir = os.path.join(validation_dataset_path, image_dir[1])
|
|
49
|
-
new_image_sub_dir = os.path.realpath(new_image_sub_dir)
|
|
50
|
-
new_image_path = os.path.join(new_image_sub_dir, image_dir[0])
|
|
51
|
-
new_image_path = os.path.realpath(new_image_path)
|
|
52
|
-
if not os.path.exists(new_image_sub_dir):
|
|
53
|
-
logger.warning('Image sub dir is not existed %s', new_image_sub_dir)
|
|
54
|
-
os.rename(old_image_path, new_image_path)
|