mindspore 2.1.0__cp39-cp39-win_amd64.whl → 2.2.10__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +74 -104
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/amp.py +47 -11
- mindspore/atlprov.dll +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +141 -88
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +84 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +28 -19
- mindspore/ops/operations/_grad_ops.py +72 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +189 -141
- mindspore/ops/operations/nn_ops.py +794 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +5 -3
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +433 -479
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
mindspore/ops/_utils/utils.py
CHANGED
|
@@ -78,6 +78,11 @@ def get_broadcast_shape(x_shape, y_shape, prim_name, arg_name1="x", arg_name2="y
|
|
|
78
78
|
return broadcast_shape
|
|
79
79
|
|
|
80
80
|
|
|
81
|
+
def dim_not_equal(dim1, dim2):
|
|
82
|
+
"""Compare dim in shape"""
|
|
83
|
+
return dim1 != dim2 and dim1 >= 0 and dim2 >= 0
|
|
84
|
+
|
|
85
|
+
|
|
81
86
|
def get_concat_offset(x_shp, x_type, axis, prim_name):
|
|
82
87
|
"""for concat and concatoffset check args and compute offset"""
|
|
83
88
|
validator.check_value_type("shape", x_shp, [tuple, list], prim_name)
|
|
@@ -98,7 +103,7 @@ def get_concat_offset(x_shp, x_type, axis, prim_name):
|
|
|
98
103
|
for i in range(1, len(x_shp)):
|
|
99
104
|
v = x_shp[i]
|
|
100
105
|
for j in range(rank_base):
|
|
101
|
-
if j != axis and v[j]
|
|
106
|
+
if j != axis and dim_not_equal(v[j], x_shp[0][j]):
|
|
102
107
|
raise ValueError(f"The shape of the two input elements of the Concat operator do not match:"
|
|
103
108
|
f"shape[0] = {x_shp[0]} and shape[{i}] = {x_shp[i]}.")
|
|
104
109
|
offset.append(all_shp)
|
|
@@ -155,7 +155,7 @@ def _get_prefix(indices_shape, axis_size, indices_dtype):
|
|
|
155
155
|
_check(indices_shape)
|
|
156
156
|
indices_len = len(indices_shape)
|
|
157
157
|
if indices_len == 1:
|
|
158
|
-
prefix = P.Range()(Tensor(0, indices_dtype),
|
|
158
|
+
prefix = P.Range()(Tensor(0, indices_dtype), F.fill(
|
|
159
159
|
indices_dtype, (), axis_size), Tensor(1, indices_dtype))
|
|
160
160
|
return prefix
|
|
161
161
|
|
|
@@ -850,9 +850,9 @@ def get_fill_vmap_rule(prim, axis_size):
|
|
|
850
850
|
|
|
851
851
|
|
|
852
852
|
@constexpr
|
|
853
|
-
def to_tensor_with_type(x,
|
|
853
|
+
def to_tensor_with_type(x, dtype):
|
|
854
854
|
"""x to Tensor with type"""
|
|
855
|
-
return Tensor(x,
|
|
855
|
+
return Tensor(x, dtype)
|
|
856
856
|
|
|
857
857
|
|
|
858
858
|
@vmap_rules_getters.register(P.FillV2)
|
mindspore/ops/_vmap/vmap_base.py
CHANGED
|
@@ -250,7 +250,7 @@ def vmap_monad_rule(prim, axis_size):
|
|
|
250
250
|
def _bdim_at_any(x, src, dst, axis_size):
|
|
251
251
|
"""
|
|
252
252
|
Moves source axes of an array to the dst axis, and other axes remain in their original order. If the source axes
|
|
253
|
-
is
|
|
253
|
+
is ``None``, broadcasts the array at dst axis with axis_size.
|
|
254
254
|
|
|
255
255
|
Args:
|
|
256
256
|
x (Tensor or Scalar): The input tensor or scalar. The data type should be one of the following types: float16,
|
|
@@ -272,7 +272,7 @@ def _bdim_at_any(x, src, dst, axis_size):
|
|
|
272
272
|
def _bdim_at_front(x, src, axis_size):
|
|
273
273
|
"""
|
|
274
274
|
Moves source axes of an array to the foremost, and other axes remain in their original order. If the source axes
|
|
275
|
-
is
|
|
275
|
+
is ``None``, broadcasts the array at foremost axis with axis_size.
|
|
276
276
|
|
|
277
277
|
Args:
|
|
278
278
|
x (Tensor or Scalar): The input tensor or scalar. The data type should be one of the following types: float16,
|
|
@@ -289,7 +289,7 @@ def _bdim_at_front(x, src, axis_size):
|
|
|
289
289
|
def _bdim_at_back(x, src, axis_size):
|
|
290
290
|
"""
|
|
291
291
|
Moves source axes of an array to the last, and other axes remain in their original order. If the source axes
|
|
292
|
-
is
|
|
292
|
+
is ``None``, broadcasts the array at foremost axis with axis_size.
|
|
293
293
|
|
|
294
294
|
Args:
|
|
295
295
|
x (Tensor or Scalar): The input tensor or scalar. The data type should be one of the following types: float16,
|
|
@@ -190,8 +190,8 @@ def _reshape_expand_dims(src_dim, dst_size, target, prim_name):
|
|
|
190
190
|
@_primexpr
|
|
191
191
|
def _get_new_size_by_index(input_size, batch_size, index):
|
|
192
192
|
"""Get the new size of input_size by multiplying input_size[index] by batch_size."""
|
|
193
|
-
new_size = ()
|
|
194
193
|
if input_size is None:
|
|
194
|
+
new_size = ()
|
|
195
195
|
return new_size
|
|
196
196
|
new_size = list(input_size)
|
|
197
197
|
new_size[index] *= batch_size
|
|
@@ -62,8 +62,9 @@ def get_broadcast_binary_op_grad_vmap_rule(prim, axis_size):
|
|
|
62
62
|
y_shape = F.shape(y)
|
|
63
63
|
g_shape = F.shape(g)
|
|
64
64
|
|
|
65
|
-
|
|
66
|
-
|
|
65
|
+
is_dim_ok = x_dim == y_dim and x_dim == g_dim
|
|
66
|
+
is_shape_ok = x_shape == y_shape and x_shape == g_shape
|
|
67
|
+
if is_dim_ok and is_shape_ok:
|
|
67
68
|
dx, dy = prim(x, y, g)
|
|
68
69
|
return (dx, x_dim), (dy, y_dim)
|
|
69
70
|
|
|
@@ -113,8 +114,9 @@ def get_broadcast_grad_grad_vmap_rule(prim, axis_size):
|
|
|
113
114
|
dx1_shape = F.shape(dx1)
|
|
114
115
|
dx2_shape = F.shape(dx2)
|
|
115
116
|
|
|
116
|
-
|
|
117
|
-
|
|
117
|
+
is_dim_ok = x1_dim == x2_dim and dx1_dim == dx2_dim and x1_dim == dx1_dim
|
|
118
|
+
is_shape_ok = x1_shape == x2_shape and dx1_shape == dx2_shape
|
|
119
|
+
if is_dim_ok and is_shape_ok:
|
|
118
120
|
sopd_x1, sopd_x2, sopd_grad = prim(x1, x2, dx1, dx2)
|
|
119
121
|
return (sopd_x1, x1_dim), (sopd_x2, x1_dim), (sopd_grad, x1_dim)
|
|
120
122
|
|
|
@@ -66,6 +66,7 @@ def _broadcast_shape(nd, x_ndim, x_shape):
|
|
|
66
66
|
@vmap_rules_getters.register(P.ApproximateEqual)
|
|
67
67
|
@vmap_rules_getters.register(P.TruncateDiv)
|
|
68
68
|
@vmap_rules_getters.register(P.TruncateMod)
|
|
69
|
+
|
|
69
70
|
def get_broadcast_binary_op_vmap_rule(prim, axis_size):
|
|
70
71
|
"""VmapRule for binary operations with broadcasting, such as `Add` and `Sub`."""
|
|
71
72
|
|
|
@@ -216,8 +217,9 @@ def get_lerp_vamp_rule(prim, axis_size):
|
|
|
216
217
|
# Both broadcast end and weight to start.
|
|
217
218
|
else:
|
|
218
219
|
weight_shape = F.shape(weight)
|
|
219
|
-
|
|
220
|
-
|
|
220
|
+
is_dim_ok = start_dim == end_dim and start_dim == weight_dim
|
|
221
|
+
is_shape_ok = start_shape == end_shape and start_shape == weight_shape
|
|
222
|
+
if is_dim_ok and is_shape_ok:
|
|
221
223
|
out = prim(start, end, weight)
|
|
222
224
|
return out, start_dim
|
|
223
225
|
start, end = broadcast_a_b_shape(start_bdim, end_bdim)
|
|
@@ -900,3 +902,4 @@ get_unop_vmap_rule = vmap_rules_getters.register(BesselK1e)(get_unop_vmap_rule)
|
|
|
900
902
|
get_unop_vmap_rule = vmap_rules_getters.register(P.Trunc)(get_unop_vmap_rule)
|
|
901
903
|
get_unop_vmap_rule = vmap_rules_getters.register(P.PopulationCount)(get_unop_vmap_rule)
|
|
902
904
|
get_unop_vmap_rule = vmap_rules_getters.register(P.Square)(get_unop_vmap_rule)
|
|
905
|
+
get_unop_vmap_rule = vmap_rules_getters.register(P.Eps)(get_unop_vmap_rule)
|
|
@@ -325,9 +325,10 @@ def get_bce_with_logits_loss_vamp_rule(prim, axis_size):
|
|
|
325
325
|
# If rank is larger than 1, we need to reduce result when reduction != 'none'
|
|
326
326
|
if max_rank > 1:
|
|
327
327
|
reduce_indexes = tuple(range(1, max_rank))
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
328
|
+
logits_dim_ok = logits_dim == label_dim and logits_dim == weight_dim and logits_dim == pos_weight_dim
|
|
329
|
+
shape = F.shape(logits)
|
|
330
|
+
shape_ok = shape == F.shape(label) and shape == F.shape(weight) and shape == F.shape(pos_weight)
|
|
331
|
+
if logits_dim_ok and shape_ok:
|
|
331
332
|
if prim_reduction == 'none':
|
|
332
333
|
output = prim(logits, label, weight, pos_weight)
|
|
333
334
|
elif prim_reduction in ('mean', 'sum'):
|
|
@@ -798,7 +799,8 @@ def get_instance_norm_rule(prim, axis_size):
|
|
|
798
799
|
output_x, updated_moving_mean, updated_moving_variance = prim(input_x, gamma, beta, mean, variance, u_monad)
|
|
799
800
|
return (output_x, None), (updated_moving_mean, None), (updated_moving_variance, None)
|
|
800
801
|
|
|
801
|
-
|
|
802
|
+
precondition = gamma_dim != 0 or beta_dim != gamma_dim or mean_dim != gamma_dim or variance_dim != gamma_dim
|
|
803
|
+
if precondition:
|
|
802
804
|
# pylint: disable=too-many-format-args
|
|
803
805
|
raise ValueError(
|
|
804
806
|
"For `{}`, the source axis of `var` must be equal to `accum` and `accum_update`, and not equal to 0, "
|
|
@@ -1679,7 +1681,8 @@ def get_rmsprop_vmap_rule(prim, axis_size):
|
|
|
1679
1681
|
res = prim(var, mean_square, moment, lr, grad, decay, momentum, epsilon,
|
|
1680
1682
|
u_monad) # low dimensional operator;
|
|
1681
1683
|
return (res, None)
|
|
1682
|
-
|
|
1684
|
+
precondition = var_dim != 0 or var_dim != mean_square_dim or var_dim != moment_dim or var_dim != grad_dim
|
|
1685
|
+
if precondition:
|
|
1683
1686
|
raise ValueError(
|
|
1684
1687
|
f"For '{prim_name}', the source axis of 'var' must be equal to 'mean_square_dim' "
|
|
1685
1688
|
f"and 'moment_dim' and 'grad_dim' and not equal to 0, "
|
|
@@ -1735,8 +1738,8 @@ def get_apply_centered_rmsprop_vmap_rule(prim, axis_size):
|
|
|
1735
1738
|
var = prim(var, mean_grad, mean_square,
|
|
1736
1739
|
mom, grad, lr, rho, momentum, eps, u_monad)
|
|
1737
1740
|
return (var, None)
|
|
1738
|
-
|
|
1739
|
-
if
|
|
1741
|
+
precondition = var_dim != 0 or var_dim != mean_grad_dim or var_dim != mean_square_dim or var_dim != mom_dim
|
|
1742
|
+
if precondition:
|
|
1740
1743
|
raise ValueError(
|
|
1741
1744
|
f"For '{prim_name}', the source axis of 'var' must be equal to 'mean_grad_dim' "
|
|
1742
1745
|
f"and 'mean_square_dim' and 'mom_dim' and not equal to 0, "
|
|
@@ -2000,6 +2003,57 @@ def get_sparse_apply_ftrl_vmap_rule(prim, axis_size):
|
|
|
2000
2003
|
return vmap_rule
|
|
2001
2004
|
|
|
2002
2005
|
|
|
2006
|
+
@vmap_rules_getters.register(P.Dense)
|
|
2007
|
+
def get_dense_vmap_rule(prim, axis_size):
|
|
2008
|
+
"""VmapRule for `Dense` operation."""
|
|
2009
|
+
if isinstance(prim, str):
|
|
2010
|
+
prim = Primitive(prim)
|
|
2011
|
+
|
|
2012
|
+
batch_matmul = P.BatchMatMul(transpose_b=True)
|
|
2013
|
+
|
|
2014
|
+
@_primexpr
|
|
2015
|
+
def get_start_mid_end(x_shape):
|
|
2016
|
+
start = x_shape[0]
|
|
2017
|
+
mid = 1
|
|
2018
|
+
for shp in x_shape[1:-1]:
|
|
2019
|
+
mid *= shp
|
|
2020
|
+
end = x_shape[-1]
|
|
2021
|
+
return start, mid, end
|
|
2022
|
+
|
|
2023
|
+
def vmap_rule(x_bdim, w_bdim, b_bdim):
|
|
2024
|
+
is_all_none, result = vmap_general_preprocess(prim, x_bdim, w_bdim, b_bdim)
|
|
2025
|
+
if is_all_none:
|
|
2026
|
+
return result
|
|
2027
|
+
|
|
2028
|
+
x, x_dim = x_bdim
|
|
2029
|
+
w, w_dim = w_bdim
|
|
2030
|
+
b, b_dim = b_bdim
|
|
2031
|
+
x = _bdim_at_front(x, x_dim, axis_size)
|
|
2032
|
+
w = _bdim_at_front(w, w_dim, axis_size)
|
|
2033
|
+
if b is not None:
|
|
2034
|
+
b = _bdim_at_front(b, b_dim, axis_size)
|
|
2035
|
+
|
|
2036
|
+
x_shape = x.shape
|
|
2037
|
+
start, mid, end = get_start_mid_end(x_shape)
|
|
2038
|
+
|
|
2039
|
+
x = x.reshape(start, mid, end)
|
|
2040
|
+
|
|
2041
|
+
out = batch_matmul(x, w)
|
|
2042
|
+
out_shape = tuple(x_shape[:-1]) + (out.shape[-1],)
|
|
2043
|
+
out = out.reshape(out_shape)
|
|
2044
|
+
|
|
2045
|
+
if b is not None:
|
|
2046
|
+
b_shape = b.shape
|
|
2047
|
+
b_shape = (start,) + (1,) * (len(out_shape) - 2) + (b_shape[-1],)
|
|
2048
|
+
b = b.reshape(b_shape)
|
|
2049
|
+
|
|
2050
|
+
out = out + b
|
|
2051
|
+
|
|
2052
|
+
return out, 0
|
|
2053
|
+
|
|
2054
|
+
return vmap_rule
|
|
2055
|
+
|
|
2056
|
+
|
|
2003
2057
|
# Unary vmap
|
|
2004
2058
|
get_unop_vmap_rule = vmap_rules_getters.register(P.Elu)(get_unop_vmap_rule)
|
|
2005
2059
|
get_unop_vmap_rule = vmap_rules_getters.register(P.ReLU)(get_unop_vmap_rule)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
2
|
+
#
|
|
3
|
+
# Copyright 2023-2024 Huawei Technologies Co., Ltd
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
# ============================================================================
|
|
17
|
+
"""Operator argument data type cast function."""
|
|
18
|
+
from enum import Enum
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TypeCastKind(Enum):
|
|
22
|
+
INT_TO_TUPLE = 1
|
|
23
|
+
INT_OR_TUPLE_TO_LIST = 2
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def type_it(src_data, cast_type):
|
|
27
|
+
"""
|
|
28
|
+
cast operator argument data type.
|
|
29
|
+
"""
|
|
30
|
+
if cast_type == TypeCastKind.INT_TO_TUPLE:
|
|
31
|
+
if isinstance(src_data, tuple):
|
|
32
|
+
return src_data
|
|
33
|
+
|
|
34
|
+
if isinstance(src_data, int):
|
|
35
|
+
return (src_data,)
|
|
36
|
+
|
|
37
|
+
raise TypeError(f'{src_data} is the wrong data type.')
|
|
38
|
+
|
|
39
|
+
if cast_type == TypeCastKind.INT_OR_TUPLE_TO_LIST:
|
|
40
|
+
if isinstance(src_data, list):
|
|
41
|
+
return src_data
|
|
42
|
+
|
|
43
|
+
if isinstance(src_data, int):
|
|
44
|
+
return [
|
|
45
|
+
src_data,
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
if isinstance(src_data, tuple):
|
|
49
|
+
dst_list = [item for item in src_data]
|
|
50
|
+
return dst_list
|
|
51
|
+
|
|
52
|
+
raise TypeError(f'{src_data} is the wrong data type.')
|
|
53
|
+
|
|
54
|
+
raise TypeError("Unsupported type cast")
|
mindspore/ops/composite/base.py
CHANGED
|
@@ -20,6 +20,7 @@ from __future__ import absolute_import
|
|
|
20
20
|
from functools import partial
|
|
21
21
|
|
|
22
22
|
from types import FunctionType, MethodType
|
|
23
|
+
import numpy as np
|
|
23
24
|
import mindspore as ms
|
|
24
25
|
from mindspore import context
|
|
25
26
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
@@ -28,7 +29,8 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
|
|
|
28
29
|
TupleAdd_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
|
|
29
30
|
SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_, \
|
|
30
31
|
ListClear_, ListReverse_, ListExtend_, DictClear_, DictHasKey_, DictUpdate_, DictFromKeys_, \
|
|
31
|
-
ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_, ListAdd_
|
|
32
|
+
ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_, ListAdd_, DictSetItem_, \
|
|
33
|
+
HandleBoolTensor_, HandleEmptySlice_, PreSetitemByTuple_, HandleScalarTensorIndex_
|
|
32
34
|
from mindspore.common import dtype as mstype
|
|
33
35
|
from mindspore.common.api import jit, _pynative_executor, _wrap_func
|
|
34
36
|
from mindspore.common.api import _add_flags, _core
|
|
@@ -36,7 +38,8 @@ from mindspore.ops.primitive import Primitive
|
|
|
36
38
|
from mindspore.ops import signature as sig
|
|
37
39
|
|
|
38
40
|
__all__ = [TupleAdd_, ListAdd_, UnpackCall_, TupleGetItemTensor_, SequenceSliceGetItem_,
|
|
39
|
-
ListSliceSetItem_, ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_
|
|
41
|
+
ListSliceSetItem_, ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_,
|
|
42
|
+
HandleBoolTensor_, HandleEmptySlice_, PreSetitemByTuple_, HandleScalarTensorIndex_]
|
|
40
43
|
|
|
41
44
|
|
|
42
45
|
def add_flags(fn=None, **flags):
|
|
@@ -334,7 +337,7 @@ class GradOperation(GradOperation_):
|
|
|
334
337
|
self.get_all = get_all
|
|
335
338
|
self.get_by_list = get_by_list
|
|
336
339
|
self.sens_param = sens_param
|
|
337
|
-
GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, False, False, False, False)
|
|
340
|
+
GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, False, False, False, False, False)
|
|
338
341
|
self.grad_fn = None
|
|
339
342
|
self.fn = None
|
|
340
343
|
self.weights_id = None
|
|
@@ -511,8 +514,8 @@ class _Grad(GradOperation_):
|
|
|
511
514
|
A higher-order function which is used to generate the gradient function by position for the input function.
|
|
512
515
|
"""
|
|
513
516
|
|
|
514
|
-
def __init__(self, get_by_list=False, sens_param=False, get_by_position=False, has_aux=False,
|
|
515
|
-
return_ids=False):
|
|
517
|
+
def __init__(self, get_all=False, get_by_list=False, sens_param=False, get_by_position=False, has_aux=False,
|
|
518
|
+
get_value=False, return_ids=False, merge_forward=False):
|
|
516
519
|
"""Initialize _Grad."""
|
|
517
520
|
if not isinstance(get_by_position, bool):
|
|
518
521
|
raise TypeError(f"For '_Grad', the 'get_by_position' should be bool, "
|
|
@@ -532,14 +535,16 @@ class _Grad(GradOperation_):
|
|
|
532
535
|
if not isinstance(return_ids, bool):
|
|
533
536
|
raise TypeError(f"For '_Grad', the 'return_ids' should be bool, "
|
|
534
537
|
f"but got {type(return_ids).__name__}")
|
|
538
|
+
self.get_all = get_all
|
|
535
539
|
self.get_by_position = get_by_position
|
|
536
540
|
self.get_by_list = get_by_list
|
|
537
541
|
self.sens_param = sens_param
|
|
538
542
|
self.has_aux = has_aux
|
|
539
543
|
self.get_value = get_value
|
|
540
544
|
self.return_ids = return_ids
|
|
541
|
-
|
|
542
|
-
|
|
545
|
+
self.merge_forward = merge_forward
|
|
546
|
+
GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param, get_by_position, has_aux, get_value,
|
|
547
|
+
return_ids, merge_forward)
|
|
543
548
|
self.grad_fn = None
|
|
544
549
|
self.fn = None
|
|
545
550
|
self.pynative_ = False
|
|
@@ -562,8 +567,8 @@ class _Grad(GradOperation_):
|
|
|
562
567
|
res += (stop_gradient(item),)
|
|
563
568
|
return res
|
|
564
569
|
|
|
565
|
-
grad_ = _Grad(self.get_by_list, self.sens_param, self.get_by_position, self.has_aux,
|
|
566
|
-
self.return_ids)
|
|
570
|
+
grad_ = _Grad(self.get_all, self.get_by_list, self.sens_param, self.get_by_position, self.has_aux,
|
|
571
|
+
self.get_value, self.return_ids, self.merge_forward)
|
|
567
572
|
# If calling Grad in GRAPH_MODE or calling Grad in functions decorated with 'jit', do grad in GRAPH_MODE
|
|
568
573
|
# If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
|
|
569
574
|
# In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
|
|
@@ -738,6 +743,9 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|
|
738
743
|
sig.make_sig('args', sig.sig_rw.RW_READ, sig.sig_kind.KIND_VAR_POSITIONAL),))
|
|
739
744
|
|
|
740
745
|
def __call__(self, *args):
|
|
746
|
+
for arg in args:
|
|
747
|
+
if isinstance(arg, np.ndarray):
|
|
748
|
+
raise TypeError("For 'MultitypeFuncGraph', the input can not be numpy.ndarray")
|
|
741
749
|
if len(self.entries) == 1:
|
|
742
750
|
output = self.entries[0][1](*args)
|
|
743
751
|
return output
|
|
@@ -890,7 +898,7 @@ class Map(Map_):
|
|
|
890
898
|
If `ops` is `None`, the first input is the operation, and the other is inputs.
|
|
891
899
|
|
|
892
900
|
Outputs:
|
|
893
|
-
Sequence, the sequence of output after applying the function. e.g. `
|
|
901
|
+
Sequence, the sequence of output after applying the ops function. e.g. `ops(args[0][i], args[1][i])`.
|
|
894
902
|
|
|
895
903
|
Supported Platforms:
|
|
896
904
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -1046,6 +1054,25 @@ class _ListExtend(ListExtend_):
|
|
|
1046
1054
|
_extend = _ListExtend("extend")
|
|
1047
1055
|
|
|
1048
1056
|
|
|
1057
|
+
class _DictSetItem(DictSetItem_):
|
|
1058
|
+
"""
|
|
1059
|
+
A metafuncgraph class that setitem for the dict.
|
|
1060
|
+
|
|
1061
|
+
Args:
|
|
1062
|
+
name (str): The name of the metafuncgraph object.
|
|
1063
|
+
"""
|
|
1064
|
+
|
|
1065
|
+
def __init__(self, name):
|
|
1066
|
+
"""Initialize _DictClear."""
|
|
1067
|
+
DictSetItem_.__init__(self, name)
|
|
1068
|
+
|
|
1069
|
+
def __call__(self, *args):
|
|
1070
|
+
pass
|
|
1071
|
+
|
|
1072
|
+
|
|
1073
|
+
_dict_setitem = _DictSetItem("setitem")
|
|
1074
|
+
|
|
1075
|
+
|
|
1049
1076
|
class _DictClear(DictClear_):
|
|
1050
1077
|
"""
|
|
1051
1078
|
A metafuncgraph class that clear the dict.
|
|
@@ -13,8 +13,10 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ============================================================================
|
|
15
15
|
"""math Operations."""
|
|
16
|
+
import mindspore.ops as ops
|
|
16
17
|
from mindspore.ops import functional as F
|
|
17
18
|
from mindspore.ops.function.math_func import cummin as cummin_
|
|
19
|
+
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
18
20
|
|
|
19
21
|
|
|
20
22
|
def matmul(x1, x2, dtype=None):
|
|
@@ -117,10 +119,9 @@ def mm(input, mat2):
|
|
|
117
119
|
>>> print(out.shape)
|
|
118
120
|
(2, 4)
|
|
119
121
|
"""
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
return matmul(input, mat2)
|
|
122
|
+
_matmul = _get_cache_prim(ops.MatMul)()
|
|
123
|
+
out = _matmul(input, mat2)
|
|
124
|
+
return out
|
|
124
125
|
|
|
125
126
|
|
|
126
127
|
def cummin(x, axis):
|