mindspore 2.1.0__cp38-cp38-win_amd64.whl → 2.2.11__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/amp.py +47 -11
- mindspore/atlprov.dll +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +7 -4
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +429 -486
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -1,46 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for LogSoftmax"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.add_format(DF.DEFAULT)
|
|
21
|
-
@VLD.check_attrs('axis')
|
|
22
|
-
class LogSoftmax(Expander):
|
|
23
|
-
"""LogSoftmax expander"""
|
|
24
|
-
|
|
25
|
-
def _expand(self, graph_builder):
|
|
26
|
-
input_x = self.inputs[0]
|
|
27
|
-
axis = self.attrs['axis']
|
|
28
|
-
processor = self.processor
|
|
29
|
-
|
|
30
|
-
if isinstance(axis, int):
|
|
31
|
-
axis = (axis,)
|
|
32
|
-
|
|
33
|
-
ori_dtype = input_x.dtype
|
|
34
|
-
if ori_dtype != "float16" and processor == "aicore":
|
|
35
|
-
input_x_f16 = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float16'})
|
|
36
|
-
max_x_f16 = graph_builder.emit('ReduceMax', [input_x_f16], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
37
|
-
max_x = graph_builder.emit('Cast', [max_x_f16], attrs={'dst_type': ori_dtype})
|
|
38
|
-
else:
|
|
39
|
-
max_x = graph_builder.emit('ReduceMax', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
40
|
-
data_sub = graph_builder.emit('Sub', [input_x, max_x])
|
|
41
|
-
data_exp = graph_builder.emit('Exp', [data_sub])
|
|
42
|
-
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
43
|
-
log_expsum = graph_builder.emit('Log', [data_expsum])
|
|
44
|
-
result = graph_builder.emit('Sub', [data_sub, log_expsum])
|
|
45
|
-
|
|
46
|
-
return result
|
|
@@ -1,36 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for LogSoftmaxGrad"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
|
21
|
-
@VLD.check_attrs('axis')
|
|
22
|
-
class LogSoftmaxGrad(Expander):
|
|
23
|
-
"""LogSoftmaxGrad expander"""
|
|
24
|
-
|
|
25
|
-
def _expand(self, graph_builder):
|
|
26
|
-
input_logits, input_dy = self.inputs
|
|
27
|
-
axis = self.attrs['axis']
|
|
28
|
-
if isinstance(axis, int):
|
|
29
|
-
axis = (axis,)
|
|
30
|
-
|
|
31
|
-
softmax = graph_builder.emit('Exp', [input_logits])
|
|
32
|
-
dy_sum = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
33
|
-
mul_result = graph_builder.emit('Mul', [softmax, dy_sum])
|
|
34
|
-
result = graph_builder.emit('Sub', [input_dy, mul_result])
|
|
35
|
-
|
|
36
|
-
return result
|
|
@@ -1,80 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for BatchMatMul and MatMul"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
|
18
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@VLD.check_attrs('transpose_a', 'transpose_b', 'left_format', 'right_format')
|
|
22
|
-
class MatMul(Expander):
|
|
23
|
-
"""
|
|
24
|
-
MatMul expander
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
def __init__(self, expand_info):
|
|
28
|
-
super(MatMul, self).__init__(expand_info)
|
|
29
|
-
self.shape_a = self.inputs[0]['shape']
|
|
30
|
-
self.shape_b = self.inputs[1]['shape']
|
|
31
|
-
self.transpose_a = False
|
|
32
|
-
self.transpose_b = False
|
|
33
|
-
self.left_format = ""
|
|
34
|
-
self.right_format = ""
|
|
35
|
-
|
|
36
|
-
def _optimize_to_mul(self):
|
|
37
|
-
"""check if matmul can be replace by mul"""
|
|
38
|
-
if self.processor != 'aicore' or self.left_format != DF.DEFAULT or self.right_format != DF.DEFAULT:
|
|
39
|
-
return False
|
|
40
|
-
k_a = self.shape_a[-2] if self.transpose_a else self.shape_a[-1]
|
|
41
|
-
k_b = self.shape_b[-1] if self.transpose_b else self.shape_b[-2]
|
|
42
|
-
if k_a != 1 or k_b != 1:
|
|
43
|
-
return False
|
|
44
|
-
return True
|
|
45
|
-
|
|
46
|
-
def _check(self):
|
|
47
|
-
input_num = len(self.inputs)
|
|
48
|
-
if input_num < 2:
|
|
49
|
-
raise GKException("For 'MatMul', inputs number should bigger than 1, but got {}.".format(input_num))
|
|
50
|
-
|
|
51
|
-
def _expand(self, graph_builder):
|
|
52
|
-
self.transpose_a = self.attrs['transpose_a']
|
|
53
|
-
self.transpose_b = self.attrs['transpose_b']
|
|
54
|
-
self.left_format = self.attrs['left_format']
|
|
55
|
-
self.right_format = self.attrs['right_format']
|
|
56
|
-
|
|
57
|
-
def transpose(shape):
|
|
58
|
-
trans_shape = list(shape)
|
|
59
|
-
trans_shape[-2] = shape[-1]
|
|
60
|
-
trans_shape[-1] = shape[-2]
|
|
61
|
-
return trans_shape
|
|
62
|
-
if not self._optimize_to_mul():
|
|
63
|
-
raise GKException("MatMul/BatchMatMul do not need to be replaced by Mul")
|
|
64
|
-
# Matmul is replaced by Mul([b m k], [b k n]) when k==1
|
|
65
|
-
input_a = self.inputs[0]
|
|
66
|
-
input_b = self.inputs[1]
|
|
67
|
-
if self.transpose_a:
|
|
68
|
-
shape_a_trans = transpose(self.shape_a)
|
|
69
|
-
input_a = graph_builder.emit('Reshape', [input_a], attrs={'shape': shape_a_trans})
|
|
70
|
-
if self.transpose_b:
|
|
71
|
-
shape_b_trans = transpose(self.shape_b)
|
|
72
|
-
input_b = graph_builder.emit('Reshape', [input_b], attrs={'shape': shape_b_trans})
|
|
73
|
-
result = graph_builder.emit('Mul', [input_a, input_b])
|
|
74
|
-
if 'dst_type' in self.attrs and self.inputs[0].dtype != self.attrs['dst_type']:
|
|
75
|
-
result = graph_builder.emit('Cast', [result], attrs={'dst_type': self.attrs['dst_type']})
|
|
76
|
-
return result
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
class BatchMatMul(MatMul):
|
|
80
|
-
"""BatchMatMul expander"""
|
|
@@ -1,59 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for maximum_grad"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
from .minimum_grad import MinimumGrad
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@VLD.check_all_formats_same
|
|
22
|
-
class MaximumGrad(Expander):
|
|
23
|
-
"""MaximumGrad expander"""
|
|
24
|
-
|
|
25
|
-
def _check(self):
|
|
26
|
-
if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
|
|
27
|
-
raise GKException("For 'MaximumGrad', value of attr 'grad_x' and 'grad_y' should be False, but got {} and "
|
|
28
|
-
"{}".format(self.attrs.get('grad_x'), self.attrs.get('grad_y')))
|
|
29
|
-
return super()._check()
|
|
30
|
-
|
|
31
|
-
def _expand(self, graph_builder):
|
|
32
|
-
input_x, input_y, input_dout = self.inputs
|
|
33
|
-
ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y])
|
|
34
|
-
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
|
|
35
|
-
dx = graph_builder.emit('Mul', [ge_result, input_dout])
|
|
36
|
-
dy = graph_builder.emit('Sub', [input_dout, dx])
|
|
37
|
-
|
|
38
|
-
reduce_axis_x = MinimumGrad.get_reduce_axis(input_x.shape, dx.shape)
|
|
39
|
-
reduce_axis_y = MinimumGrad.get_reduce_axis(input_y.shape, dy.shape)
|
|
40
|
-
if reduce_axis_x:
|
|
41
|
-
dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False})
|
|
42
|
-
if dx_reduce.shape != input_x.shape:
|
|
43
|
-
dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape})
|
|
44
|
-
else:
|
|
45
|
-
dx_out = dx_reduce
|
|
46
|
-
else:
|
|
47
|
-
dx_out = dx
|
|
48
|
-
|
|
49
|
-
if reduce_axis_y:
|
|
50
|
-
dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False})
|
|
51
|
-
if dy_reduce.shape != input_y.shape:
|
|
52
|
-
dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape})
|
|
53
|
-
else:
|
|
54
|
-
dy_out = dy_reduce
|
|
55
|
-
else:
|
|
56
|
-
dy_out = dy
|
|
57
|
-
|
|
58
|
-
# output two results, regardless of grad_x and grad_y
|
|
59
|
-
return dx_out, dy_out
|
|
@@ -1,80 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for minimum_grad"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.check_all_formats_same
|
|
21
|
-
class MinimumGrad(Expander):
|
|
22
|
-
"""MinimumGrad expander"""
|
|
23
|
-
|
|
24
|
-
def _check(self):
|
|
25
|
-
if not self.attrs.get('grad_x', True) and not self.attrs.get('grad_y', True):
|
|
26
|
-
raise GKException("For 'MinimumGrad', value of attr 'grad_x' and 'grad_y' should be False, but got {} and "
|
|
27
|
-
"{}".format(self.attrs.get('grad_x'), self.attrs.get('grad_y')))
|
|
28
|
-
return super(MinimumGrad, self)._check()
|
|
29
|
-
|
|
30
|
-
def _expand(self, graph_builder):
|
|
31
|
-
input_x, input_y, input_dout = self.inputs
|
|
32
|
-
|
|
33
|
-
le_result = graph_builder.emit('LessEqual', [input_x, input_y])
|
|
34
|
-
le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': input_x.dtype})
|
|
35
|
-
dx = graph_builder.emit('Mul', [le_result, input_dout])
|
|
36
|
-
dy = graph_builder.emit('Sub', [input_dout, dx])
|
|
37
|
-
|
|
38
|
-
# for minimumgrad op, output_shape should be equal to input_shape,
|
|
39
|
-
# but some elementwise operating may broadcast input_shape
|
|
40
|
-
# then output_shape not equal to original input_shape, so need to reduce output to let them equal
|
|
41
|
-
reduce_axis_x = self.get_reduce_axis(input_x.shape, dx.shape)
|
|
42
|
-
reduce_axis_y = self.get_reduce_axis(input_y.shape, dy.shape)
|
|
43
|
-
if reduce_axis_x:
|
|
44
|
-
dx_reduce = graph_builder.emit('ReduceSum', [dx], attrs={'reduce_axis': reduce_axis_x, 'keep_dims': False})
|
|
45
|
-
if dx_reduce.shape != input_x.shape:
|
|
46
|
-
dx_out = graph_builder.emit('Reshape', [dx_reduce], attrs={'shape': input_x.shape})
|
|
47
|
-
else:
|
|
48
|
-
dx_out = dx_reduce
|
|
49
|
-
else:
|
|
50
|
-
dx_out = dx
|
|
51
|
-
|
|
52
|
-
if reduce_axis_y:
|
|
53
|
-
dy_reduce = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': reduce_axis_y, 'keep_dims': False})
|
|
54
|
-
if dy_reduce.shape != input_y.shape:
|
|
55
|
-
dy_out = graph_builder.emit('Reshape', [dy_reduce], attrs={'shape': input_y.shape})
|
|
56
|
-
else:
|
|
57
|
-
dy_out = dy_reduce
|
|
58
|
-
else:
|
|
59
|
-
dy_out = dy
|
|
60
|
-
|
|
61
|
-
# output two results, regardless of grad_x and grad_y
|
|
62
|
-
return dx_out, dy_out
|
|
63
|
-
|
|
64
|
-
@staticmethod
|
|
65
|
-
def get_reduce_axis(original_shape, broadcast_shape):
|
|
66
|
-
"""compute reduce axis for final output_shape"""
|
|
67
|
-
if len(original_shape) > len(broadcast_shape):
|
|
68
|
-
raise ValueError("For 'MinimumGrad', the length of original_shape should be less than or equal to the "
|
|
69
|
-
"length of broadcast_shape, but got {} and {}".format(original_shape, broadcast_shape))
|
|
70
|
-
|
|
71
|
-
tmp_shape = [1] * (len(broadcast_shape) - len(original_shape)) + original_shape
|
|
72
|
-
reduce_axis = []
|
|
73
|
-
for idx, _ in enumerate(tmp_shape):
|
|
74
|
-
if tmp_shape[idx] != broadcast_shape[idx]:
|
|
75
|
-
if tmp_shape[idx] == 1:
|
|
76
|
-
reduce_axis.append(idx)
|
|
77
|
-
else:
|
|
78
|
-
raise ValueError("For 'MinimumGrad', original_shape {} and broadcast_shape {} can not broadcast."
|
|
79
|
-
.format(original_shape, broadcast_shape))
|
|
80
|
-
return reduce_axis
|
|
@@ -1,26 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for OnesLike"""
|
|
16
|
-
from ._utils import Expander
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class OnesLike(Expander):
|
|
20
|
-
"""OnesLike expander"""
|
|
21
|
-
|
|
22
|
-
def _expand(self, graph_builder):
|
|
23
|
-
input_x = self.inputs[0]
|
|
24
|
-
const_one = graph_builder.value(input_x.dtype, 1)
|
|
25
|
-
result = graph_builder.emit('BroadcastTo', [const_one], attrs={'shape': input_x.shape})
|
|
26
|
-
return result
|
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for reduce_mean"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.add_format(DF.DEFAULT)
|
|
21
|
-
@VLD.check_attrs('axis', 'keep_dims')
|
|
22
|
-
class ReduceMean(Expander):
|
|
23
|
-
"""ReduceMean expander"""
|
|
24
|
-
|
|
25
|
-
def _expand(self, graph_builder):
|
|
26
|
-
x = self.inputs[0]
|
|
27
|
-
axis = self.attrs['axis']
|
|
28
|
-
keep_dims = self.attrs['keep_dims']
|
|
29
|
-
|
|
30
|
-
if not isinstance(axis, (tuple, list)):
|
|
31
|
-
axis = (axis,)
|
|
32
|
-
elif not axis:
|
|
33
|
-
axis = list(range(len(x.shape)))
|
|
34
|
-
reduce_size = 1.0
|
|
35
|
-
for idx in axis:
|
|
36
|
-
reduce_size *= x.shape[idx]
|
|
37
|
-
|
|
38
|
-
reduce_size_value = graph_builder.value(x.dtype, reduce_size)
|
|
39
|
-
|
|
40
|
-
sum_x = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': axis, 'keep_dims': keep_dims})
|
|
41
|
-
result = graph_builder.emit('RealDiv', [sum_x, reduce_size_value])
|
|
42
|
-
|
|
43
|
-
return result
|
|
@@ -1,32 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for relu_grad"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_all_formats_same
|
|
20
|
-
class ReluGrad(Expander):
|
|
21
|
-
"""ReLU expander"""
|
|
22
|
-
|
|
23
|
-
def _expand(self, graph_builder):
|
|
24
|
-
input_x = self.inputs[0]
|
|
25
|
-
input_y = self.inputs[1]
|
|
26
|
-
|
|
27
|
-
const_zero = graph_builder.value(input_y.dtype, 0)
|
|
28
|
-
ge_result = graph_builder.emit('Greater', [input_y, const_zero])
|
|
29
|
-
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
|
|
30
|
-
result = graph_builder.emit('Mul', [ge_result, input_x])
|
|
31
|
-
|
|
32
|
-
return result
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for SigmoidCrossEntropyWithLogits"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_all_formats_same
|
|
20
|
-
class SigmoidCrossEntropyWithLogits(Expander):
|
|
21
|
-
"""SigmoidCrossEntropyWithLogits expander"""
|
|
22
|
-
|
|
23
|
-
def _expand(self, graph_builder):
|
|
24
|
-
logits, labels = self.inputs
|
|
25
|
-
# Calculate sigmoid_cross_entropy_with_logits(logits, labels)
|
|
26
|
-
# formula of sigmoid_cross_entropy_with_logits is:
|
|
27
|
-
# -(labels * log(sigmoid(logits)) + (1 - labels) * log(1 - sigmoid(logits)))
|
|
28
|
-
# To ensure stability and avoid overflow, the formula equal to :
|
|
29
|
-
# max(logits, 0) - logits * labels + log(1 + exp(-abs(logits)))
|
|
30
|
-
const_one = graph_builder.value(logits.dtype, 1.0)
|
|
31
|
-
const_zero = graph_builder.value(logits.dtype, 0.0)
|
|
32
|
-
max_logits = graph_builder.emit('Maximum', [logits, const_zero])
|
|
33
|
-
logits_mul_labels = graph_builder.emit('Mul', [logits, labels])
|
|
34
|
-
abs_logits = graph_builder.emit('Abs', [logits])
|
|
35
|
-
neg_abs_logits = graph_builder.emit('Neg', [abs_logits])
|
|
36
|
-
exp_neg_abs_logits = graph_builder.emit('Exp', [neg_abs_logits])
|
|
37
|
-
one_add_exp_neg_abs_logits = graph_builder.emit('Add', [const_one, exp_neg_abs_logits])
|
|
38
|
-
log_one_add_exp_neg_abs_logits = graph_builder.emit('Log', [one_add_exp_neg_abs_logits])
|
|
39
|
-
res_tmp = graph_builder.emit('Sub', [max_logits, logits_mul_labels])
|
|
40
|
-
res = graph_builder.emit('Add', [res_tmp, log_one_add_exp_neg_abs_logits])
|
|
41
|
-
return res
|
|
@@ -1,35 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for SigmoidCrossEntropyWithLogitsGrad"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_all_formats_same
|
|
20
|
-
class SigmoidCrossEntropyWithLogitsGrad(Expander):
|
|
21
|
-
"""SigmoidCrossEntropyWithLogitsGrad expander"""
|
|
22
|
-
|
|
23
|
-
def _expand(self, graph_builder):
|
|
24
|
-
logits, label, dout = self.inputs
|
|
25
|
-
# Calculate sigmoid_cross_entropy_with_logits_grad(logits, label, dout)
|
|
26
|
-
# formula of sigmoid_cross_entropy_with_logits_grad is :
|
|
27
|
-
# (sigmoid(logits) - label) * dout
|
|
28
|
-
const_one = graph_builder.value(logits.dtype, 1.0)
|
|
29
|
-
neg_x = graph_builder.emit('Neg', [logits])
|
|
30
|
-
exp_neg_x = graph_builder.emit('Exp', [neg_x])
|
|
31
|
-
add_exp = graph_builder.emit('Add', [const_one, exp_neg_x])
|
|
32
|
-
sigmoid_res = graph_builder.emit('RealDiv', [const_one, add_exp])
|
|
33
|
-
sigmoid_res_sub_label = graph_builder.emit('Sub', [sigmoid_res, label])
|
|
34
|
-
res = graph_builder.emit('Mul', [sigmoid_res_sub_label, dout])
|
|
35
|
-
return res
|
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for SigmoidGrad"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_all_formats_same
|
|
20
|
-
class SigmoidGrad(Expander):
|
|
21
|
-
"""SigmoidGrad expander"""
|
|
22
|
-
|
|
23
|
-
def _expand(self, graph_builder):
|
|
24
|
-
input_y, dy = self.inputs
|
|
25
|
-
# Calculate sigmoid_grad(y, dy)
|
|
26
|
-
# formula of sigmoid_grad is : (1 - y) * y * dy
|
|
27
|
-
const_one = graph_builder.value(input_y.dtype, 1.0)
|
|
28
|
-
one_mins_y = graph_builder.emit('Sub', [const_one, input_y])
|
|
29
|
-
y_mul_dy = graph_builder.emit('Mul', [input_y, dy])
|
|
30
|
-
res = graph_builder.emit('Mul', [one_mins_y, y_mul_dy])
|
|
31
|
-
return res
|
|
@@ -1,35 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for slice"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_attrs('begin', 'size')
|
|
20
|
-
class Slice(Expander):
|
|
21
|
-
"""Slice expander"""
|
|
22
|
-
|
|
23
|
-
def _expand(self, graph_builder):
|
|
24
|
-
input_x = self.inputs[0]
|
|
25
|
-
begin = self.attrs['begin']
|
|
26
|
-
size = self.attrs['size']
|
|
27
|
-
end = []
|
|
28
|
-
strides = []
|
|
29
|
-
for i, begin_idx in enumerate(begin):
|
|
30
|
-
strides.append(1)
|
|
31
|
-
end.append(begin_idx + size[i])
|
|
32
|
-
output = graph_builder.tensor(size, input_x.dtype, input_x.data_format)
|
|
33
|
-
graph_builder.op('StridedSlice', output, [input_x], attrs={'begin': begin, 'end': end, 'strides': strides})
|
|
34
|
-
|
|
35
|
-
return output
|
|
@@ -1,42 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for SoftmaxCrossEntropyWithLogits"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
|
21
|
-
class SoftmaxCrossEntropyWithLogits(Expander):
|
|
22
|
-
"""SoftmaxCrossEntropyWithLogits expander"""
|
|
23
|
-
|
|
24
|
-
def _expand(self, graph_builder):
|
|
25
|
-
logits, label = self.inputs
|
|
26
|
-
# Calculate softmax_cross_entropy_with_logits(logits, label)
|
|
27
|
-
# formula of softmax_cross_entropy_with_logits is : -reduce_sum(label * log(softmax(logits)))
|
|
28
|
-
axis = (-1,)
|
|
29
|
-
max_x = graph_builder.emit('ReduceMax', [logits], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
30
|
-
data_sub = graph_builder.emit('Sub', [logits, max_x])
|
|
31
|
-
data_exp = graph_builder.emit('Exp', [data_sub])
|
|
32
|
-
data_expsum = graph_builder.emit('ReduceSum', [data_exp], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
33
|
-
data_softmax = graph_builder.emit('RealDiv', [data_exp, data_expsum])
|
|
34
|
-
const_eps = graph_builder.value(logits.dtype, 0.000001)
|
|
35
|
-
data_softmax_safety = graph_builder.emit("Maximum", [data_softmax, const_eps])
|
|
36
|
-
softmax_log = graph_builder.emit('Log', [data_softmax_safety])
|
|
37
|
-
label_mul_log = graph_builder.emit('Mul', [label, softmax_log])
|
|
38
|
-
tmp_res = data_expsum = graph_builder.emit('ReduceSum', [label_mul_log], attrs={
|
|
39
|
-
'reduce_axis': axis, 'keep_dims': False})
|
|
40
|
-
loss = graph_builder.emit('Neg', [tmp_res])
|
|
41
|
-
dlogits = graph_builder.emit('Sub', [data_softmax, label])
|
|
42
|
-
return loss, dlogits
|
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for SoftmaxGradExt"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
from ._utils import get_reduce_axis_shape
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@VLD.add_format(DF.FRAC_NZ, DF.FRAC_NZ, DF.DEFAULT)
|
|
22
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
23
|
-
@VLD.check_attrs('axis')
|
|
24
|
-
class SoftmaxGradExt(Expander):
|
|
25
|
-
"""SoftmaxGradExt expander"""
|
|
26
|
-
|
|
27
|
-
def _expand(self, graph_builder):
|
|
28
|
-
x, y, z = self.inputs
|
|
29
|
-
axis = self.attrs['axis']
|
|
30
|
-
|
|
31
|
-
reduce_axis, ori_reduced_shape = get_reduce_axis_shape(x.shape, x.data_format, axis)
|
|
32
|
-
|
|
33
|
-
data_mul = graph_builder.emit('Mul', [x, y])
|
|
34
|
-
data_sum = graph_builder.emit('ReduceSum', [data_mul],
|
|
35
|
-
attrs={'reduce_axis': reduce_axis, 'keep_dims': True, 'reduce_output_fuse': True})
|
|
36
|
-
if x.data_format == DF.FRAC_NZ:
|
|
37
|
-
data_sum = graph_builder.emit('Reshape', [data_sum], attrs={'shape': ori_reduced_shape})
|
|
38
|
-
data_sub = graph_builder.emit('Sub', [x, data_sum])
|
|
39
|
-
data_mul2 = graph_builder.emit('Mul', [data_sub, y])
|
|
40
|
-
result = graph_builder.emit('Mul', [data_mul2, z])
|
|
41
|
-
return result
|