mindspore 2.0.0rc1__cp38-none-any.whl → 2.2.0__cp38-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Third_Party_Open_Source_Software_Notice +2 -2
- mindspore/__init__.py +5 -2
- mindspore/_akg/akg/build_module.py +5 -6
- mindspore/_akg/akg/composite/build_module.py +49 -16
- mindspore/_akg/akg/composite/split_stitch.py +10 -11
- mindspore/_akg/akg/config/repository.json +195 -0
- mindspore/_akg/akg/global_configs.py +5 -1
- mindspore/_akg/akg/ms/info_version_adapt.py +67 -1
- mindspore/_akg/akg/tvm/api.py +4 -3
- mindspore/_akg/akg/tvm/autotvm/__init__.py +1 -2
- mindspore/_akg/akg/tvm/autotvm/graph_tuner/base_graph_tuner.py +1 -5
- mindspore/_akg/akg/tvm/autotvm/measure/__init__.py +1 -1
- mindspore/_akg/akg/tvm/autotvm/measure/measure.py +1 -10
- mindspore/_akg/akg/tvm/autotvm/measure/measure_methods.py +1 -372
- mindspore/_akg/akg/tvm/build_module.py +16 -1
- mindspore/_akg/akg/tvm/contrib/graph_runtime.py +0 -53
- mindspore/_akg/akg/tvm/hybrid/parser.py +7 -6
- mindspore/_akg/akg/tvm/ir_builder.py +1 -1
- mindspore/_akg/akg/tvm/module.py +1 -2
- mindspore/_akg/akg/tvm/stmt.py +2 -2
- mindspore/_akg/akg/utils/composite_op_helper.py +9 -10
- mindspore/_akg/akg/utils/kernel_exec.py +58 -260
- mindspore/_akg/akg/utils/op_dsl.py +17 -1
- mindspore/_akg/akg/utils/result_analysis.py +4 -24
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +198 -0
- mindspore/_c_dataengine.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_c_mindrecord.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/_check_jit_forbidden_api.py +5 -1
- mindspore/_checkparam.py +79 -62
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +2 -0
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +1 -9
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +128 -21
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +2 -2
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +18 -13
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +13 -9
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +19 -17
- mindspore/_extends/parse/namespace.py +7 -36
- mindspore/_extends/parse/parser.py +375 -189
- mindspore/_extends/parse/resources.py +36 -41
- mindspore/_extends/parse/standard_method.py +350 -245
- mindspore/_extends/parse/trope.py +2 -12
- mindspore/_extends/remote/kernel_build_server.py +24 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/_mindspore_offline_debug.cpython-38-aarch64-linux-gnu.so +0 -0
- mindspore/amp.py +85 -19
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/base.py +2 -2
- mindspore/boost/boost.py +27 -32
- mindspore/boost/boost_cell_wrapper.py +37 -13
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/grad_freeze.py +34 -6
- mindspore/boost/group_loss_scale_manager.py +15 -14
- mindspore/boost/less_batch_normalization.py +28 -3
- mindspore/common/__init__.py +15 -11
- mindspore/common/_auto_dynamic.py +68 -0
- mindspore/common/_jit_fallback_utils.py +111 -0
- mindspore/common/_register_for_adapter.py +17 -5
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +18 -15
- mindspore/common/_utils.py +31 -7
- mindspore/common/api.py +269 -101
- mindspore/common/auto_dynamic_shape.py +498 -0
- mindspore/common/dtype.py +61 -21
- mindspore/common/dump.py +9 -7
- mindspore/common/initializer.py +106 -76
- mindspore/common/jit_config.py +35 -14
- mindspore/common/lazy_inline.py +187 -0
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/mutable.py +10 -13
- mindspore/common/parameter.py +246 -55
- mindspore/common/seed.py +13 -7
- mindspore/common/sparse_tensor.py +29 -33
- mindspore/common/tensor.py +907 -251
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +84 -4
- mindspore/communication/management.py +160 -88
- mindspore/config/op_info.config +99 -75
- mindspore/config/super_bar_config.json +36 -4
- mindspore/context.py +526 -219
- mindspore/dataset/__init__.py +9 -46
- mindspore/dataset/audio/__init__.py +4 -19
- mindspore/dataset/audio/transforms.py +545 -233
- mindspore/dataset/audio/utils.py +21 -18
- mindspore/dataset/callback/ds_callback.py +42 -13
- mindspore/dataset/core/config.py +158 -100
- mindspore/dataset/core/validator_helpers.py +1 -63
- mindspore/dataset/debug/debug_hook.py +45 -13
- mindspore/dataset/debug/pre_defined_hook.py +5 -5
- mindspore/dataset/engine/__init__.py +0 -5
- mindspore/dataset/engine/cache_client.py +38 -15
- mindspore/dataset/engine/datasets.py +615 -278
- mindspore/dataset/engine/datasets_audio.py +154 -283
- mindspore/dataset/engine/datasets_standard_format.py +104 -116
- mindspore/dataset/engine/datasets_text.py +443 -326
- mindspore/dataset/engine/datasets_user_defined.py +251 -164
- mindspore/dataset/engine/datasets_vision.py +839 -1443
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +7 -3
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/offload.py +6 -6
- mindspore/dataset/engine/queue.py +15 -14
- mindspore/dataset/engine/samplers.py +39 -23
- mindspore/dataset/engine/serializer_deserializer.py +22 -6
- mindspore/dataset/engine/validators.py +21 -331
- mindspore/dataset/text/__init__.py +5 -33
- mindspore/dataset/text/transforms.py +334 -165
- mindspore/dataset/text/utils.py +215 -145
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/c_transforms.py +3 -2
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +174 -71
- mindspore/dataset/utils/browse_dataset.py +25 -17
- mindspore/dataset/utils/line_reader.py +24 -21
- mindspore/dataset/vision/__init__.py +5 -26
- mindspore/dataset/vision/c_transforms.py +177 -165
- mindspore/dataset/vision/py_transforms.py +114 -119
- mindspore/dataset/vision/py_transforms_util.py +54 -51
- mindspore/dataset/vision/transforms.py +1127 -381
- mindspore/dataset/vision/utils.py +54 -38
- mindspore/dataset/vision/validators.py +12 -2
- mindspore/experimental/map_parameter.py +38 -4
- mindspore/{dataset/datapreprocess → experimental/optim}/__init__.py +14 -4
- mindspore/experimental/optim/adam.py +192 -0
- mindspore/experimental/optim/adamw.py +181 -0
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/experimental/optim/optimizer.py +252 -0
- mindspore/experimental/optim/sgd.py +147 -0
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +1 -2
- mindspore/include/api/context.h +21 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +29 -42
- mindspore/include/api/model_group.h +14 -3
- mindspore/include/api/model_parallel_runner.h +18 -2
- mindspore/include/api/serialization.h +26 -0
- mindspore/include/api/status.h +1 -0
- mindspore/include/api/types.h +38 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/c_api/status_c.h +3 -0
- mindspore/include/dataset/constants.h +6 -12
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +25 -31
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/include/mindapi/base/format.h +0 -1
- mindspore/include/mindapi/base/type_id.h +2 -1
- mindspore/include/mindapi/base/types.h +5 -1
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libjemalloc.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libmpi_adapter.so +0 -0
- mindspore/lib/libnnacl.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/libps_cache.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +9000 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libakg.so +0 -0
- mindspore/lib/plugin/ascend/libascend_collective.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libhccl_plugin.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_aicpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/cpu/libakg.so +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.1 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/log.py +9 -6
- mindspore/mindrecord/filereader.py +33 -4
- mindspore/mindrecord/filewriter.py +70 -35
- mindspore/mindrecord/mindpage.py +40 -34
- mindspore/mindrecord/shardreader.py +1 -1
- mindspore/mindrecord/shardsegment.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +25 -18
- mindspore/mindrecord/tools/cifar10_to_mr.py +25 -18
- mindspore/mindrecord/tools/csv_to_mr.py +29 -13
- mindspore/mindrecord/tools/imagenet_to_mr.py +24 -10
- mindspore/mindrecord/tools/mnist_to_mr.py +24 -11
- mindspore/mindrecord/tools/tfrecord_to_mr.py +31 -26
- mindspore/nn/cell.py +463 -169
- mindspore/nn/dynamic_lr.py +47 -43
- mindspore/nn/layer/activation.py +225 -82
- mindspore/nn/layer/basic.py +121 -79
- mindspore/nn/layer/channel_shuffle.py +21 -21
- mindspore/nn/layer/combined.py +33 -26
- mindspore/nn/layer/container.py +277 -22
- mindspore/nn/layer/conv.py +441 -304
- mindspore/nn/layer/dense.py +19 -13
- mindspore/nn/layer/embedding.py +62 -49
- mindspore/nn/layer/flash_attention.py +264 -0
- mindspore/nn/layer/image.py +50 -39
- mindspore/nn/layer/math.py +62 -51
- mindspore/nn/layer/normalization.py +219 -167
- mindspore/nn/layer/padding.py +58 -70
- mindspore/nn/layer/pooling.py +334 -287
- mindspore/nn/layer/rnn_cells.py +53 -38
- mindspore/nn/layer/rnns.py +59 -56
- mindspore/nn/layer/thor_layer.py +52 -44
- mindspore/nn/layer/timedistributed.py +6 -4
- mindspore/nn/layer/transformer.py +284 -164
- mindspore/nn/learning_rate_schedule.py +34 -25
- mindspore/nn/loss/__init__.py +3 -2
- mindspore/nn/loss/loss.py +554 -311
- mindspore/nn/optim/ada_grad.py +12 -9
- mindspore/nn/optim/adadelta.py +14 -11
- mindspore/nn/optim/adafactor.py +19 -16
- mindspore/nn/optim/adam.py +62 -47
- mindspore/nn/optim/adamax.py +13 -10
- mindspore/nn/optim/adasum.py +12 -8
- mindspore/nn/optim/asgd.py +10 -9
- mindspore/nn/optim/ftrl.py +20 -17
- mindspore/nn/optim/lamb.py +16 -12
- mindspore/nn/optim/lars.py +8 -6
- mindspore/nn/optim/lazyadam.py +25 -20
- mindspore/nn/optim/momentum.py +10 -7
- mindspore/nn/optim/optimizer.py +61 -9
- mindspore/nn/optim/proximal_ada_grad.py +14 -13
- mindspore/nn/optim/rmsprop.py +17 -13
- mindspore/nn/optim/rprop.py +30 -17
- mindspore/nn/optim/sgd.py +40 -23
- mindspore/nn/optim/thor.py +24 -26
- mindspore/nn/probability/bijector/bijector.py +11 -11
- mindspore/nn/probability/bijector/exp.py +1 -1
- mindspore/nn/probability/bijector/gumbel_cdf.py +3 -3
- mindspore/nn/probability/bijector/invert.py +1 -1
- mindspore/nn/probability/bijector/power_transform.py +29 -29
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +5 -5
- mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +4 -2
- mindspore/nn/probability/bnn_layers/conv_variational.py +13 -13
- mindspore/nn/probability/bnn_layers/dense_variational.py +12 -12
- mindspore/nn/probability/bnn_layers/layer_distribution.py +9 -8
- mindspore/nn/probability/distribution/_utils/custom_ops.py +19 -3
- mindspore/nn/probability/distribution/_utils/utils.py +1 -1
- mindspore/nn/probability/distribution/bernoulli.py +9 -9
- mindspore/nn/probability/distribution/beta.py +8 -8
- mindspore/nn/probability/distribution/categorical.py +23 -15
- mindspore/nn/probability/distribution/cauchy.py +5 -6
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/exponential.py +4 -4
- mindspore/nn/probability/distribution/gamma.py +10 -10
- mindspore/nn/probability/distribution/geometric.py +8 -8
- mindspore/nn/probability/distribution/gumbel.py +8 -9
- mindspore/nn/probability/distribution/half_normal.py +5 -5
- mindspore/nn/probability/distribution/laplace.py +5 -5
- mindspore/nn/probability/distribution/log_normal.py +12 -11
- mindspore/nn/probability/distribution/logistic.py +8 -8
- mindspore/nn/probability/distribution/normal.py +6 -5
- mindspore/nn/probability/distribution/poisson.py +10 -11
- mindspore/nn/probability/distribution/student_t.py +8 -9
- mindspore/nn/probability/distribution/transformed_distribution.py +5 -5
- mindspore/nn/probability/distribution/uniform.py +11 -11
- mindspore/nn/reinforcement/tensor_array.py +2 -2
- mindspore/nn/sparse/sparse.py +9 -9
- mindspore/nn/wrap/cell_wrapper.py +188 -63
- mindspore/nn/wrap/grad_reducer.py +21 -12
- mindspore/nn/wrap/loss_scale.py +136 -49
- mindspore/numpy/__init__.py +4 -4
- mindspore/numpy/array_creations.py +55 -56
- mindspore/numpy/array_ops.py +134 -35
- mindspore/numpy/logic_ops.py +66 -20
- mindspore/numpy/math_ops.py +142 -139
- mindspore/numpy/utils_const.py +2 -2
- mindspore/offline_debug/convert_async.py +2 -2
- mindspore/ops/_grad_experimental/__init__.py +7 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +231 -348
- mindspore/ops/{_grad → _grad_experimental}/grad_base.py +1 -33
- mindspore/ops/{_grad → _grad_experimental}/grad_comm_ops.py +25 -13
- mindspore/ops/{_grad/__init__.py → _grad_experimental/grad_debug_ops.py} +15 -7
- mindspore/ops/{_grad → _grad_experimental}/grad_implementations.py +17 -11
- mindspore/ops/_grad_experimental/grad_inner_ops.py +33 -52
- mindspore/ops/_grad_experimental/grad_math_ops.py +151 -1224
- mindspore/ops/_grad_experimental/grad_nn_ops.py +141 -414
- mindspore/ops/{_grad → _grad_experimental}/grad_quant_ops.py +10 -6
- mindspore/ops/_grad_experimental/grad_sparse.py +317 -2
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -13
- mindspore/ops/{_grad → _grad_experimental}/taylor_rule.py +1 -1
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +406 -0
- mindspore/{_extends/graph_kernel/expanders/complex/__init__.py → ops/_op_impl/_custom_op/flash_attention/constants.py} +27 -8
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +467 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +563 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +193 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +435 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +45 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +67 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +62 -0
- mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +2 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +41 -1
- mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/cast.py +52 -0
- mindspore/ops/_op_impl/aicpu/coalesce.py +2 -0
- mindspore/ops/_op_impl/aicpu/col2im.py +3 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/dropout_genmask.py +6 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/eye.py +4 -4
- mindspore/ops/_op_impl/aicpu/fft_with_size.py +6 -0
- mindspore/ops/_op_impl/aicpu/fill_diagonal.py +5 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/im2col.py +3 -5
- mindspore/ops/_op_impl/aicpu/lgamma.py +1 -0
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu.py +39 -0
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/masked_scatter.py +1 -0
- mindspore/ops/_op_impl/aicpu/masked_select_grad.py +3 -0
- mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
- mindspore/ops/_op_impl/aicpu/matrix_power.py +6 -1
- mindspore/ops/_op_impl/aicpu/median.py +1 -0
- mindspore/ops/_op_impl/aicpu/multinomial.py +9 -9
- mindspore/ops/_op_impl/aicpu/not_equal.py +0 -5
- mindspore/ops/_op_impl/aicpu/pad_v3.py +3 -1
- mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +2 -0
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +0 -6
- mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +0 -7
- mindspore/ops/_op_impl/aicpu/scatter_nd.py +2 -0
- mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
- mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -4
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -4
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +14 -6
- mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +22 -8
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +11 -6
- mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +21 -10
- mindspore/ops/_op_impl/tbe/__init__.py +6 -4
- mindspore/ops/_op_impl/tbe/atomic_addr_clean.py +1 -1
- mindspore/ops/_op_impl/tbe/avg_pool.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_3d.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_3d_grad.py +4 -4
- mindspore/ops/_op_impl/tbe/avg_pool_ds.py +2 -2
- mindspore/ops/_op_impl/tbe/avg_pool_grad.py +3 -3
- mindspore/ops/_op_impl/tbe/avg_pool_grad_vm.py +3 -3
- mindspore/ops/_op_impl/tbe/batch_to_space.py +1 -1
- mindspore/ops/_op_impl/tbe/batch_to_space_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer.py +2 -2
- mindspore/ops/_op_impl/tbe/bn_infer_ds.py +3 -2
- mindspore/ops/_op_impl/tbe/broadcast_to.py +1 -1
- mindspore/ops/_op_impl/tbe/depthwise_conv2d.py +3 -3
- mindspore/ops/_op_impl/tbe/expand_dims.py +1 -1
- mindspore/ops/_op_impl/tbe/gather_v2.py +56 -0
- mindspore/ops/_op_impl/tbe/im2col.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/mem_set.py +38 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_add.py +3 -0
- mindspore/ops/_op_impl/tbe/scatter_nd_d.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch.py +1 -1
- mindspore/ops/_op_impl/tbe/space_to_batch_nd.py +2 -2
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +241 -0
- mindspore/ops/_utils/utils.py +10 -2
- mindspore/ops/_vmap/vmap_array_ops.py +5 -3
- mindspore/ops/_vmap/vmap_base.py +5 -4
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +11 -6
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +135 -11
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/__init__.py +7 -5
- mindspore/ops/composite/base.py +78 -34
- mindspore/ops/composite/math_ops.py +5 -695
- mindspore/ops/composite/multitype_ops/_compile_utils.py +403 -97
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +28 -22
- mindspore/ops/composite/multitype_ops/add_impl.py +69 -7
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/div_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/getitem_impl.py +48 -10
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/greater_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/less_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +2 -2
- mindspore/ops/composite/multitype_ops/mod_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/mul_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/negative_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +10 -7
- mindspore/ops/composite/multitype_ops/sub_impl.py +1 -0
- mindspore/ops/composite/multitype_ops/uadd_impl.py +2 -0
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +41 -4
- mindspore/ops/function/array_func.py +1108 -467
- mindspore/ops/function/clip_func.py +94 -27
- mindspore/ops/function/debug_func.py +3 -1
- mindspore/ops/function/grad/grad_func.py +82 -73
- mindspore/ops/function/image_func.py +28 -12
- mindspore/ops/function/linalg_func.py +135 -39
- mindspore/ops/function/math_func.py +3779 -894
- mindspore/ops/function/nn_func.py +1584 -657
- mindspore/ops/function/parameter_func.py +13 -3
- mindspore/ops/function/random_func.py +247 -153
- mindspore/ops/function/sparse_func.py +14 -11
- mindspore/ops/function/sparse_unary_func.py +173 -47
- mindspore/ops/function/spectral_func.py +8 -4
- mindspore/ops/function/vmap_func.py +8 -7
- mindspore/ops/functional.py +47 -16
- mindspore/ops/op_info_register.py +346 -86
- mindspore/ops/operations/__init__.py +38 -22
- mindspore/ops/operations/_grad_ops.py +145 -149
- mindspore/ops/operations/_inner_ops.py +298 -56
- mindspore/ops/operations/_ms_kernel.py +3 -3
- mindspore/ops/operations/_quant_ops.py +24 -28
- mindspore/ops/operations/_rl_inner_ops.py +9 -7
- mindspore/ops/operations/_scalar_ops.py +115 -0
- mindspore/ops/operations/_sequence_ops.py +148 -10
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/_thor_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +1239 -561
- mindspore/ops/operations/comm_ops.py +166 -90
- mindspore/ops/operations/control_ops.py +3 -3
- mindspore/ops/operations/custom_ops.py +124 -102
- mindspore/ops/operations/debug_ops.py +24 -11
- mindspore/ops/operations/image_ops.py +86 -71
- mindspore/ops/operations/inner_ops.py +18 -13
- mindspore/ops/operations/linalg_ops.py +30 -11
- mindspore/ops/operations/math_ops.py +1730 -435
- mindspore/ops/operations/nn_ops.py +1953 -943
- mindspore/ops/operations/other_ops.py +65 -43
- mindspore/ops/operations/random_ops.py +258 -98
- mindspore/ops/operations/rl_ops.py +4 -36
- mindspore/ops/operations/sparse_ops.py +38 -33
- mindspore/ops/operations/spectral_ops.py +8 -4
- mindspore/ops/primitive.py +66 -44
- mindspore/ops/signature.py +5 -5
- mindspore/parallel/_auto_parallel_context.py +80 -19
- mindspore/parallel/_cost_model_context.py +42 -0
- mindspore/parallel/_offload_context.py +162 -72
- mindspore/parallel/_parallel_serialization.py +2 -2
- mindspore/parallel/_ps_context.py +16 -4
- mindspore/parallel/_recovery_context.py +2 -1
- mindspore/parallel/_tensor.py +15 -13
- mindspore/parallel/_transformer/layers.py +8 -6
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +7 -7
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +34 -14
- mindspore/parallel/_utils.py +36 -14
- mindspore/parallel/algo_parameter_config.py +114 -20
- mindspore/parallel/checkpoint_transform.py +16 -18
- mindspore/parallel/shard.py +16 -13
- mindspore/profiler/__init__.py +1 -1
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +3 -2
- mindspore/profiler/envprofiling.py +11 -4
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +94 -0
- mindspore/profiler/parser/ascend_fpbp_generator.py +76 -0
- mindspore/profiler/parser/ascend_hccl_generator.py +288 -0
- mindspore/profiler/parser/ascend_msprof_exporter.py +213 -0
- mindspore/profiler/parser/ascend_msprof_generator.py +199 -0
- mindspore/profiler/parser/ascend_op_generator.py +276 -0
- mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
- mindspore/profiler/parser/ascend_timeline_generator.py +110 -54
- mindspore/profiler/parser/base_timeline_generator.py +11 -7
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +45 -46
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +92 -73
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +36 -11
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +4 -5
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +678 -377
- mindspore/rewrite/api/node.py +211 -54
- mindspore/rewrite/api/node_type.py +5 -0
- mindspore/rewrite/api/pattern_engine.py +22 -23
- mindspore/rewrite/api/scoped_value.py +20 -17
- mindspore/rewrite/api/symbol_tree.py +252 -106
- mindspore/rewrite/api/tree_node_helper.py +3 -0
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +97 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +51 -51
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{ops/bprop_mindir → rewrite/node}/__init__.py +9 -4
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +313 -247
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/node/node_topological_manager.py +243 -0
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +225 -239
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +179 -218
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +23 -20
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +202 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/sparsify.py +4 -1
- mindspore/rewrite/sparsify/utils.py +11 -5
- mindspore/rewrite/symbol_tree.py +577 -732
- mindspore/rewrite/symbol_tree_builder.py +9 -175
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +46 -39
- mindspore/run_check/run_check.py +3 -2
- mindspore/{scipy/sparse → safeguard}/__init__.py +4 -5
- mindspore/safeguard/rewrite_obfuscation.py +517 -0
- mindspore/scipy/__init__.py +1 -1
- mindspore/scipy/linalg.py +67 -61
- mindspore/scipy/ops.py +5 -41
- mindspore/scipy/ops_grad.py +3 -2
- mindspore/scipy/ops_wrapper.py +5 -5
- mindspore/scipy/optimize/line_search.py +8 -8
- mindspore/scipy/optimize/linear_sum_assignment.py +4 -4
- mindspore/scipy/optimize/minimize.py +16 -12
- mindspore/scipy/utils.py +1 -52
- mindspore/scipy/utils_const.py +4 -4
- mindspore/train/__init__.py +4 -4
- mindspore/train/_utils.py +13 -5
- mindspore/train/amp.py +410 -148
- mindspore/train/anf_ir_pb2.py +16 -4
- mindspore/train/callback/_backup_and_restore.py +8 -11
- mindspore/train/callback/_callback.py +80 -3
- mindspore/train/callback/_checkpoint.py +82 -51
- mindspore/train/callback/_early_stop.py +12 -15
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_lambda_callback.py +13 -13
- mindspore/train/callback/_landscape.py +21 -17
- mindspore/train/callback/_loss_monitor.py +9 -10
- mindspore/train/callback/_on_request_exit.py +16 -33
- mindspore/train/callback/_reduce_lr_on_plateau.py +21 -24
- mindspore/train/callback/_summary_collector.py +44 -30
- mindspore/train/callback/_time_monitor.py +62 -12
- mindspore/train/data_sink.py +10 -16
- mindspore/train/dataset_helper.py +154 -86
- mindspore/train/loss_scale_manager.py +14 -9
- mindspore/train/metrics/__init__.py +10 -2
- mindspore/train/metrics/accuracy.py +1 -1
- mindspore/train/metrics/auc.py +1 -1
- mindspore/train/metrics/bleu_score.py +2 -2
- mindspore/train/metrics/confusion_matrix.py +14 -14
- mindspore/train/metrics/cosine_similarity.py +3 -3
- mindspore/train/metrics/dice.py +1 -1
- mindspore/train/metrics/fbeta.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +8 -6
- mindspore/train/metrics/mean_surface_distance.py +5 -4
- mindspore/train/metrics/metric.py +49 -17
- mindspore/train/metrics/occlusion_sensitivity.py +4 -4
- mindspore/train/metrics/perplexity.py +1 -1
- mindspore/train/metrics/precision.py +2 -2
- mindspore/train/metrics/recall.py +2 -3
- mindspore/train/metrics/roc.py +7 -7
- mindspore/train/metrics/root_mean_square_surface_distance.py +5 -4
- mindspore/train/metrics/topk.py +7 -4
- mindspore/train/mind_ir_pb2.py +193 -48
- mindspore/train/model.py +377 -133
- mindspore/train/serialization.py +697 -245
- mindspore/train/summary/_summary_adapter.py +5 -2
- mindspore/train/summary/_writer_pool.py +4 -3
- mindspore/train/summary/summary_record.py +25 -23
- mindspore/train/train_thor/convert_utils.py +39 -23
- mindspore/train/train_thor/dataset_helper.py +4 -3
- mindspore/train/train_thor/model_thor.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/METADATA +7 -8
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/RECORD +633 -804
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/entry_points.txt +0 -1
- mindspore/_akg/akg/tvm/contrib/debugger/__init__.py +0 -16
- mindspore/_akg/akg/tvm/contrib/debugger/debug_result.py +0 -274
- mindspore/_akg/akg/tvm/contrib/debugger/debug_runtime.py +0 -259
- mindspore/_akg/akg/tvm/contrib/peak.py +0 -341
- mindspore/_akg/akg/tvm/contrib/rpc.py +0 -25
- mindspore/_akg/akg/tvm/contrib/xcode.py +0 -257
- mindspore/_akg/akg/tvm/exec/__init__.py +0 -17
- mindspore/_akg/akg/tvm/exec/autotvm_log_editor.py +0 -60
- mindspore/_akg/akg/tvm/exec/measure_peak.py +0 -48
- mindspore/_akg/akg/tvm/exec/query_rpc_tracker.py +0 -48
- mindspore/_akg/akg/tvm/exec/rpc_proxy.py +0 -98
- mindspore/_akg/akg/tvm/exec/rpc_server.py +0 -88
- mindspore/_akg/akg/tvm/exec/rpc_tracker.py +0 -62
- mindspore/_akg/akg/tvm/rpc/__init__.py +0 -29
- mindspore/_akg/akg/tvm/rpc/base.py +0 -182
- mindspore/_akg/akg/tvm/rpc/client.py +0 -436
- mindspore/_akg/akg/tvm/rpc/proxy.py +0 -595
- mindspore/_akg/akg/tvm/rpc/server.py +0 -413
- mindspore/_akg/akg/tvm/rpc/tornado_util.py +0 -121
- mindspore/_akg/akg/tvm/rpc/tracker.py +0 -431
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -57
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +0 -49
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gather.py +0 -43
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/tile.py +0 -54
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/_extends/parse/jit_fallback_modules.py +0 -51
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/dataset/engine/graphdata.py +0 -1586
- mindspore/include/api/net.h +0 -142
- mindspore/ops/_grad/grad_array_ops.py +0 -1347
- mindspore/ops/_grad/grad_clip_ops.py +0 -84
- mindspore/ops/_grad/grad_debug_ops.py +0 -68
- mindspore/ops/_grad/grad_inner_ops.py +0 -235
- mindspore/ops/_grad/grad_math_ops.py +0 -1684
- mindspore/ops/_grad/grad_nn_ops.py +0 -1529
- mindspore/ops/_grad/grad_other_ops.py +0 -89
- mindspore/ops/_grad/grad_sequence_ops.py +0 -296
- mindspore/ops/_grad/grad_sparse.py +0 -323
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -249
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -195
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/bprop_mindir/AdaptiveAvgPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/AdaptiveMaxPool2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ApproximateEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Argmax_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Argmin_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/AssignSub_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Assign_bprop.mindir +0 -17
- mindspore/ops/bprop_mindir/AvgPool3D_bprop.mindir +0 -150
- mindspore/ops/bprop_mindir/AvgPool_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/BCEWithLogitsLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/BatchNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BatchToSpaceND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/BiasAddGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/BinaryCrossEntropy_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/BroadcastTo_bprop.mindir +0 -306
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/CTCLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Concat_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Conv2DBackpropFilter_bprop.mindir +0 -240
- mindspore/ops/bprop_mindir/Conv2DBackpropInput_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv2DTranspose_bprop.mindir +0 -247
- mindspore/ops/bprop_mindir/Conv3DTranspose_bprop.mindir +0 -315
- mindspore/ops/bprop_mindir/Conv3D_bprop.mindir +0 -278
- mindspore/ops/bprop_mindir/DType_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/DeformableOffsets_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/DepthToSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/DiagPart_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Dropout2D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Dropout3D_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DropoutDoMask_bprop.mindir +0 -25
- mindspore/ops/bprop_mindir/DropoutGenMask_bprop.mindir +0 -18
- mindspore/ops/bprop_mindir/DropoutGrad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Dropout_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicGRUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicRNN_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DynamicShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Elu_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Equal_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ExpandDims_bprop.mindir +0 -58
- mindspore/ops/bprop_mindir/FastGeLU_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Flatten_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/FloorDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/GatherD_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/GatherNd_bprop.mindir +0 -57
- mindspore/ops/bprop_mindir/Gather_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/GreaterEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Greater_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/HSigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/HSwish_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/IOU_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/InstanceNorm_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/IsFinite_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsInf_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/IsNan_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/KLDivLoss_bprop.mindir +0 -126
- mindspore/ops/bprop_mindir/L2Loss_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/L2Normalize_bprop.mindir +0 -30
- mindspore/ops/bprop_mindir/LRN_bprop.mindir +0 -43
- mindspore/ops/bprop_mindir/LayerNormGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/LessEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/Less_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LinSpace_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -13
- mindspore/ops/bprop_mindir/LogSoftmax_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/LogicalAnd_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/LogicalNot_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/MaskedSelect_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/MaxPool3DGradGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3DGrad_bprop.mindir +0 -74
- mindspore/ops/bprop_mindir/MaxPool3D_bprop.mindir +0 -75
- mindspore/ops/bprop_mindir/MaxPoolGradGrad_bprop.mindir +0 -65
- mindspore/ops/bprop_mindir/MaxPoolWithArgmax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Maximum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Minimum_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/MirrorPad_bprop.mindir +0 -27
- mindspore/ops/bprop_mindir/Mish_bprop.mindir +0 -35
- mindspore/ops/bprop_mindir/MulNoNan_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NLLLoss_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/NonZero_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/NotEqual_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/OneHot_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/OnesLike_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/PReLU_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Pad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Padding_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/RNNTLoss_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ROIAlign_bprop.mindir +0 -82
- mindspore/ops/bprop_mindir/Range_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Rank_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/ReLU6_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/ReLUV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ReduceAll_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReduceAny_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/ReluGrad_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Reshape_bprop.mindir +0 -60
- mindspore/ops/bprop_mindir/ResizeBilinear_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/ResizeNearestNeighbor_bprop.mindir +0 -89
- mindspore/ops/bprop_mindir/ReverseSequence_bprop.mindir +0 -52
- mindspore/ops/bprop_mindir/ReverseV2_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Round_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/ScatterMax_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterMin_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNdUpdate_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterNd_bprop.mindir +0 -24
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/ScatterUpdate_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SeLU_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/Select_bprop.mindir +0 -31
- mindspore/ops/bprop_mindir/Shape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/SigmoidCrossEntropyWithLogits_bprop.mindir +0 -21
- mindspore/ops/bprop_mindir/SigmoidGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Sigmoid_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Sign_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/Slice_bprop.mindir +0 -26
- mindspore/ops/bprop_mindir/SmoothL1Loss_bprop.mindir +0 -36
- mindspore/ops/bprop_mindir/SoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Softplus_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Softsign_bprop.mindir +0 -33
- mindspore/ops/bprop_mindir/Sort_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SpaceToBatchND_bprop.mindir +0 -28
- mindspore/ops/bprop_mindir/SpaceToDepth_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Split_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/Squeeze_bprop.mindir +0 -54
- mindspore/ops/bprop_mindir/StridedSliceGrad_bprop.mindir +0 -95
- mindspore/ops/bprop_mindir/StridedSlice_bprop.mindir +0 -98
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TanhGrad_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Tanh_bprop.mindir +0 -66
- mindspore/ops/bprop_mindir/TensorScatterAdd_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/TensorScatterUpdate_bprop.mindir +0 -29
- mindspore/ops/bprop_mindir/TensorShape_bprop.mindir +0 -14
- mindspore/ops/bprop_mindir/Tile_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TopK_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -23
- mindspore/ops/bprop_mindir/TruncateDiv_bprop.mindir +0 -19
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -20
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -16
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -22
- mindspore/ops/bprop_mindir/UpsampleNearest3D_bprop.mindir +0 -32
- mindspore/ops/bprop_mindir/UpsampleTrilinear3D_bprop.mindir +0 -38
- mindspore/ops/bprop_mindir/ZerosLike_bprop.mindir +0 -15
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- mindspore/rewrite/topological_manager.py +0 -203
- mindspore/scipy/sparse/linalg.py +0 -192
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/WHEEL +0 -0
- {mindspore-2.0.0rc1.dist-info → mindspore-2.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""optimizer"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
from collections import defaultdict
|
|
18
|
+
from typing import Iterable
|
|
19
|
+
from mindspore.ops import functional as F, composite as C, operations as P
|
|
20
|
+
|
|
21
|
+
from mindspore.nn.cell import Cell
|
|
22
|
+
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
23
|
+
from mindspore.common import Tensor
|
|
24
|
+
import mindspore.common.dtype as mstype
|
|
25
|
+
from mindspore import _checkparam as validator
|
|
26
|
+
from mindspore import log as logger
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
__all__ = ['Optimizer']
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Optimizer(Cell):
|
|
33
|
+
r"""
|
|
34
|
+
Base class for all optimizers.
|
|
35
|
+
|
|
36
|
+
.. warning::
|
|
37
|
+
This is an experimental optimizer API that is subject to change.
|
|
38
|
+
This module must be used with lr scheduler module in `LRScheduler Class
|
|
39
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.experimental.html#lrscheduler-class>`_ .
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
params (Union[list(Parameter), list(dict)]): an iterable of :class:`mindspore.Parameter` or
|
|
43
|
+
dict. Specifies what Tensors should be optimized.
|
|
44
|
+
defaults (dict): a dict containing default values of optimization
|
|
45
|
+
options (used when a parameter group doesn't specify them).
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
TypeError: If `learning_rate` is not one of int, float, Tensor.
|
|
49
|
+
TypeError: If element of `parameters` is neither Parameter nor dict.
|
|
50
|
+
TypeError: If `weight_decay` is neither float nor int.
|
|
51
|
+
ValueError: If `weight_decay` is less than 0.
|
|
52
|
+
ValueError: If `learning_rate` is a Tensor, but the dimension of tensor is greater than 1.
|
|
53
|
+
|
|
54
|
+
Supported Platforms:
|
|
55
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
56
|
+
|
|
57
|
+
Examples:
|
|
58
|
+
>>> import numpy as np
|
|
59
|
+
>>> import mindspore
|
|
60
|
+
>>> from mindspore import nn, Tensor, Parameter
|
|
61
|
+
>>> from mindspore import ops
|
|
62
|
+
>>> from mindspore.experimental import optim
|
|
63
|
+
>>>
|
|
64
|
+
>>> class MySGD(optim.Optimizer):
|
|
65
|
+
... def __init__(self, params, lr):
|
|
66
|
+
... defaults = dict(lr=lr)
|
|
67
|
+
... super(MySGD, self).__init__(params, defaults)
|
|
68
|
+
...
|
|
69
|
+
... def construct(self, gradients):
|
|
70
|
+
... for group_id, group in enumerate(self.param_groups):
|
|
71
|
+
... id = self.group_start_id[group_id]
|
|
72
|
+
... for i, param in enumerate(group["params"]):
|
|
73
|
+
... next_param = param + gradients[id+i] * group["lr"]
|
|
74
|
+
... ops.assign(param, next_param)
|
|
75
|
+
>>>
|
|
76
|
+
>>> net = nn.Dense(8, 2)
|
|
77
|
+
>>> data = Tensor(np.random.rand(20, 8).astype(np.float32))
|
|
78
|
+
>>> label = Tensor(np.random.rand(20, 2).astype(np.float32))
|
|
79
|
+
>>>
|
|
80
|
+
>>> optimizer = MySGD(net.trainable_params(), 0.01)
|
|
81
|
+
>>> optimizer.add_param_group({"params": Parameter([0.01, 0.02])})
|
|
82
|
+
>>>
|
|
83
|
+
>>> criterion = nn.MAELoss(reduction="mean")
|
|
84
|
+
>>>
|
|
85
|
+
>>> def forward_fn(data, label):
|
|
86
|
+
... logits = net(data)
|
|
87
|
+
... loss = criterion(logits, label)
|
|
88
|
+
... return loss, logits
|
|
89
|
+
>>>
|
|
90
|
+
>>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
|
|
91
|
+
>>>
|
|
92
|
+
>>> def train_step(data, label):
|
|
93
|
+
... (loss, _), grads = grad_fn(data, label)
|
|
94
|
+
... optimizer(grads)
|
|
95
|
+
... print(loss)
|
|
96
|
+
>>>
|
|
97
|
+
>>> train_step(data, label)
|
|
98
|
+
"""
|
|
99
|
+
def __init__(self, params, defaults):
|
|
100
|
+
super(Optimizer, self).__init__(auto_prefix=False)
|
|
101
|
+
|
|
102
|
+
param_groups = self._parameters_base_check(params, "params")
|
|
103
|
+
self.defaults = defaults
|
|
104
|
+
self.state = defaultdict(dict)
|
|
105
|
+
self.param_groups = []
|
|
106
|
+
self.parameters = []
|
|
107
|
+
self.map_ = C.Map()
|
|
108
|
+
self.group_start_id = [0]
|
|
109
|
+
if not isinstance(param_groups[0], dict):
|
|
110
|
+
param_groups = [{'params': param_groups}]
|
|
111
|
+
|
|
112
|
+
for param_group in param_groups:
|
|
113
|
+
self.add_param_group(param_group)
|
|
114
|
+
self.parameters = ParameterTuple(self.parameters)
|
|
115
|
+
self.hyper_map = C.HyperMap()
|
|
116
|
+
self.enable_tuple_broaden = True
|
|
117
|
+
|
|
118
|
+
def __repr__(self):
|
|
119
|
+
format_string = self.__class__.__name__ + ' ('
|
|
120
|
+
for i, group in enumerate(self.param_groups):
|
|
121
|
+
format_string += '\n'
|
|
122
|
+
format_string += 'Parameter Group {0}\n'.format(i)
|
|
123
|
+
for key in sorted(group.keys()):
|
|
124
|
+
if key != 'params':
|
|
125
|
+
format_string += ' {0}: {1}\n'.format(key, group[key].value()) \
|
|
126
|
+
if key == "lr" and isinstance(group[key], Parameter) \
|
|
127
|
+
else ' {0}: {1}\n'.format(key, group[key])
|
|
128
|
+
format_string += ')'
|
|
129
|
+
return format_string
|
|
130
|
+
|
|
131
|
+
def add_param_group(self, param_group):
|
|
132
|
+
r"""
|
|
133
|
+
Add a param group to the `Optimizer.param_groups`.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
param_group (dict): Specifies what Parameters should be optimized along with group
|
|
137
|
+
specific optimization options.
|
|
138
|
+
"""
|
|
139
|
+
group_id = len(self.param_groups)
|
|
140
|
+
param_group = self._preprocess_param_group(param_group)
|
|
141
|
+
self.parameters += tuple(param_group.get("params"))
|
|
142
|
+
|
|
143
|
+
for name, default in self.defaults.items():
|
|
144
|
+
if name not in param_group:
|
|
145
|
+
param_group.setdefault(name, default)
|
|
146
|
+
|
|
147
|
+
lr = self._build_single_lr(param_group.get("lr"), 'learning_rate_group_' + str(group_id))
|
|
148
|
+
weight_decay = self._preprocess_weight_decay(param_group.get("weight_decay", 0.0))
|
|
149
|
+
param_group["lr"] = lr
|
|
150
|
+
param_group["weight_decay"] = weight_decay
|
|
151
|
+
self.param_groups.append(param_group)
|
|
152
|
+
self.group_start_id.append(self.group_start_id[-1] + len(param_group.get("params")))
|
|
153
|
+
|
|
154
|
+
@staticmethod
|
|
155
|
+
def _parameters_base_check(parameters, param_info):
|
|
156
|
+
"""Parameters base check."""
|
|
157
|
+
if parameters is None:
|
|
158
|
+
raise ValueError(f"For 'Optimizer', the argument {param_info} can not be None.")
|
|
159
|
+
if not isinstance(parameters, Iterable):
|
|
160
|
+
raise TypeError(f"For 'Optimizer', the argument {param_info} must be Iterable type, "
|
|
161
|
+
f"but got {type(parameters)}.")
|
|
162
|
+
parameters = list(parameters)
|
|
163
|
+
|
|
164
|
+
if not parameters:
|
|
165
|
+
raise ValueError(f"For 'Optimizer', the argument {param_info} must not be empty.")
|
|
166
|
+
return parameters
|
|
167
|
+
|
|
168
|
+
def _decay_weight(self, weight_decay, params, gradients):
|
|
169
|
+
"""Apply weight decay."""
|
|
170
|
+
if weight_decay != 0.:
|
|
171
|
+
weight_decay = Tensor(weight_decay, mstype.float32)
|
|
172
|
+
gradients = self.map_(F.partial(_apply_decay, weight_decay), params, gradients)
|
|
173
|
+
return gradients
|
|
174
|
+
|
|
175
|
+
def _preprocess_param_group(self, param_group):
|
|
176
|
+
"""Preprocess param groups."""
|
|
177
|
+
if not isinstance(param_group, dict):
|
|
178
|
+
raise TypeError('Param group must be a dict.')
|
|
179
|
+
|
|
180
|
+
params = param_group['params']
|
|
181
|
+
if isinstance(params, Parameter):
|
|
182
|
+
param_group['params'] = [params]
|
|
183
|
+
elif isinstance(params, set):
|
|
184
|
+
raise TypeError('Optimizer parameters need to be organized in ordered collections, but '
|
|
185
|
+
'the ordering of tensors in sets will change between runs. '
|
|
186
|
+
'Please use a list instead.')
|
|
187
|
+
else:
|
|
188
|
+
param_group['params'] = list(params)
|
|
189
|
+
|
|
190
|
+
for param in param_group['params']:
|
|
191
|
+
if not isinstance(param, Parameter):
|
|
192
|
+
raise TypeError("Optimizer can only optimize Parameters, but one of the params is " + type(param))
|
|
193
|
+
|
|
194
|
+
if len(param_group['params']) != len(set(param_group['params'])):
|
|
195
|
+
logger.warning("Optimizer contains a parameter group with duplicate parameters.")
|
|
196
|
+
|
|
197
|
+
param_set = set()
|
|
198
|
+
for group in self.param_groups:
|
|
199
|
+
param_set.update(set(group['params']))
|
|
200
|
+
if not param_set.isdisjoint(set(param_group['params'])):
|
|
201
|
+
raise ValueError("some parameters appear in more than one parameter group.")
|
|
202
|
+
return param_group
|
|
203
|
+
|
|
204
|
+
def _build_single_lr(self, learning_rate, name):
|
|
205
|
+
"""Check lr value, and convert lr to a float or a Tensor."""
|
|
206
|
+
if isinstance(learning_rate, (float, int)):
|
|
207
|
+
learning_rate = float(learning_rate)
|
|
208
|
+
validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
|
|
209
|
+
return Parameter(Tensor(learning_rate, mstype.float32), name)
|
|
210
|
+
|
|
211
|
+
if isinstance(learning_rate, Tensor):
|
|
212
|
+
if learning_rate.ndim == 0:
|
|
213
|
+
return Parameter(learning_rate.astype(mstype.float32), name)
|
|
214
|
+
raise ValueError(f"For 'Optimizer', if 'learning_rate' is a Tensor, "
|
|
215
|
+
f"then it should be scalar Tensor")
|
|
216
|
+
|
|
217
|
+
raise TypeError("For 'Optimizer', the argument 'learning_rate' must be int, float or Tensor, "
|
|
218
|
+
"but got {}.".format(type(learning_rate)))
|
|
219
|
+
|
|
220
|
+
def _preprocess_weight_decay(self, weight_decay):
|
|
221
|
+
"""preprocess weight decay"""
|
|
222
|
+
if isinstance(weight_decay, (float, int)):
|
|
223
|
+
weight_decay = float(weight_decay)
|
|
224
|
+
validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name)
|
|
225
|
+
else:
|
|
226
|
+
raise TypeError("For 'Optimizer', the argument 'Weight_decay' must be int or "
|
|
227
|
+
"float.but got {}".format(type(weight_decay)))
|
|
228
|
+
return weight_decay
|
|
229
|
+
|
|
230
|
+
def construct(self, *hyper_params):
|
|
231
|
+
raise NotImplementedError
|
|
232
|
+
|
|
233
|
+
op_add = P.AddN()
|
|
234
|
+
op_gather = P.Gather()
|
|
235
|
+
op_mul = P.Mul()
|
|
236
|
+
|
|
237
|
+
_apply_decay = C.MultitypeFuncGraph("apply_decay")
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
@_apply_decay.register("Tensor", "Tensor", "RowTensor")
|
|
241
|
+
def _tensor_apply_decay_with_sparse(weight_decay, weight, gradient):
|
|
242
|
+
"""Get grad with weight_decay."""
|
|
243
|
+
indices = gradient.indices
|
|
244
|
+
values = op_add((op_gather(weight, indices, 0) * F.cast(weight_decay, F.dtype(weight)), gradient.values))
|
|
245
|
+
shape = gradient.dense_shape
|
|
246
|
+
return RowTensorInner(indices, values, shape)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
@_apply_decay.register("Tensor", "Tensor", "Tensor")
|
|
250
|
+
def _tensor_apply_decay(weight_decay, weight, gradient):
|
|
251
|
+
"""Get grad with weight_decay."""
|
|
252
|
+
return op_add((op_mul(weight, F.cast(weight_decay, F.dtype(weight))), gradient))
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
# Copyright 2023 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""sgd"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
from mindspore.ops import functional as F, composite as C, operations as P
|
|
19
|
+
from mindspore.common.tensor import Tensor
|
|
20
|
+
import mindspore.common.dtype as mstype
|
|
21
|
+
from mindspore import _checkparam as Validator
|
|
22
|
+
from mindspore.experimental.optim.optimizer import Optimizer
|
|
23
|
+
|
|
24
|
+
_sgd_opt = C.MultitypeFuncGraph("sgd_opt")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@_sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor",)
|
|
28
|
+
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, accum, stat):
|
|
29
|
+
"""Apply sgd optimizer to the weight parameter using Tensor."""
|
|
30
|
+
success = True
|
|
31
|
+
success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat))
|
|
32
|
+
return success
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SGD(Optimizer):
|
|
36
|
+
r"""
|
|
37
|
+
Stochastic Gradient Descent optimizer.
|
|
38
|
+
|
|
39
|
+
.. math::
|
|
40
|
+
v_{t+1} = u \ast v_{t} + gradient \ast (1-dampening)
|
|
41
|
+
|
|
42
|
+
If nesterov is True:
|
|
43
|
+
|
|
44
|
+
.. math::
|
|
45
|
+
p_{t+1} = p_{t} - lr \ast (gradient + u \ast v_{t+1})
|
|
46
|
+
|
|
47
|
+
If nesterov is False:
|
|
48
|
+
|
|
49
|
+
.. math::
|
|
50
|
+
p_{t+1} = p_{t} - lr \ast v_{t+1}
|
|
51
|
+
|
|
52
|
+
To be noticed, for the first step, :math:`v_{t+1} = gradient`.
|
|
53
|
+
|
|
54
|
+
Here : where p, v and u denote the parameters, accum, and momentum respectively.
|
|
55
|
+
|
|
56
|
+
.. warning::
|
|
57
|
+
This is an experimental optimizer API that is subject to change.
|
|
58
|
+
This module must be used with lr scheduler module in `LRScheduler Class
|
|
59
|
+
<https://www.mindspore.cn/docs/en/r2.2/api_python/mindspore.experimental.html#lrscheduler-class>`_ .
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
params (Union[list(Parameter), list(dict)]): list of parameters to optimize or dicts defining
|
|
63
|
+
parameter groups.
|
|
64
|
+
lr (Union[int, float, Tensor]): learning rate.
|
|
65
|
+
momentum (Union[int, float], optional): momentum factor. Default: ``0``.
|
|
66
|
+
weight_decay (float, optional): weight decay (L2 penalty). Default: ``0``.
|
|
67
|
+
dampening (Union[int, float], optional): dampening for momentum. Default: ``0``.
|
|
68
|
+
nesterov (bool, optional): enables Nesterov momentum. Default: ``False``.
|
|
69
|
+
|
|
70
|
+
Keyword Args:
|
|
71
|
+
maximize (bool, optional): maximize the params based on the objective, instead of minimizing.
|
|
72
|
+
Default: ``False``.
|
|
73
|
+
|
|
74
|
+
Inputs:
|
|
75
|
+
- **gradients** (tuple[Tensor]) - The gradients of `params`.
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
ValueError: If the learning rate is not int, float or Tensor.
|
|
79
|
+
ValueError: If the learning rate is less than 0.
|
|
80
|
+
ValueError: If the `momentum` or `weight_decay` value is less than 0.0.
|
|
81
|
+
ValueError: If the `momentum`, `dampening` or `weight_decay` value is not int or float.
|
|
82
|
+
ValueError: If the `nesterov` and `maximize` is not bool.
|
|
83
|
+
ValueError: If the `nesterov` is true, `momentum` is not positive or `dampening` is not 0.0.
|
|
84
|
+
|
|
85
|
+
Supported Platforms:
|
|
86
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
87
|
+
|
|
88
|
+
Examples:
|
|
89
|
+
>>> import mindspore
|
|
90
|
+
>>> from mindspore import nn
|
|
91
|
+
>>> from mindspore.experimental import optim
|
|
92
|
+
>>> # Define the network structure of LeNet5. Refer to
|
|
93
|
+
>>> # https://gitee.com/mindspore/docs/blob/r2.2/docs/mindspore/code/lenet.py
|
|
94
|
+
>>> net = LeNet5()
|
|
95
|
+
>>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
|
|
96
|
+
>>> optimizer = optim.SGD(net.trainable_params(), lr=0.1)
|
|
97
|
+
>>> def forward_fn(data, label):
|
|
98
|
+
... logits = net(data)
|
|
99
|
+
... loss = loss_fn(logits, label)
|
|
100
|
+
... return loss, logits
|
|
101
|
+
>>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
|
|
102
|
+
>>> def train_step(data, label):
|
|
103
|
+
... (loss, _), grads = grad_fn(data, label)
|
|
104
|
+
... optimizer(grads)
|
|
105
|
+
... return loss
|
|
106
|
+
"""
|
|
107
|
+
def __init__(self, params, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False, *,
|
|
108
|
+
maximize=False):
|
|
109
|
+
Validator.check_value_type("lr", lr, [float, int, Tensor], self.cls_name)
|
|
110
|
+
if lr < 0.0:
|
|
111
|
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
112
|
+
Validator.check_value_type("momentum", momentum, [int, float], self.cls_name)
|
|
113
|
+
if momentum < 0.0:
|
|
114
|
+
raise ValueError("Invalid momentum value: {}".format(momentum))
|
|
115
|
+
momentum = float(momentum)
|
|
116
|
+
Validator.check_value_type("nesterov", nesterov, [bool], self.cls_name)
|
|
117
|
+
Validator.check_value_type("maximize", maximize, [bool], self.cls_name)
|
|
118
|
+
|
|
119
|
+
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
|
|
120
|
+
weight_decay=weight_decay, nesterov=nesterov,
|
|
121
|
+
maximize=maximize, grad_centralization=False)
|
|
122
|
+
super(SGD, self).__init__(params, defaults)
|
|
123
|
+
for group in self.param_groups:
|
|
124
|
+
Validator.check_value_type("dampening", group.get("dampening"), [int, float], self.cls_name)
|
|
125
|
+
group["dampening"] = float(group.get("dampening"))
|
|
126
|
+
if nesterov and (momentum <= 0.0 or dampening != 0.0):
|
|
127
|
+
raise ValueError("For 'SGD', if 'nesterov' is true, 'momentum' must be > 0.0 and 'dampening' must "
|
|
128
|
+
"equal to 0.0, but got 'momentum' {}, 'dampening' {}".format(momentum, dampening))
|
|
129
|
+
self.accum = self.parameters.clone(prefix="accum", init='zeros')
|
|
130
|
+
self.stat = self.parameters.clone(prefix="stat", init='ones')
|
|
131
|
+
self.op_cast = P.Cast()
|
|
132
|
+
|
|
133
|
+
def construct(self, gradients):
|
|
134
|
+
for group_id, group in enumerate(self.param_groups):
|
|
135
|
+
opt = P.SGD(group.get("dampening"), group.get("weight_decay"), group.get("nesterov"))
|
|
136
|
+
lr = group.get("lr")
|
|
137
|
+
if isinstance(lr, float):
|
|
138
|
+
lr = self.op_cast(group.get("lr"), mstype.float32)
|
|
139
|
+
maximize = group.get("maximize")
|
|
140
|
+
momentum = self.op_cast(group.get("momentum"), mstype.float32)
|
|
141
|
+
start_id = self.group_start_id[group_id]
|
|
142
|
+
end_id = self.group_start_id[group_id+1]
|
|
143
|
+
grads = gradients[start_id: end_id] if not maximize else -gradients[start_id: end_id]
|
|
144
|
+
self.hyper_map(F.partial(_sgd_opt, opt, momentum, lr), grads,
|
|
145
|
+
self.parameters[start_id: end_id], self.accum[start_id: end_id],
|
|
146
|
+
self.stat[start_id: end_id])
|
|
147
|
+
return True
|
mindspore/gen_ops.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
# Copyright 2023-2025 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Generate operator definition from ops.yaml
|
|
17
|
+
"""
|
|
18
|
+
import sys
|
|
19
|
+
import os
|
|
20
|
+
import yaml
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def generate_py_op_func(yaml_data, doc_data):
|
|
24
|
+
"""
|
|
25
|
+
generate python operator function
|
|
26
|
+
"""
|
|
27
|
+
gen_py = ''
|
|
28
|
+
|
|
29
|
+
op_desc_dict = {}
|
|
30
|
+
for operator_name, operator_desc in doc_data.items():
|
|
31
|
+
desc = operator_desc.get("description")
|
|
32
|
+
op_desc_dict[operator_name] = desc
|
|
33
|
+
|
|
34
|
+
for operator_name, operator_data in yaml_data.items():
|
|
35
|
+
description = op_desc_dict.get(operator_name)
|
|
36
|
+
args = operator_data.get('args')
|
|
37
|
+
func_name = operator_data.get('func_name')
|
|
38
|
+
if func_name is None:
|
|
39
|
+
func_name = operator_name
|
|
40
|
+
|
|
41
|
+
class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
|
|
42
|
+
func_args = []
|
|
43
|
+
primitive_init_args = []
|
|
44
|
+
input_args = []
|
|
45
|
+
for arg_name, arg_info in args.items():
|
|
46
|
+
dtype = arg_info.get('dtype')
|
|
47
|
+
init_value = arg_info.get('init')
|
|
48
|
+
if init_value:
|
|
49
|
+
if dtype == 'str':
|
|
50
|
+
init_value = '"' + init_value + '"'
|
|
51
|
+
func_args.append(f"""{arg_name}={init_value}""")
|
|
52
|
+
primitive_init_args.append(arg_name)
|
|
53
|
+
else:
|
|
54
|
+
func_args.append(arg_name)
|
|
55
|
+
input_args.append(arg_name)
|
|
56
|
+
|
|
57
|
+
function_code = f"""
|
|
58
|
+
def {func_name}({', '.join(arg for arg in func_args)}):
|
|
59
|
+
\"\"\"
|
|
60
|
+
{description}
|
|
61
|
+
\"\"\"
|
|
62
|
+
{operator_name}_op = _get_cache_prim(P.{class_name})({', '.join(arg_name for arg_name in primitive_init_args)})
|
|
63
|
+
return {operator_name}_op({', '.join(arg_name for arg_name in input_args)})
|
|
64
|
+
"""
|
|
65
|
+
gen_py += function_code
|
|
66
|
+
|
|
67
|
+
return gen_py
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def generate_py_primitive(yaml_data):
|
|
71
|
+
"""
|
|
72
|
+
generate python primitive
|
|
73
|
+
"""
|
|
74
|
+
gen_py = ''
|
|
75
|
+
for operator_name, operator_data in yaml_data.items():
|
|
76
|
+
args = operator_data.get('args')
|
|
77
|
+
func_name = operator_data.get('func_name')
|
|
78
|
+
if func_name is None:
|
|
79
|
+
func_name = operator_name
|
|
80
|
+
|
|
81
|
+
class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
|
|
82
|
+
|
|
83
|
+
init_args_with_default = []
|
|
84
|
+
init_args = []
|
|
85
|
+
args_assign = []
|
|
86
|
+
for arg_name, arg_info in args.items():
|
|
87
|
+
dtype = arg_info.get('dtype')
|
|
88
|
+
type_cast = arg_info.get('type_cast')
|
|
89
|
+
type_cast_set = None
|
|
90
|
+
if type_cast:
|
|
91
|
+
type_cast_set = {ct.strip() for ct in type_cast.split(",")}
|
|
92
|
+
|
|
93
|
+
init_value = arg_info.get('init')
|
|
94
|
+
if init_value is None:
|
|
95
|
+
continue
|
|
96
|
+
|
|
97
|
+
if dtype == 'str':
|
|
98
|
+
init_value = '"' + init_value + '"'
|
|
99
|
+
init_args_with_default.append(f"""{arg_name}={init_value}""")
|
|
100
|
+
init_args.append(arg_name)
|
|
101
|
+
|
|
102
|
+
assign_str = f""" self.{arg_name} = """
|
|
103
|
+
|
|
104
|
+
if type_cast_set:
|
|
105
|
+
assign_str += f'type_it({arg_name}, '
|
|
106
|
+
type_cast_list = []
|
|
107
|
+
|
|
108
|
+
if 'int' in type_cast_set:
|
|
109
|
+
type_cast_list.append('INT')
|
|
110
|
+
if 'tuple[int]' in type_cast_list:
|
|
111
|
+
type_cast_list.append('TUPLE')
|
|
112
|
+
#add more type cast kind here
|
|
113
|
+
|
|
114
|
+
assign_str += 'TypeCastKind.' + '_OR_'.join(ct for ct in type_cast_list)
|
|
115
|
+
if dtype == 'tuple[int]':
|
|
116
|
+
assign_str += '_TO_TUPLE)'
|
|
117
|
+
if dtype == 'list[int]':
|
|
118
|
+
assign_str += '_TO_LIST)'
|
|
119
|
+
else:
|
|
120
|
+
assign_str += arg_name
|
|
121
|
+
args_assign.append(assign_str)
|
|
122
|
+
|
|
123
|
+
args_assign = '\n'.join(assign for assign in args_assign)
|
|
124
|
+
primitive_code = f"""
|
|
125
|
+
class {class_name}(Primitive):
|
|
126
|
+
def __init__(self, {', '.join(init_args_with_default)}):
|
|
127
|
+
{args_assign}
|
|
128
|
+
def __call__(self, *args):
|
|
129
|
+
super.__call__(self, *args, {', '.join([f'self.{arg}' for arg in init_args])})
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
gen_py += primitive_code
|
|
133
|
+
return gen_py
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def generate_cc_opdef(yaml_data):
|
|
137
|
+
"""
|
|
138
|
+
generate OpDef
|
|
139
|
+
"""
|
|
140
|
+
gen_cc = ''
|
|
141
|
+
opdef_map_str = f"""
|
|
142
|
+
std::unordered_map<std::string, OpDefPtr> gOpDefTable = {{"""
|
|
143
|
+
|
|
144
|
+
for operator_name, operator_data in yaml_data.items():
|
|
145
|
+
args = operator_data.get('args')
|
|
146
|
+
returns = operator_data.get('returns')
|
|
147
|
+
func_name = operator_data.get('func_name')
|
|
148
|
+
if func_name is None:
|
|
149
|
+
func_name = operator_name
|
|
150
|
+
|
|
151
|
+
class_name = ''.join(word.capitalize() for word in operator_name.split('_'))
|
|
152
|
+
opdef_map_str += f"""
|
|
153
|
+
{{"{operator_name}", &g{class_name}}},"""
|
|
154
|
+
|
|
155
|
+
opdef_cc = f"""
|
|
156
|
+
OpDef g{class_name} = {{
|
|
157
|
+
.name_ = "{operator_name}","""
|
|
158
|
+
opdef_cc += f"""
|
|
159
|
+
.args_ = {{"""
|
|
160
|
+
|
|
161
|
+
for arg_name, arg_info in args.items():
|
|
162
|
+
dtype = arg_info.get('dtype')
|
|
163
|
+
init = arg_info.get('init')
|
|
164
|
+
if init is None:
|
|
165
|
+
init = 0
|
|
166
|
+
else:
|
|
167
|
+
init = 1
|
|
168
|
+
cc_dtype_str = 'DT_' + dtype.replace('[', '_').replace(']', '').replace('tuple', 'array').replace(
|
|
169
|
+
'list', 'array').upper()
|
|
170
|
+
cc_dtype_str.replace('TUPLE', 'ARRAY').replace('LIST', 'ARRAY')
|
|
171
|
+
opdef_cc += f"""
|
|
172
|
+
{{.arg_name_ = "{arg_name}", .arg_dtype_ = {cc_dtype_str}, .as_init_arg_ = {init}}},"""
|
|
173
|
+
opdef_cc += f"""
|
|
174
|
+
}},"""
|
|
175
|
+
|
|
176
|
+
opdef_cc += f"""
|
|
177
|
+
.returns_ = {{"""
|
|
178
|
+
|
|
179
|
+
for return_name, return_info in returns.items():
|
|
180
|
+
return_dtype = return_info.get('dtype')
|
|
181
|
+
cc_return_type_str = 'DT_' + return_dtype.replace('[', '_').replace(']', '').replace(
|
|
182
|
+
'tuple', 'array').replace('list', 'array').upper()
|
|
183
|
+
opdef_cc += f"""
|
|
184
|
+
{{.arg_name_ = "{return_name}", .arg_dtype_ = {cc_return_type_str}}},"""
|
|
185
|
+
|
|
186
|
+
opdef_cc += f"""
|
|
187
|
+
}},"""
|
|
188
|
+
|
|
189
|
+
opdef_cc += f"""
|
|
190
|
+
}};"""
|
|
191
|
+
gen_cc += opdef_cc
|
|
192
|
+
|
|
193
|
+
opdef_map_str += f"""
|
|
194
|
+
}};"""
|
|
195
|
+
gen_cc += opdef_map_str
|
|
196
|
+
return gen_cc
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
if __name__ == "__main__":
|
|
200
|
+
work_path = ''
|
|
201
|
+
if len(sys.argv) > 1:
|
|
202
|
+
work_path = sys.argv[1]
|
|
203
|
+
|
|
204
|
+
yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops.yaml')
|
|
205
|
+
doc_yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_doc.yaml')
|
|
206
|
+
op_py_path = os.path.join(work_path, 'mindspore/python/mindspore/gen_ops_def.py')
|
|
207
|
+
op_cc_path = os.path.join(work_path, 'mindspore/core/ops/gen_ops_def.cc')
|
|
208
|
+
|
|
209
|
+
yaml_str = None
|
|
210
|
+
with open(yaml_path, 'r') as yaml_file:
|
|
211
|
+
yaml_str = yaml.safe_load(yaml_file)
|
|
212
|
+
|
|
213
|
+
doc_str = None
|
|
214
|
+
with open(doc_yaml_path, 'r') as doc_file:
|
|
215
|
+
doc_str = yaml.safe_load(doc_file)
|
|
216
|
+
|
|
217
|
+
cc_code = generate_cc_opdef(yaml_str)
|
|
218
|
+
cc_code += f"""
|
|
219
|
+
}} // namespace mindspore::ops"""
|
|
220
|
+
|
|
221
|
+
py_licence_str = f"""# Copyright 2023 Huawei Technologies Co., Ltd
|
|
222
|
+
#
|
|
223
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
224
|
+
# you may not use this file except in compliance with the License.
|
|
225
|
+
# You may obtain a copy of the License at
|
|
226
|
+
#
|
|
227
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
228
|
+
#
|
|
229
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
230
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
231
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
232
|
+
# See the License for the specific language governing permissions and
|
|
233
|
+
# limitations under the License.
|
|
234
|
+
# ============================================================================
|
|
235
|
+
"""
|
|
236
|
+
pyheader = f"""
|
|
237
|
+
\"\"\"Operators definition generated by gen_os.py, includes functions and primitive classes.\"\"\"
|
|
238
|
+
|
|
239
|
+
from mindspore.ops.primitive import Primitive
|
|
240
|
+
from mindspore.ops import operations as P
|
|
241
|
+
from mindspore.ops import functional as F
|
|
242
|
+
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
243
|
+
from mindspore.ops.arg_dtype_cast import TypeCastKind, type_it
|
|
244
|
+
"""
|
|
245
|
+
cc_license_str = f"""/**
|
|
246
|
+
* Copyright 2023 Huawei Technologies Co., Ltd
|
|
247
|
+
*
|
|
248
|
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
249
|
+
* you may not use this file except in compliance with the License.
|
|
250
|
+
* You may obtain a copy of the License at
|
|
251
|
+
*
|
|
252
|
+
* http://www.apache.org/licenses/LICENSE-2.0
|
|
253
|
+
*
|
|
254
|
+
* Unless required by applicable law or agreed to in writing, software
|
|
255
|
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
256
|
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
257
|
+
* See the License for the specific language governing permissions and
|
|
258
|
+
* limitations under the License.
|
|
259
|
+
*/"""
|
|
260
|
+
|
|
261
|
+
ccheader = f"""
|
|
262
|
+
#include "op_def.h"
|
|
263
|
+
namespace mindspore::ops {{
|
|
264
|
+
"""
|
|
265
|
+
py_prim = generate_py_primitive(yaml_str)
|
|
266
|
+
py_func = generate_py_op_func(yaml_str, doc_str)
|
|
267
|
+
py_file = None
|
|
268
|
+
with open(op_py_path, 'w') as py_file:
|
|
269
|
+
py_file.write(py_licence_str + pyheader + py_prim + py_func)
|
|
270
|
+
|
|
271
|
+
cc_file = None
|
|
272
|
+
with open(op_cc_path, 'w') as cc_file:
|
|
273
|
+
cc_file.write(cc_license_str + ccheader + cc_code)
|