mindspore 2.1.0__cp37-cp37m-manylinux1_x86_64.whl → 2.2.11__cp37-cp37m-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +139 -22
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/ascend_profilier/cann_file_parser.py +76 -0
- mindspore/_akg/akg/utils/ascend_profilier/file_manager.py +56 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_bean.py +23 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_headers.py +8 -0
- mindspore/_akg/akg/utils/ascend_profilier/op_summary_parser.py +42 -0
- mindspore/_akg/akg/utils/ascend_profilier/path_manager.py +65 -0
- mindspore/_akg/akg/utils/composite_op_helper.py +16 -12
- mindspore/_akg/akg/utils/dump_ascend_meta.py +22 -3
- mindspore/_akg/akg/utils/kernel_exec.py +98 -274
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +219 -0
- mindspore/_akg/akg/utils/util.py +56 -1
- mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +47 -11
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/config/super_bar_config.json +4 -2
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/type_id.h +1 -0
- mindspore/include/mindapi/base/types.h +1 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310/aic-ascend310-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend310p/aic-ascend310p-ops-info.json +123 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910/aic-ascend910-ops-info.json +158 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/config/ascend910b/aic-ascend910b-ops-info.json +37 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/ai_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_dsl.py +46 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/add_tik.py +51 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/kv_cache_mgr.py +241 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_impl/vector_core/tbe/custom_aicore_ops_impl/matmul_tik.py +212 -0
- mindspore/lib/plugin/ascend/custom_aicore_ops/op_proto/libop_proto.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +8998 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libakg.so +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/scipy/linalg.py +1 -1
- mindspore/scipy/ops.py +55 -5
- mindspore/scipy/optimize/__init__.py +3 -2
- mindspore/scipy/optimize/linear_sum_assignment.py +38 -33
- mindspore/scipy/optimize/minimize.py +7 -3
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +8 -8
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +488 -539
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- /mindspore/{ops/_op_impl/_custom_op/flash_attention → _akg/akg/utils/ascend_profilier}/__init__.py +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
mindspore/include/api/net.h
DELETED
|
@@ -1,142 +0,0 @@
|
|
|
1
|
-
/**
|
|
2
|
-
* Copyright 2022-2023 Huawei Technologies Co., Ltd
|
|
3
|
-
*
|
|
4
|
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
-
* you may not use this file except in compliance with the License.
|
|
6
|
-
* You may obtain a copy of the License at
|
|
7
|
-
*
|
|
8
|
-
* http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
-
*
|
|
10
|
-
* Unless required by applicable law or agreed to in writing, software
|
|
11
|
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
-
* See the License for the specific language governing permissions and
|
|
14
|
-
* limitations under the License.
|
|
15
|
-
*/
|
|
16
|
-
|
|
17
|
-
#ifndef MINDSPORE_INCLUDE_API_NET_H
|
|
18
|
-
#define MINDSPORE_INCLUDE_API_NET_H
|
|
19
|
-
|
|
20
|
-
#include <memory>
|
|
21
|
-
#include <vector>
|
|
22
|
-
#include <unordered_set>
|
|
23
|
-
#include <string>
|
|
24
|
-
#include "include/api/types.h"
|
|
25
|
-
#include "include/api/data_type.h"
|
|
26
|
-
#include "include/api/cfg.h"
|
|
27
|
-
|
|
28
|
-
namespace mindspore {
|
|
29
|
-
/// \brief Register node or sub network
|
|
30
|
-
#define REG(_name) Register(_name, #_name)
|
|
31
|
-
|
|
32
|
-
class Expr;
|
|
33
|
-
class NodeImpl;
|
|
34
|
-
class NetImpl;
|
|
35
|
-
class NodeSet;
|
|
36
|
-
class Graph;
|
|
37
|
-
class NetData;
|
|
38
|
-
|
|
39
|
-
class MS_API NetBase {
|
|
40
|
-
public:
|
|
41
|
-
NetBase() = default;
|
|
42
|
-
virtual std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) = 0;
|
|
43
|
-
virtual uint32_t type() = 0;
|
|
44
|
-
};
|
|
45
|
-
|
|
46
|
-
class MS_API Node : public NetBase {
|
|
47
|
-
public:
|
|
48
|
-
Node();
|
|
49
|
-
virtual ~Node();
|
|
50
|
-
/// \brief Create output expression from node
|
|
51
|
-
|
|
52
|
-
/// \param[in] name Name of input (like "labels" etc.)
|
|
53
|
-
///
|
|
54
|
-
/// \return Expression
|
|
55
|
-
Expr *Create(std::string name);
|
|
56
|
-
/// \brief Run node on inputs. This operator is used in Net::construct()
|
|
57
|
-
///
|
|
58
|
-
/// \param[in] inputs Inputs expression for the node.
|
|
59
|
-
/// \return Output node expression vector
|
|
60
|
-
std::vector<Expr *> operator()(const std::vector<Expr *> &inputs) override;
|
|
61
|
-
uint32_t type() final;
|
|
62
|
-
|
|
63
|
-
private:
|
|
64
|
-
friend NodeImpl;
|
|
65
|
-
std::shared_ptr<NodeImpl> impl_ = nullptr;
|
|
66
|
-
};
|
|
67
|
-
|
|
68
|
-
class MS_API Net : public NetBase, public std::enable_shared_from_this<Net> {
|
|
69
|
-
public:
|
|
70
|
-
Net();
|
|
71
|
-
virtual ~Net();
|
|
72
|
-
explicit Net(std::string name);
|
|
73
|
-
explicit Net(const Graph &g);
|
|
74
|
-
/// \brief Define the relation between network inputs and outputs
|
|
75
|
-
///
|
|
76
|
-
/// \param[in] inputs expression vector
|
|
77
|
-
///
|
|
78
|
-
/// \return expression vector
|
|
79
|
-
|
|
80
|
-
virtual std::vector<Expr *> construct(const std::vector<Expr *> &inputs);
|
|
81
|
-
/// \brief Addition operation
|
|
82
|
-
///
|
|
83
|
-
/// \param[in] inputs Two elements to add
|
|
84
|
-
///
|
|
85
|
-
/// \return expression vector (single element)
|
|
86
|
-
|
|
87
|
-
/// \brief Execution operator. Connect inputs to outputs via user defined construct
|
|
88
|
-
///
|
|
89
|
-
/// \return expression vector
|
|
90
|
-
|
|
91
|
-
std::vector<Expr *> operator()(const std::vector<Expr *> &inputs);
|
|
92
|
-
void Register(Net *net, std::string &&name);
|
|
93
|
-
void Register(Node *node, std::string &&name);
|
|
94
|
-
/// \brief Find the trainable params for the trained network
|
|
95
|
-
///
|
|
96
|
-
/// \return NodeSet for all trainable nodes
|
|
97
|
-
std::shared_ptr<NodeSet> trainable_params();
|
|
98
|
-
virtual void Add(NetBase *element);
|
|
99
|
-
/// \brief Input shape
|
|
100
|
-
///
|
|
101
|
-
/// \param[in] idx input index
|
|
102
|
-
///
|
|
103
|
-
/// \return Specific input shape vector
|
|
104
|
-
const std::vector<int> InputShape(int idx);
|
|
105
|
-
/// \brief Output shape
|
|
106
|
-
///
|
|
107
|
-
/// \param[in] idx Output index
|
|
108
|
-
///
|
|
109
|
-
/// \return Specific output shape vector
|
|
110
|
-
const std::vector<int> OutputShape(int idx);
|
|
111
|
-
uint32_t type() final;
|
|
112
|
-
|
|
113
|
-
private:
|
|
114
|
-
friend NetImpl;
|
|
115
|
-
friend NetData;
|
|
116
|
-
std::shared_ptr<NetImpl> impl_;
|
|
117
|
-
};
|
|
118
|
-
|
|
119
|
-
class MS_API SoftMaxCrossEntropyCfg {
|
|
120
|
-
public:
|
|
121
|
-
std::string reduction = "mean"; /**< Specifies reduction mode. The optional values are "none", "mean", "sum" */
|
|
122
|
-
};
|
|
123
|
-
|
|
124
|
-
class MS_API AdamConfig {
|
|
125
|
-
public:
|
|
126
|
-
float learning_rate_ = 1e-3;
|
|
127
|
-
float beta1_ = 0.9;
|
|
128
|
-
float beta2_ = 0.999;
|
|
129
|
-
float eps_ = 1e-08;
|
|
130
|
-
bool use_nesterov_ = false;
|
|
131
|
-
};
|
|
132
|
-
|
|
133
|
-
namespace NN {
|
|
134
|
-
MS_API Net *NetWithLoss(Net *net, Node *loss);
|
|
135
|
-
MS_API Graph *GraphWithLoss(Graph *g, Node *loss);
|
|
136
|
-
MS_API Node *Adam(std::shared_ptr<NodeSet> learn, const AdamConfig &cfg);
|
|
137
|
-
MS_API Node *SoftmaxCrossEntropy(const SoftMaxCrossEntropyCfg &cfg);
|
|
138
|
-
MS_API std::unique_ptr<Node> Input(std::vector<int> dims, DataType data_type = DataType::kNumberTypeFloat32,
|
|
139
|
-
int fmt = NHWC);
|
|
140
|
-
}; // namespace NN
|
|
141
|
-
} // namespace mindspore
|
|
142
|
-
#endif // MINDSPORE_INCLUDE_API_NET_H
|
mindspore/nn/lr_scheduler.py
DELETED
|
@@ -1,262 +0,0 @@
|
|
|
1
|
-
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""LRScheduler."""
|
|
16
|
-
from mindspore import ops
|
|
17
|
-
from mindspore.nn.optim_ex.optimizer import Optimizer
|
|
18
|
-
from mindspore.common.api import jit_class
|
|
19
|
-
from mindspore.common.parameter import Parameter
|
|
20
|
-
from mindspore.common import Tensor
|
|
21
|
-
import mindspore.common.dtype as mstype
|
|
22
|
-
from mindspore.ops import functional as F
|
|
23
|
-
from mindspore import _checkparam as Validator
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
__all__ = ['StepLR', 'LinearLR', 'LRScheduler']
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
@jit_class
|
|
30
|
-
class LRScheduler():
|
|
31
|
-
r"""
|
|
32
|
-
Basic class of learning rate schedule.
|
|
33
|
-
|
|
34
|
-
.. warning::
|
|
35
|
-
This is an experimental lr scheduler module that is subject to change.
|
|
36
|
-
This module must be used with optimizers in `Experimental Optimizer
|
|
37
|
-
<https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.nn.html#experimental-optimizer>`_ .
|
|
38
|
-
|
|
39
|
-
Args:
|
|
40
|
-
optimizer (:class:`mindspore.nn.optim_ex.Optimizer`): The optimizer instance.
|
|
41
|
-
last_epoch (int, optional): The epoch/step number. Default: ``-1``.
|
|
42
|
-
verbose (bool, optional): Whether to print lr information. Default: ``False``.
|
|
43
|
-
|
|
44
|
-
Raises:
|
|
45
|
-
TypeError: If `optimizer` is not an Optimizer.
|
|
46
|
-
TypeError: If `last_epoch` is not greater than -1.
|
|
47
|
-
ValueError: If `verbose` is not bool.
|
|
48
|
-
|
|
49
|
-
Supported Platforms:
|
|
50
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
51
|
-
"""
|
|
52
|
-
|
|
53
|
-
def __init__(self, optimizer, last_epoch=-1, verbose=False):
|
|
54
|
-
if not isinstance(optimizer, Optimizer):
|
|
55
|
-
raise TypeError('{} is not an Optimizer'.format(
|
|
56
|
-
type(optimizer).__name__))
|
|
57
|
-
Validator.check_value_type("last_epoch", last_epoch, [int])
|
|
58
|
-
if last_epoch < -1:
|
|
59
|
-
raise ValueError("Invalid last_epoch: {}".format(last_epoch))
|
|
60
|
-
Validator.check_value_type("verbose", verbose, [bool])
|
|
61
|
-
|
|
62
|
-
self.optimizer = optimizer
|
|
63
|
-
self._last_lr = []
|
|
64
|
-
self.groups_num = len(optimizer.param_groups)
|
|
65
|
-
self.verbose = verbose
|
|
66
|
-
self.last_epoch = Parameter(Tensor(last_epoch, dtype=mstype.float32),
|
|
67
|
-
name='last_epoch_' + self.__class__.__name__)
|
|
68
|
-
self.increase_tensor = Tensor(1, mstype.int32)
|
|
69
|
-
self.assignadd = ops.AssignAdd()
|
|
70
|
-
self.step()
|
|
71
|
-
|
|
72
|
-
@staticmethod
|
|
73
|
-
def _get_lr():
|
|
74
|
-
"""
|
|
75
|
-
Compute current lr.
|
|
76
|
-
|
|
77
|
-
This method must be overridden by all subclasses.
|
|
78
|
-
"""
|
|
79
|
-
raise NotImplementedError
|
|
80
|
-
|
|
81
|
-
@staticmethod
|
|
82
|
-
def _print_lr(is_verbose, group, lr):
|
|
83
|
-
"""
|
|
84
|
-
Display the current learning rate.
|
|
85
|
-
"""
|
|
86
|
-
if is_verbose:
|
|
87
|
-
print('Adjusting learning rate of group %s to %s.' % (group, lr.value()))
|
|
88
|
-
|
|
89
|
-
def get_last_lr(self):
|
|
90
|
-
"""
|
|
91
|
-
Return last computed learning rate by current scheduler.
|
|
92
|
-
"""
|
|
93
|
-
return [group["lr"].value() for group in self.optimizer.param_groups]
|
|
94
|
-
|
|
95
|
-
def step(self):
|
|
96
|
-
"""
|
|
97
|
-
Get the current learning rate and change the learning rate.
|
|
98
|
-
"""
|
|
99
|
-
self.assignadd(self.last_epoch, self.increase_tensor)
|
|
100
|
-
values = self._get_lr()
|
|
101
|
-
for i in range(self.groups_num):
|
|
102
|
-
lr = values[i]
|
|
103
|
-
lr = F.depend(lr, F.assign(self.optimizer.param_groups[i]["lr"], lr))
|
|
104
|
-
self._print_lr(self.verbose, i, lr)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
@jit_class
|
|
108
|
-
class StepLR(LRScheduler):
|
|
109
|
-
"""Decays the learning rate of each parameter group by gamma every
|
|
110
|
-
step_size epochs. Notice that such decay can happen simultaneously with
|
|
111
|
-
other changes to the learning rate from outside this scheduler.
|
|
112
|
-
|
|
113
|
-
.. warning::
|
|
114
|
-
This is an experimental lr scheduler module that is subject to change.
|
|
115
|
-
This module must be used with optimizers in `Experimental Optimizer
|
|
116
|
-
<https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.nn.html#experimental-optimizer>`_ .
|
|
117
|
-
|
|
118
|
-
Args:
|
|
119
|
-
optimizer (:class:`mindspore.nn.optim_ex.Optimizer`): Wrapped optimizer.
|
|
120
|
-
step_size (int): Period of learning rate decay.
|
|
121
|
-
gamma (float, optional): Multiplicative factor of learning rate decay.
|
|
122
|
-
Default: ``0.1``.
|
|
123
|
-
last_epoch (int, optional): The index of last epoch. Default: ``-1``.
|
|
124
|
-
verbose (bool, optional): If ``True``, prints a message to stdout for
|
|
125
|
-
each update. Default: ``False``.
|
|
126
|
-
|
|
127
|
-
Supported Platforms:
|
|
128
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
129
|
-
|
|
130
|
-
Examples:
|
|
131
|
-
>>> import mindspore
|
|
132
|
-
>>> from mindspore import nn
|
|
133
|
-
>>> # Define the network structure of LeNet5. Refer to
|
|
134
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
|
|
135
|
-
>>> net = LeNet5()
|
|
136
|
-
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
137
|
-
>>> optimizer = nn.optim_ex.Adam(net.trainable_params(), lr=0.05)
|
|
138
|
-
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
|
139
|
-
>>> # lr = 0.05 if epoch < 2
|
|
140
|
-
>>> # lr = 0.005 if 2 <= epoch < 4
|
|
141
|
-
>>> # lr = 0.0005 if 4 <= epoch < 6
|
|
142
|
-
>>> scheduler = nn.StepLR(optimizer, step_size=2, gamma=0.1)
|
|
143
|
-
>>> def forward_fn(data, label):
|
|
144
|
-
... logits = net(data)
|
|
145
|
-
... loss = loss_fn(logits, label)
|
|
146
|
-
... return loss, logits
|
|
147
|
-
>>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
|
|
148
|
-
>>> def train_step(data, label):
|
|
149
|
-
... (loss, _), grads = grad_fn(data, label)
|
|
150
|
-
... optimizer(grads)
|
|
151
|
-
... return loss
|
|
152
|
-
>>> for epoch in range(6):
|
|
153
|
-
... # Create the dataset taking MNIST as an example. Refer to
|
|
154
|
-
... # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/mnist.py
|
|
155
|
-
... for data, label in create_dataset():
|
|
156
|
-
... train_step(data, label)
|
|
157
|
-
... scheduler.step()
|
|
158
|
-
... current_lr = scheduler.get_last_lr()
|
|
159
|
-
"""
|
|
160
|
-
def __init__(self, optimizer, step_size, gamma=0.5, last_epoch=-1, verbose=False):
|
|
161
|
-
self.step_size = step_size
|
|
162
|
-
self.gamma = gamma
|
|
163
|
-
super(StepLR, self).__init__(optimizer, last_epoch, verbose)
|
|
164
|
-
|
|
165
|
-
def _get_lr(self):
|
|
166
|
-
if (self.last_epoch == Tensor(0, mstype.float32)) or (
|
|
167
|
-
self.last_epoch % self.step_size != Tensor(0, mstype.float32)):
|
|
168
|
-
return [group['lr'] * 1. for group in self.optimizer.param_groups]
|
|
169
|
-
return [group['lr'] * self.gamma
|
|
170
|
-
for group in self.optimizer.param_groups]
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
@jit_class
|
|
174
|
-
class LinearLR(LRScheduler):
|
|
175
|
-
"""Decays the learning rate of each parameter group by linearly changing small
|
|
176
|
-
multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters.
|
|
177
|
-
Notice that such decay can happen simultaneously with other changes to the learning rate
|
|
178
|
-
from outside this scheduler.
|
|
179
|
-
|
|
180
|
-
.. warning::
|
|
181
|
-
This is an experimental lr scheduler module that is subject to change.
|
|
182
|
-
This module must be used with optimizers in `Experimental Optimizer
|
|
183
|
-
<https://www.mindspore.cn/docs/en/r2.1/api_python/mindspore.nn.html#experimental-optimizer>`_ .
|
|
184
|
-
|
|
185
|
-
Args:
|
|
186
|
-
optimizer (:class:`mindspore.nn.optim_ex.Optimizer`): Wrapped optimizer.
|
|
187
|
-
start_factor (float, optional): The number we multiply learning rate in the first epoch.
|
|
188
|
-
The multiplication factor changes towards `end_factor` in the following epochs.
|
|
189
|
-
Default: ``1.0 /3``.
|
|
190
|
-
end_factor (float, optional): The number we multiply learning rate at the end of linear changing
|
|
191
|
-
process. Default: ``1.0``.
|
|
192
|
-
total_iters (int, optional): The number of iterations that multiplicative factor reaches to 1.
|
|
193
|
-
Default: ``5``.
|
|
194
|
-
last_epoch (int, optional): The index of the last epoch. Default: ``-1``.
|
|
195
|
-
verbose (bool, optional): If ``True``, prints a message to stdout for
|
|
196
|
-
each update. Default: ``False``.
|
|
197
|
-
|
|
198
|
-
Raises:
|
|
199
|
-
ValueError: If `start_factor` is not in the range of (0, 1].
|
|
200
|
-
ValueError: If `end_factor` is not in the range of [0, 1].
|
|
201
|
-
|
|
202
|
-
Supported Platforms:
|
|
203
|
-
``Ascend`` ``GPU`` ``CPU``
|
|
204
|
-
|
|
205
|
-
Examples:
|
|
206
|
-
>>> import mindspore
|
|
207
|
-
>>> from mindspore.nn import LinearLR
|
|
208
|
-
>>> from mindspore import nn
|
|
209
|
-
>>> # Define the network structure of LeNet5. Refer to
|
|
210
|
-
>>> # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/lenet.py
|
|
211
|
-
>>> net = LeNet5()
|
|
212
|
-
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
213
|
-
>>> optimizer = nn.optim_ex.Adam(net.trainable_params(), lr=0.05)
|
|
214
|
-
>>> # Assuming optimizer uses lr = 0.05 for all groups
|
|
215
|
-
>>> # lr = 0.025 if epoch == 0
|
|
216
|
-
>>> # lr = 0.03125 if epoch == 1
|
|
217
|
-
>>> # lr = 0.0375 if epoch == 2
|
|
218
|
-
>>> # lr = 0.04375 if epoch == 3
|
|
219
|
-
>>> # lr = 0.05 if epoch >= 4
|
|
220
|
-
>>> scheduler = LinearLR(optimizer, start_factor=0.5, total_iters=4)
|
|
221
|
-
>>> def forward_fn(data, label):
|
|
222
|
-
... logits = net(data)
|
|
223
|
-
... loss = loss_fn(logits, label)
|
|
224
|
-
... return loss, logits
|
|
225
|
-
>>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
|
|
226
|
-
>>> def train_step(data, label):
|
|
227
|
-
... (loss, _), grads = grad_fn(data, label)
|
|
228
|
-
... optimizer(grads)
|
|
229
|
-
... return loss
|
|
230
|
-
>>> for epoch in range(5):
|
|
231
|
-
... # Create the dataset taking MNIST as an example. Refer to
|
|
232
|
-
... # https://gitee.com/mindspore/docs/blob/r2.1/docs/mindspore/code/mnist.py
|
|
233
|
-
... for data, label in create_dataset():
|
|
234
|
-
... train_step(data, label)
|
|
235
|
-
... scheduler.step()
|
|
236
|
-
... current_lr = scheduler.get_last_lr()
|
|
237
|
-
"""
|
|
238
|
-
|
|
239
|
-
def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1,
|
|
240
|
-
verbose=False):
|
|
241
|
-
|
|
242
|
-
if start_factor > 1.0 or start_factor <= 0:
|
|
243
|
-
raise ValueError('Starting multiplicative factor expected to be greater than 0 and less or equal to 1.')
|
|
244
|
-
|
|
245
|
-
if end_factor > 1.0 or end_factor < 0:
|
|
246
|
-
raise ValueError('Ending multiplicative factor expected to be between 0 and 1.')
|
|
247
|
-
|
|
248
|
-
self.start_factor = start_factor
|
|
249
|
-
self.end_factor = end_factor
|
|
250
|
-
self.total_iters = total_iters
|
|
251
|
-
super(LinearLR, self).__init__(optimizer, last_epoch, verbose)
|
|
252
|
-
|
|
253
|
-
def _get_lr(self):
|
|
254
|
-
if self.last_epoch == Tensor(0, mstype.float32):
|
|
255
|
-
return [group['lr'] * self.start_factor for group in self.optimizer.param_groups]
|
|
256
|
-
|
|
257
|
-
if self.last_epoch > self.total_iters:
|
|
258
|
-
return [group['lr'] * 1. for group in self.optimizer.param_groups]
|
|
259
|
-
|
|
260
|
-
return [group['lr'] * (1. + (self.end_factor - self.start_factor) /
|
|
261
|
-
(self.total_iters * self.start_factor + (self.last_epoch - 1) *
|
|
262
|
-
(self.end_factor - self.start_factor))) for group in self.optimizer.param_groups]
|
|
@@ -1,248 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
|
|
16
|
-
"""image_ops"""
|
|
17
|
-
|
|
18
|
-
from mindspore import Tensor
|
|
19
|
-
from mindspore.common import dtype as mstype
|
|
20
|
-
from mindspore.ops._grad_experimental.grad_base import bprop_getters
|
|
21
|
-
from mindspore.ops import operations as P
|
|
22
|
-
from mindspore.ops import functional as F
|
|
23
|
-
from mindspore.ops.operations import _grad_ops as G
|
|
24
|
-
from mindspore.ops.operations.image_ops import ResizeBicubic
|
|
25
|
-
from mindspore.ops.operations._grad_ops import ResizeBicubicGrad
|
|
26
|
-
from mindspore.ops.operations.image_ops import ResizeV2
|
|
27
|
-
from mindspore.ops.composite.multitype_ops.zeros_like_impl import zeros_like
|
|
28
|
-
from mindspore.ops.operations.image_ops import CropAndResize
|
|
29
|
-
from mindspore.ops.operations.image_ops import CropAndResizeGradImage
|
|
30
|
-
from mindspore.ops.operations.image_ops import CropAndResizeGradBoxes
|
|
31
|
-
from mindspore.ops.operations.image_ops import RGBToHSV
|
|
32
|
-
from mindspore.ops.operations.image_ops import ScaleAndTranslate
|
|
33
|
-
from mindspore import context
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
@bprop_getters.register(ResizeBicubic)
|
|
37
|
-
def get_bprop_resize_bicubic(self):
|
|
38
|
-
"""Grad definition for `ResizeBicubic` operation."""
|
|
39
|
-
resize_bicubic_grad = ResizeBicubicGrad(align_corners=self.align_corners,
|
|
40
|
-
half_pixel_centers=self.half_pixel_centers)
|
|
41
|
-
|
|
42
|
-
def bprop(images, size, out, dout):
|
|
43
|
-
dx = resize_bicubic_grad(dout, images)
|
|
44
|
-
return (dx, P.ZerosLike()(size))
|
|
45
|
-
return bprop
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
@bprop_getters.register(ResizeV2)
|
|
49
|
-
def get_bprop_resize_v2(self):
|
|
50
|
-
"""Grad definition for `ResizeV2` operation."""
|
|
51
|
-
resize_v2_grad = G.ResizeV2Grad(coordinate_transformation_mode=self.coordinate_transformation_mode,
|
|
52
|
-
mode=self.mode)
|
|
53
|
-
|
|
54
|
-
def bprop(x, roi, scales, sizes, out, dout):
|
|
55
|
-
input_size = P.Shape()(x)
|
|
56
|
-
dx = resize_v2_grad(dout, roi, scales, Tensor(input_size))
|
|
57
|
-
return (dx, zeros_like(roi), zeros_like(scales), zeros_like(sizes))
|
|
58
|
-
return bprop
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
@bprop_getters.register(CropAndResize)
|
|
62
|
-
def get_bprop_crop_and_resize(self):
|
|
63
|
-
"""Grad definition for `CropAndResize` operation."""
|
|
64
|
-
allowed_types = [mstype.float16, mstype.float32, mstype.float64]
|
|
65
|
-
gradboxes = CropAndResizeGradBoxes(method="bilinear")
|
|
66
|
-
method_ = self.method
|
|
67
|
-
|
|
68
|
-
is_ascend_cpu = context.get_context('device_target') in ("Ascend", "CPU")
|
|
69
|
-
|
|
70
|
-
def bprop(x, boxes, box_index, crop_size, out, dout):
|
|
71
|
-
if method_ != "bilinear":
|
|
72
|
-
if not is_ascend_cpu:
|
|
73
|
-
return (zeros_like(x), zeros_like(boxes), zeros_like(box_index), zeros_like(crop_size))
|
|
74
|
-
image_type = x.dtype
|
|
75
|
-
if image_type not in allowed_types:
|
|
76
|
-
x = F.cast(x, mstype.float32)
|
|
77
|
-
dimage_type = image_type
|
|
78
|
-
gradimage = CropAndResizeGradImage(dimage_type, method=method_)
|
|
79
|
-
image_shape = x.shape
|
|
80
|
-
if F.is_sequence_value_unknown(image_shape):
|
|
81
|
-
image_size = P.TensorShape()(x)
|
|
82
|
-
image_size = F.cast(image_size, mstype.int32)
|
|
83
|
-
else:
|
|
84
|
-
image_size = Tensor(image_shape, dtype=mstype.int32)
|
|
85
|
-
dimage = gradimage(dout, boxes, box_index, image_size)
|
|
86
|
-
dbox = gradboxes(dout, x, boxes, box_index)
|
|
87
|
-
return (dimage, dbox, zeros_like(box_index), zeros_like(crop_size))
|
|
88
|
-
return bprop
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
def crcp(x):
|
|
92
|
-
"""Grad definition for `RGBToHSV` operations."""
|
|
93
|
-
return P.DivNoNan()(1, x)
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def function1_rgbtohsv(images, out, dout):
|
|
97
|
-
"""Grad definition for `RGBToHSV` operations."""
|
|
98
|
-
dout = P.Cast()(dout, mstype.float32)
|
|
99
|
-
images = P.Cast()(images, mstype.float32)
|
|
100
|
-
out = P.Cast()(out, mstype.float32)
|
|
101
|
-
return images, out, dout
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def function2_rgbtohsv(images):
|
|
105
|
-
"""Grad definition for `RGBToHSV` operations."""
|
|
106
|
-
# Input Channels
|
|
107
|
-
reds = images[..., 0]
|
|
108
|
-
greens = images[..., 1]
|
|
109
|
-
blues = images[..., 2]
|
|
110
|
-
return reds, greens, blues
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
def function3_rgbtohsv(out, reds):
|
|
114
|
-
"""Grad definition for `RGBToHSV` operations."""
|
|
115
|
-
# Output Channels
|
|
116
|
-
saturation = out[..., 1]
|
|
117
|
-
value = out[..., 2]
|
|
118
|
-
dsr1 = P.Cast()(reds > 0, mstype.float32)
|
|
119
|
-
return dsr1, saturation, value
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
def function4_rgbtohsv(reds, greens, blues):
|
|
123
|
-
"""Grad definition for `RGBToHSV` operations."""
|
|
124
|
-
r_b = P.LogicalAnd()((reds >= blues), (reds >= greens))
|
|
125
|
-
red_biggest = P.Cast()(r_b, mstype.float32)
|
|
126
|
-
g_b = P.LogicalAnd()((greens > reds), (greens >= blues))
|
|
127
|
-
green_biggest = P.Cast()(g_b, mstype.float32)
|
|
128
|
-
b_b = P.LogicalAnd()((blues > reds), (blues > greens))
|
|
129
|
-
blue_biggest = P.Cast()(b_b, mstype.float32)
|
|
130
|
-
return red_biggest, green_biggest, blue_biggest
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
def function5_rgbtohsv(reds, greens, blues):
|
|
134
|
-
"""Grad definition for `RGBToHSV` operations."""
|
|
135
|
-
r_s = P.LogicalAnd()((reds < blues), (reds < greens))
|
|
136
|
-
red_smallest = P.Cast()(r_s, mstype.float32)
|
|
137
|
-
g_s = P.LogicalAnd()((greens <= reds), (greens < blues))
|
|
138
|
-
green_smallest = P.Cast()(g_s, mstype.float32)
|
|
139
|
-
b_s = P.LogicalAnd()((blues <= reds), (blues <= greens))
|
|
140
|
-
blue_smallest = P.Cast()(b_s, mstype.float32)
|
|
141
|
-
return red_smallest, green_smallest, blue_smallest
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
def function6_rgbtohsv(red_biggest, green_biggest, blue_biggest):
|
|
145
|
-
"""Grad definition for `RGBToHSV` operations."""
|
|
146
|
-
dv_dr = red_biggest
|
|
147
|
-
dv_dg = green_biggest
|
|
148
|
-
dv_db = blue_biggest
|
|
149
|
-
return dv_dr, dv_dg, dv_db
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
def function7_rgbtohsv(greens, green_biggest, dhb5, dh_db_1, dh_db_2, dh_db_3, dh_db_4,\
|
|
153
|
-
dout, dv_dr, dv_dg, dv_db, ds_dr, ds_dg, ds_db, dh_dr, dh_dg):
|
|
154
|
-
"""Grad definition for `RGBToHSV` operations."""
|
|
155
|
-
dh_db_5 = 60 * (P.Cast()((greens > 0), mstype.float32) * green_biggest * dhb5)
|
|
156
|
-
|
|
157
|
-
dh_db = dh_db_1 + dh_db_2 + dh_db_3 + dh_db_4 + dh_db_5
|
|
158
|
-
|
|
159
|
-
dh_db = dh_db / 360
|
|
160
|
-
|
|
161
|
-
dv_drgb = P.Stack(-1)(
|
|
162
|
-
[dout[..., 2] * dv_dr, dout[..., 2] * dv_dg, dout[..., 2] * dv_db])
|
|
163
|
-
ds_drgb = P.Stack(-1)(
|
|
164
|
-
[dout[..., 1] * ds_dr, dout[..., 1] * ds_dg, dout[..., 1] * ds_db])
|
|
165
|
-
dh_drgb = P.Stack(-1)(
|
|
166
|
-
[dout[..., 0] * dh_dr, dout[..., 0] * dh_dg, dout[..., 0] * dh_db])
|
|
167
|
-
dvds_drgb = P.Add()(dv_drgb, ds_drgb)
|
|
168
|
-
doutient_input = P.Add()(dvds_drgb, dh_drgb)
|
|
169
|
-
return (doutient_input,)
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
@bprop_getters.register(RGBToHSV)
|
|
173
|
-
def get_bprop_rgb_to_hsv(self):
|
|
174
|
-
"""dout definition for 'RGBToHSV' operation"""
|
|
175
|
-
|
|
176
|
-
def bprop(images, out, dout):
|
|
177
|
-
images, out, dout = function1_rgbtohsv(images, out, dout)
|
|
178
|
-
reds, greens, blues = function2_rgbtohsv(images)
|
|
179
|
-
dsr1, saturation, value = function3_rgbtohsv(out, reds)
|
|
180
|
-
red_biggest, green_biggest, blue_biggest = function4_rgbtohsv(reds, greens, blues)
|
|
181
|
-
red_smallest, green_smallest, blue_smallest = function5_rgbtohsv(reds, greens, blues)
|
|
182
|
-
dv_dr, dv_dg, dv_db = function6_rgbtohsv(red_biggest, green_biggest, blue_biggest)
|
|
183
|
-
dsr2 = red_biggest * P.Add()(green_smallest * greens, blue_smallest * blues) * crcp(P.Square()(reds))
|
|
184
|
-
dsr3 = red_smallest * -1 * crcp((green_biggest * greens) + (blue_biggest * blues))
|
|
185
|
-
ds_dr = dsr1 * P.Add()(dsr2, dsr3)
|
|
186
|
-
dsg1 = P.Cast()((greens > 0), mstype.float32)
|
|
187
|
-
dsg2 = green_biggest * P.Add()(red_smallest * reds, blue_smallest * blues) * crcp(P.Square()(greens))
|
|
188
|
-
dsg3 = green_smallest * -1 * crcp((red_biggest * reds) + (blue_biggest * blues))
|
|
189
|
-
ds_dg = dsg1 * P.Add()(dsg2, dsg3)
|
|
190
|
-
|
|
191
|
-
dsb1 = P.Cast()((blues > 0), mstype.float32)
|
|
192
|
-
dsb2 = blue_biggest * P.Add()(green_smallest * greens, red_smallest * reds) * crcp(P.Square()(blues))
|
|
193
|
-
dsb3 = blue_smallest * -1 * crcp((green_biggest * greens) + (red_biggest * reds))
|
|
194
|
-
ds_db = dsb1 * P.Add()(dsb2, dsb3)
|
|
195
|
-
|
|
196
|
-
dhr1 = (greens - blues) * crcp(P.Square()(saturation)) * crcp(P.Square()(value))
|
|
197
|
-
dh_dr_1 = 60 * (P.Cast()((reds > 0), mstype.float32) * red_biggest * -1 * dhr1)
|
|
198
|
-
dhr2 = red_smallest * (blues - greens) * crcp(P.Square()(reds - greens))
|
|
199
|
-
dh_dr_2 = 60 * (P.Cast()((greens > 0), mstype.float32) * green_biggest * dhr2)
|
|
200
|
-
dhr3 = blue_smallest * -1 * crcp(greens - blues)
|
|
201
|
-
dh_dr_3 = 60 * (P.Cast()((greens > 0), mstype.float32) * green_biggest * dhr3)
|
|
202
|
-
dhr4 = red_smallest * (blues - greens) * crcp(P.Square()(blues - reds))
|
|
203
|
-
dh_dr_4 = 60 * (P.Cast()((blues > 0), mstype.float32) * blue_biggest * dhr4)
|
|
204
|
-
dhr5 = green_smallest * crcp(blues - greens)
|
|
205
|
-
dh_dr_5 = 60 * (P.Cast()((blues > 0), mstype.float32) * blue_biggest * dhr5)
|
|
206
|
-
|
|
207
|
-
dh_dr = (dh_dr_1 + dh_dr_2 + dh_dr_3 + dh_dr_4 + dh_dr_5) / 360
|
|
208
|
-
|
|
209
|
-
dhg1 = (blues - reds) * crcp(P.Square()(saturation)) * crcp(P.Square()(value))
|
|
210
|
-
dh_dg_1 = 60 * (P.Cast()((greens > 0), mstype.float32) * green_biggest * -1 * dhg1)
|
|
211
|
-
dhg2 = green_smallest * (reds - blues) * crcp(P.Square()(reds - greens))
|
|
212
|
-
dh_dg_2 = 60 * (P.Cast()((reds > 0), mstype.float32) * red_biggest * dhg2)
|
|
213
|
-
dhg3 = blue_smallest * crcp(reds - blues)
|
|
214
|
-
dh_dg_3 = 60 * (P.Cast()((reds > 0), mstype.float32) * red_biggest * dhg3)
|
|
215
|
-
dhg4 = green_smallest * (reds - blues) * crcp(P.Square()(blues - greens))
|
|
216
|
-
dh_dg_4 = 60 * (P.Cast()((blues > 0), mstype.float32) * blue_biggest * dhg4)
|
|
217
|
-
dhg5 = red_smallest * -1 * crcp(blues - reds)
|
|
218
|
-
dh_dg_5 = 60 * (P.Cast()((blues > 0), mstype.float32) * blue_biggest * dhg5)
|
|
219
|
-
|
|
220
|
-
dh_dg = (dh_dg_1 + dh_dg_2 + dh_dg_3 + dh_dg_4 + dh_dg_5) / 360
|
|
221
|
-
|
|
222
|
-
dhb1 = (reds - greens) * crcp(P.Square()(saturation)) * crcp(P.Square()(value))
|
|
223
|
-
dh_db_1 = 60 * (P.Cast()((blues > 0), mstype.float32) * blue_biggest * -1 * dhb1)
|
|
224
|
-
dhb2 = blue_smallest * (greens - reds) * crcp(P.Square()(reds - blues))
|
|
225
|
-
dh_db_2 = 60 * (P.Cast()((reds > 0), mstype.float32) * red_biggest * dhb2)
|
|
226
|
-
dhb3 = green_smallest * -1 * crcp(reds - greens)
|
|
227
|
-
dh_db_3 = 60 * (P.Cast()((reds > 0), mstype.float32) * red_biggest * dhb3)
|
|
228
|
-
dhb4 = blue_smallest * (greens - reds) * crcp(P.Square()(greens - blues))
|
|
229
|
-
dh_db_4 = 60 * (P.Cast()((greens > 0), mstype.float32) * green_biggest * dhb4)
|
|
230
|
-
dhb5 = red_smallest * crcp(greens - reds)
|
|
231
|
-
return function7_rgbtohsv(greens, green_biggest, dhb5, dh_db_1, dh_db_2, dh_db_3,\
|
|
232
|
-
dh_db_4, dout, dv_dr, dv_dg, dv_db, ds_dr, ds_dg, ds_db, dh_dr, dh_dg)
|
|
233
|
-
return bprop
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
@bprop_getters.register(ScaleAndTranslate)
|
|
237
|
-
def get_bprop_scale_and_translate(self):
|
|
238
|
-
"""Grad definition for `ScaleAndTranslate` operation"""
|
|
239
|
-
scale_and_translate_grad = G.ScaleAndTranslateGrad(self.kernel_type, self.antialias)
|
|
240
|
-
|
|
241
|
-
def bprop(images, size, scale, translation, out, dout):
|
|
242
|
-
images_fp32 = F.cast(images, mstype.float32)
|
|
243
|
-
grad0_fp32 = scale_and_translate_grad(dout, images_fp32, scale, translation)
|
|
244
|
-
grad0 = F.cast(grad0_fp32, F.dtype(images))
|
|
245
|
-
result = (grad0, F.zeros_like(size), F.zeros_like(scale), F.zeros_like(translation))
|
|
246
|
-
return result
|
|
247
|
-
|
|
248
|
-
return bprop
|