mindspore 2.1.0__cp38-cp38-win_amd64.whl → 2.2.11__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/amp.py +47 -11
- mindspore/atlprov.dll +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +7 -4
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +429 -486
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for DropoutGrad"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_all_formats_same
|
|
20
|
-
@VLD.check_attrs('keep_prob')
|
|
21
|
-
class DropoutGrad(Expander):
|
|
22
|
-
"""DropoutGrad expander"""
|
|
23
|
-
|
|
24
|
-
def _expand(self, graph_builder):
|
|
25
|
-
input_dy, input_mask = self.inputs
|
|
26
|
-
keep_prob = self.attrs['keep_prob']
|
|
27
|
-
r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob)
|
|
28
|
-
result = graph_builder.emit('Mul', [input_dy, r_keep_prob])
|
|
29
|
-
result = graph_builder.emit('Mul', [result, input_mask])
|
|
30
|
-
return result
|
|
@@ -1,50 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for equal_count"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.check_all_formats_same
|
|
21
|
-
class EqualCount(Expander):
|
|
22
|
-
"""EqualCount expander"""
|
|
23
|
-
|
|
24
|
-
def __init__(self, expand_info):
|
|
25
|
-
super().__init__(expand_info)
|
|
26
|
-
self.shape_x = self.inputs[0]['shape']
|
|
27
|
-
self.shape_y = self.inputs[1]['shape']
|
|
28
|
-
self.dtype_x = self.inputs[0]['data_type']
|
|
29
|
-
self.dtype_y = self.inputs[1]['data_type']
|
|
30
|
-
|
|
31
|
-
def _check(self):
|
|
32
|
-
if self.shape_x != self.shape_y:
|
|
33
|
-
raise GKException("For 'EqualCount', the inputs shape should be same, but got {} and {}"
|
|
34
|
-
.format(self.shape_x, self.shape_y))
|
|
35
|
-
if self.dtype_x != self.dtype_y:
|
|
36
|
-
raise GKException("For 'EqualCount', the inputs data type should be same, but got {} and {}"
|
|
37
|
-
.format(self.dtype_x, self.dtype_y))
|
|
38
|
-
|
|
39
|
-
def _expand(self, graph_builder):
|
|
40
|
-
input_x = self.inputs[0]
|
|
41
|
-
input_y = self.inputs[1]
|
|
42
|
-
|
|
43
|
-
eql_val = graph_builder.emit('Equal', [input_x, input_y])
|
|
44
|
-
cast_val = graph_builder.emit('Cast', [eql_val], attrs={'dst_type': 'float32'})
|
|
45
|
-
axis = list(range(len(input_x.shape)))
|
|
46
|
-
result = graph_builder.emit('ReduceSum', [cast_val], attrs={'reduce_axis': axis, 'keep_dims': False})
|
|
47
|
-
|
|
48
|
-
if result.dtype != input_x.dtype:
|
|
49
|
-
result = graph_builder.emit('Cast', [result], attrs={'dst_type': input_x.dtype})
|
|
50
|
-
return result
|
|
@@ -1,35 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for erfc"""
|
|
16
|
-
from ._utils import Expander
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class Erfc(Expander):
|
|
20
|
-
"""Erfc expander"""
|
|
21
|
-
|
|
22
|
-
def _expand(self, graph_builder):
|
|
23
|
-
input_x = self.inputs[0]
|
|
24
|
-
result = None
|
|
25
|
-
if input_x.dtype == "float16":
|
|
26
|
-
const_one = graph_builder.value("float32", 1)
|
|
27
|
-
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': "float32"})
|
|
28
|
-
erf_result = graph_builder.emit('Erf', [input_x])
|
|
29
|
-
result = graph_builder.emit('Sub', [const_one, erf_result])
|
|
30
|
-
result = graph_builder.emit('Cast', [result], attrs={'dst_type': "float16"})
|
|
31
|
-
return result
|
|
32
|
-
const_one = graph_builder.value(input_x.dtype, 1)
|
|
33
|
-
erf_result = graph_builder.emit('Erf', [input_x])
|
|
34
|
-
result = graph_builder.emit('Sub', [const_one, erf_result])
|
|
35
|
-
return result
|
|
@@ -1,50 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for expand_dims"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_attrs('axis')
|
|
20
|
-
class ExpandDims(Expander):
|
|
21
|
-
"""ExpandDims expander"""
|
|
22
|
-
|
|
23
|
-
def _expand(self, graph_builder):
|
|
24
|
-
input_x = self.inputs[0]
|
|
25
|
-
shape = self.infer_shape(input_x.shape, self.attrs['axis'])
|
|
26
|
-
result = graph_builder.emit('Reshape', [input_x], attrs={'shape': shape})
|
|
27
|
-
|
|
28
|
-
return result
|
|
29
|
-
|
|
30
|
-
@staticmethod
|
|
31
|
-
def infer_shape(shape, axis):
|
|
32
|
-
"""infer shape for expand_dims"""
|
|
33
|
-
def insert_axis(shape, axis):
|
|
34
|
-
if not isinstance(axis, int) or axis > len(shape) or axis < -len(shape) - 1:
|
|
35
|
-
raise ValueError("For 'ExpandDims', value of attr 'axis' should be of type int and in the range [{}, "
|
|
36
|
-
"{}], but got {} with type {}".format(-len(shape) - 1, len(shape), axis, type(axis)))
|
|
37
|
-
if axis >= 0:
|
|
38
|
-
shape.insert(axis, 1)
|
|
39
|
-
else:
|
|
40
|
-
shape.insert(axis + len(shape) + 1, 1)
|
|
41
|
-
return shape
|
|
42
|
-
out_shape = shape[:]
|
|
43
|
-
if isinstance(axis, int):
|
|
44
|
-
return insert_axis(out_shape, axis)
|
|
45
|
-
if isinstance(axis, (list, tuple)):
|
|
46
|
-
for i in axis:
|
|
47
|
-
out_shape = insert_axis(out_shape, i)
|
|
48
|
-
return out_shape
|
|
49
|
-
raise ValueError("For 'ExpandDims', type of attr 'axis' should be one of ['int', 'list', 'tuple'], but got {} "
|
|
50
|
-
"with type {}".format(axis, type(axis)))
|
|
@@ -1,44 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for fused_adam"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_all_formats_same
|
|
20
|
-
class FusedAdam(Expander):
|
|
21
|
-
"""FusedAdam expander"""
|
|
22
|
-
|
|
23
|
-
def _expand(self, graph_builder):
|
|
24
|
-
beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient = self.inputs
|
|
25
|
-
|
|
26
|
-
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
|
|
27
|
-
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
|
|
28
|
-
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
|
|
29
|
-
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
|
|
30
|
-
grad_square = graph_builder.emit('Mul', [gradient, gradient])
|
|
31
|
-
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
|
|
32
|
-
next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
|
|
33
|
-
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
|
|
34
|
-
sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
|
|
35
|
-
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
|
|
36
|
-
update_with_lr = graph_builder.emit('Mul', [lr, update])
|
|
37
|
-
next_para = graph_builder.emit('Sub', [param, update_with_lr])
|
|
38
|
-
|
|
39
|
-
param_result = graph_builder.emit(
|
|
40
|
-
'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
|
|
41
|
-
param_result = graph_builder.emit('InplaceAssign', [m, next_m, param_result], attrs={'fake_output': True})
|
|
42
|
-
param_result = graph_builder.emit('InplaceAssign', [v, next_v, param_result], attrs={'fake_output': True})
|
|
43
|
-
|
|
44
|
-
return param_result
|
|
@@ -1,47 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for fused_adam_weight_decay"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_all_formats_same
|
|
20
|
-
class FusedAdamWeightDecay(Expander):
|
|
21
|
-
"""FusedAdamWeightDecay expander"""
|
|
22
|
-
|
|
23
|
-
def _expand(self, graph_builder):
|
|
24
|
-
beta_1, one_sub_beta_1, beta_2, one_sub_beta_2, eps, lr, param, m, v, gradient, weight_decay = self.inputs
|
|
25
|
-
|
|
26
|
-
# compute result
|
|
27
|
-
beta_1_mul_m = graph_builder.emit('Mul', [beta_1, m])
|
|
28
|
-
one_sub_beta_1_mul_grad = graph_builder.emit('Mul', [one_sub_beta_1, gradient])
|
|
29
|
-
next_m = graph_builder.emit('Add', [beta_1_mul_m, one_sub_beta_1_mul_grad])
|
|
30
|
-
beta_2_mul_v = graph_builder.emit('Mul', [beta_2, v])
|
|
31
|
-
grad_square = graph_builder.emit('Mul', [gradient, gradient])
|
|
32
|
-
one_sub_beta_2_mul_grad_square = graph_builder.emit('Mul', [one_sub_beta_2, grad_square])
|
|
33
|
-
next_v = graph_builder.emit('Add', [beta_2_mul_v, one_sub_beta_2_mul_grad_square])
|
|
34
|
-
sqrt_next_v = graph_builder.emit('Sqrt', [next_v])
|
|
35
|
-
sqrt_next_v_add_eps = graph_builder.emit('Add', [sqrt_next_v, eps])
|
|
36
|
-
update = graph_builder.emit('RealDiv', [next_m, sqrt_next_v_add_eps])
|
|
37
|
-
param_with_weight_decay = graph_builder.emit('Mul', [weight_decay, param])
|
|
38
|
-
update = graph_builder.emit('Add', [update, param_with_weight_decay])
|
|
39
|
-
update_with_lr = graph_builder.emit('Mul', [lr, update])
|
|
40
|
-
next_para = graph_builder.emit('Sub', [param, update_with_lr])
|
|
41
|
-
|
|
42
|
-
para_result = graph_builder.emit(
|
|
43
|
-
'InplaceAssign', [param, next_para, next_para], attrs={'fake_output': True})
|
|
44
|
-
para_result = graph_builder.emit('InplaceAssign', [m, next_m, para_result], attrs={'fake_output': True})
|
|
45
|
-
para_result = graph_builder.emit('InplaceAssign', [v, next_v, para_result], attrs={'fake_output': True})
|
|
46
|
-
|
|
47
|
-
return para_result
|
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for FusedMulAdd"""
|
|
16
|
-
from ._utils import Expander
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class FusedMulAdd(Expander):
|
|
20
|
-
"""FusedMulAdd expander"""
|
|
21
|
-
|
|
22
|
-
def _expand(self, graph_builder):
|
|
23
|
-
input_x, input_y, input_z = self.inputs
|
|
24
|
-
|
|
25
|
-
mul_res = graph_builder.emit('Mul', [input_x, input_y])
|
|
26
|
-
result = graph_builder.emit('Add', [mul_res, input_z])
|
|
27
|
-
|
|
28
|
-
return result
|
|
@@ -1,70 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for gelugrad"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_all_formats_same
|
|
20
|
-
class GeLUGrad(Expander):
|
|
21
|
-
"""GeLUGrad expander"""
|
|
22
|
-
CSVALUE = 0.044715
|
|
23
|
-
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
|
|
24
|
-
CSVALUE_TRI = 0.134141 # CSVALUE * 3
|
|
25
|
-
|
|
26
|
-
def _expand(self, graph_builder):
|
|
27
|
-
# cal formula are:
|
|
28
|
-
# gelu_grad of dy and x is dy * y'
|
|
29
|
-
# y' is 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
|
|
30
|
-
# tanh_para is 'sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)'
|
|
31
|
-
# mul_right is 'sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x)'
|
|
32
|
-
|
|
33
|
-
input_dy, input_x, _ = self.inputs
|
|
34
|
-
|
|
35
|
-
# create some const var
|
|
36
|
-
const_csvalue = graph_builder.value(input_dy.dtype, self.CSVALUE)
|
|
37
|
-
const_csvalue_sqrt_two_div_pi = graph_builder.value(input_dy.dtype, self.CSVALUE_SQRT_TWO_DIV_PI)
|
|
38
|
-
const_csvalue_tri = graph_builder.value(input_dy.dtype, self.CSVALUE_TRI)
|
|
39
|
-
const_one = graph_builder.value(input_dy.dtype, 1)
|
|
40
|
-
const_half = graph_builder.value(input_dy.dtype, 0.5)
|
|
41
|
-
|
|
42
|
-
# cal mul_right
|
|
43
|
-
mul_double = graph_builder.emit('Mul', [input_x, input_x])
|
|
44
|
-
mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double])
|
|
45
|
-
mul_add_one = graph_builder.emit('Add', [const_one, mul_double_mul_tri])
|
|
46
|
-
mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one])
|
|
47
|
-
|
|
48
|
-
# cal tanh_para
|
|
49
|
-
mul_triple = graph_builder.emit('Mul', [input_x, mul_double])
|
|
50
|
-
mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple])
|
|
51
|
-
mul_add_x = graph_builder.emit('Add', [input_x, mul_triple_mul_csvalue])
|
|
52
|
-
tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x])
|
|
53
|
-
|
|
54
|
-
# cal 0.5 * (1.0 + tanh(tahn_para))
|
|
55
|
-
tanh_res = graph_builder.emit('Tanh', [tanh_para])
|
|
56
|
-
tanh_res_add_one = graph_builder.emit('Add', [const_one, tanh_res])
|
|
57
|
-
half_mul_tanh_res_add_one = graph_builder.emit('Mul', [const_half, tanh_res_add_one])
|
|
58
|
-
|
|
59
|
-
# cal 0.5 * x * (1.0 - tanh(tanh_para) * tanh(tanh_para)) * mul_right
|
|
60
|
-
tan_res_double = graph_builder.emit('Mul', [tanh_res, tanh_res])
|
|
61
|
-
one_sub_tan_res_double = graph_builder.emit('Sub', [const_one, tan_res_double])
|
|
62
|
-
half_mul_x = graph_builder.emit('Mul', [const_half, input_x])
|
|
63
|
-
mul_tmp = graph_builder.emit('Mul', [half_mul_x, one_sub_tan_res_double])
|
|
64
|
-
mul_final = graph_builder.emit('Mul', [mul_tmp, mul_right])
|
|
65
|
-
|
|
66
|
-
# cal result
|
|
67
|
-
result_tmp = graph_builder.emit('Add', [half_mul_tanh_res_add_one, mul_final])
|
|
68
|
-
result = graph_builder.emit('Mul', [input_dy, result_tmp])
|
|
69
|
-
|
|
70
|
-
return result
|
|
@@ -1,40 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for GkDropout"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_all_formats_same
|
|
20
|
-
@VLD.check_attrs('keep_prob')
|
|
21
|
-
class GkDropout(Expander):
|
|
22
|
-
"""GkDropout expander"""
|
|
23
|
-
|
|
24
|
-
def _expand(self, graph_builder):
|
|
25
|
-
input_x, input_mask = self.inputs
|
|
26
|
-
keep_prob = self.attrs['keep_prob']
|
|
27
|
-
|
|
28
|
-
r_keep_prob = graph_builder.value(input_x.dtype, 1.0 / keep_prob)
|
|
29
|
-
keep_prob = graph_builder.value(input_x.dtype, keep_prob)
|
|
30
|
-
|
|
31
|
-
if input_mask.dtype != input_x.dtype:
|
|
32
|
-
input_mask = graph_builder.emit('Cast', [input_mask], attrs={'dst_type': input_x.dtype})
|
|
33
|
-
mask = graph_builder.emit('LessEqual', [input_mask, keep_prob]) # output is bool type
|
|
34
|
-
mask = graph_builder.emit('Cast', [mask], attrs={'dst_type': input_x.dtype})
|
|
35
|
-
|
|
36
|
-
# compute result
|
|
37
|
-
result = graph_builder.emit('Mul', [r_keep_prob, input_x])
|
|
38
|
-
result = graph_builder.emit('Mul', [result, mask])
|
|
39
|
-
|
|
40
|
-
return result, mask
|
|
@@ -1,25 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for Identity"""
|
|
16
|
-
from ._utils import Expander
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class Identity(Expander):
|
|
20
|
-
"""Identity expander"""
|
|
21
|
-
|
|
22
|
-
def _expand(self, graph_builder):
|
|
23
|
-
input_x = self.inputs[0]
|
|
24
|
-
result = graph_builder.emit('Reshape', [input_x], attrs={'shape': input_x.shape})
|
|
25
|
-
return result
|
|
@@ -1,93 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for LayerNorm"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
from ._utils import infer_shape_from_fractalnz, get_reduced_ori_shape, to_frac_z_axis
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@VLD.add_format(DF.FRAC_NZ, DF.DEFAULT, DF.DEFAULT)
|
|
22
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
23
|
-
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis', 'epsilon')
|
|
24
|
-
class LayerNorm(Expander):
|
|
25
|
-
"""LayerNorm expander"""
|
|
26
|
-
|
|
27
|
-
def _expand(self, graph_builder):
|
|
28
|
-
input_x, input_gamma, input_beta = self.inputs
|
|
29
|
-
processor = self.processor
|
|
30
|
-
begin_norm_axis = self.attrs['begin_norm_axis']
|
|
31
|
-
epsilon = self.attrs['epsilon']
|
|
32
|
-
|
|
33
|
-
ori_dtype = input_x.dtype
|
|
34
|
-
if processor == 'aicore' and ori_dtype == 'float16':
|
|
35
|
-
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
|
|
36
|
-
input_gamma = graph_builder.emit('Cast', [input_gamma], attrs={'dst_type': 'float32'})
|
|
37
|
-
input_beta = graph_builder.emit('Cast', [input_beta], attrs={'dst_type': 'float32'})
|
|
38
|
-
|
|
39
|
-
ori_shape_x = input_x.shape
|
|
40
|
-
if input_x.data_format == DF.FRAC_NZ:
|
|
41
|
-
ori_shape_x = infer_shape_from_fractalnz(input_x.shape)
|
|
42
|
-
|
|
43
|
-
# Calculate the scaling ratio of the average
|
|
44
|
-
if begin_norm_axis < 0:
|
|
45
|
-
begin_norm_axis += len(ori_shape_x)
|
|
46
|
-
|
|
47
|
-
reduce_axis = ()
|
|
48
|
-
for i, _ in enumerate(ori_shape_x):
|
|
49
|
-
if i > begin_norm_axis or i == begin_norm_axis:
|
|
50
|
-
reduce_axis = reduce_axis + (i,)
|
|
51
|
-
|
|
52
|
-
reduce_elts = 1.0
|
|
53
|
-
for i in reduce_axis:
|
|
54
|
-
reduce_elts *= ori_shape_x[i]
|
|
55
|
-
# after reduced
|
|
56
|
-
ori_reduced_shape_x = get_reduced_ori_shape(ori_shape_x, reduce_axis)
|
|
57
|
-
|
|
58
|
-
axis = reduce_axis
|
|
59
|
-
if input_x.data_format == DF.FRAC_NZ:
|
|
60
|
-
axis = to_frac_z_axis(ori_shape_x, reduce_axis)
|
|
61
|
-
|
|
62
|
-
mean_cof_v = graph_builder.value(input_x.dtype, 1.0 / reduce_elts)
|
|
63
|
-
|
|
64
|
-
# Calculate mean
|
|
65
|
-
mean_red = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
66
|
-
mean = graph_builder.emit('Mul', [mean_red, mean_cof_v])
|
|
67
|
-
if input_x.data_format == DF.FRAC_NZ:
|
|
68
|
-
mean = graph_builder.emit('Reshape', [mean], attrs={'shape': ori_reduced_shape_x})
|
|
69
|
-
|
|
70
|
-
# Calculate variance
|
|
71
|
-
variance_sub = graph_builder.emit('Sub', [input_x, mean])
|
|
72
|
-
variance_mul = graph_builder.emit('Mul', [variance_sub, variance_sub])
|
|
73
|
-
variance_red = graph_builder.emit('ReduceSum', [variance_mul], attrs={'reduce_axis': axis, 'keep_dims': True})
|
|
74
|
-
variance = graph_builder.emit('Mul', [variance_red, mean_cof_v])
|
|
75
|
-
if input_x.data_format == DF.FRAC_NZ:
|
|
76
|
-
variance = graph_builder.emit('Reshape', [variance], attrs={'shape': ori_reduced_shape_x})
|
|
77
|
-
|
|
78
|
-
# Calculate normalize
|
|
79
|
-
normalize_sub = graph_builder.emit('Sub', [input_x, mean])
|
|
80
|
-
epsilon_v = graph_builder.value(input_x.dtype, epsilon)
|
|
81
|
-
normalize_add = graph_builder.emit('Add', [variance, epsilon_v])
|
|
82
|
-
normlize_rsqrt = graph_builder.emit('Rsqrt', [normalize_add])
|
|
83
|
-
normalize_mul = graph_builder.emit('Mul', [normalize_sub, normlize_rsqrt])
|
|
84
|
-
|
|
85
|
-
# Calculate scale and translate
|
|
86
|
-
scale_mul = graph_builder.emit('Mul', [normalize_mul, input_gamma])
|
|
87
|
-
res = graph_builder.emit('Add', [scale_mul, input_beta])
|
|
88
|
-
|
|
89
|
-
if processor == 'aicore' and ori_dtype == 'float16':
|
|
90
|
-
res = graph_builder.emit('Cast', [res], attrs={'dst_type': 'float16'})
|
|
91
|
-
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float16'})
|
|
92
|
-
variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float16'})
|
|
93
|
-
return res, mean, variance
|
|
@@ -1,113 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for LayerNormGrad"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
21
|
-
@VLD.check_attrs('begin_norm_axis', 'begin_params_axis')
|
|
22
|
-
class LayerNormGrad(Expander):
|
|
23
|
-
"""LayerNormGrad expander"""
|
|
24
|
-
|
|
25
|
-
def _expand(self, graph_builder):
|
|
26
|
-
x, dy, variance, mean, gamma = self.inputs
|
|
27
|
-
processor = self.processor
|
|
28
|
-
begin_norm_axis = self.attrs['begin_norm_axis']
|
|
29
|
-
begin_params_axis = self.attrs['begin_params_axis']
|
|
30
|
-
epsilon = self.attrs['epsilon'] if 'epsilon' in self.attrs else 1e-12
|
|
31
|
-
|
|
32
|
-
ori_dtype = x.dtype
|
|
33
|
-
if processor == 'aicore' and ori_dtype == 'float16':
|
|
34
|
-
x = graph_builder.emit('Cast', [x], attrs={'dst_type': 'float32'})
|
|
35
|
-
dy = graph_builder.emit('Cast', [dy], attrs={'dst_type': 'float32'})
|
|
36
|
-
variance = graph_builder.emit('Cast', [variance], attrs={'dst_type': 'float32'})
|
|
37
|
-
mean = graph_builder.emit('Cast', [mean], attrs={'dst_type': 'float32'})
|
|
38
|
-
gamma = graph_builder.emit('Cast', [gamma], attrs={'dst_type': 'float32'})
|
|
39
|
-
|
|
40
|
-
if begin_norm_axis < 0:
|
|
41
|
-
begin_norm_axis += len(x.shape)
|
|
42
|
-
if begin_params_axis < 0:
|
|
43
|
-
begin_params_axis += len(x.shape)
|
|
44
|
-
|
|
45
|
-
norm_axis = tuple(range(begin_norm_axis, len(x.shape)))
|
|
46
|
-
param_axis = tuple(range(0, begin_params_axis))
|
|
47
|
-
|
|
48
|
-
reduce_size = 1.0
|
|
49
|
-
for i in norm_axis:
|
|
50
|
-
reduce_size *= x.shape[i]
|
|
51
|
-
|
|
52
|
-
# set some constant val.
|
|
53
|
-
eps = graph_builder.value(x.dtype, epsilon)
|
|
54
|
-
const_neg_half = graph_builder.value(x.dtype, -0.5)
|
|
55
|
-
const_neg_two = graph_builder.value(x.dtype, -2.0)
|
|
56
|
-
const_two = graph_builder.value(x.dtype, 2.0)
|
|
57
|
-
const_neg_one = graph_builder.value(x.dtype, -1.0)
|
|
58
|
-
mean_cof = graph_builder.value(x.dtype, (1.0 / reduce_size))
|
|
59
|
-
|
|
60
|
-
# cal dg db
|
|
61
|
-
var_eps = graph_builder.emit('Add', [variance, eps])
|
|
62
|
-
var_eps_log = graph_builder.emit('Log', [var_eps])
|
|
63
|
-
var_eps_mul = graph_builder.emit('Mul', [var_eps_log, const_neg_half])
|
|
64
|
-
rsqrt_var_eps = graph_builder.emit('Exp', [var_eps_mul])
|
|
65
|
-
|
|
66
|
-
x_sub_mean = graph_builder.emit('Sub', [x, mean])
|
|
67
|
-
|
|
68
|
-
x_sub_mean_mul_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, x_sub_mean])
|
|
69
|
-
dg_mul = graph_builder.emit('Mul', [dy, x_sub_mean_mul_rsqrt_var_eps])
|
|
70
|
-
dg = graph_builder.emit('ReduceSum', [dg_mul], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
|
71
|
-
db = graph_builder.emit('ReduceSum', [dy], attrs={'reduce_axis': param_axis, 'keep_dims': False})
|
|
72
|
-
|
|
73
|
-
# pd_var
|
|
74
|
-
tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, rsqrt_var_eps])
|
|
75
|
-
r_tmp_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, tmp_var_eps])
|
|
76
|
-
|
|
77
|
-
dy_mul_gamma = graph_builder.emit('Mul', [dy, gamma])
|
|
78
|
-
tmp_mul = graph_builder.emit('Mul', [dy_mul_gamma, x_sub_mean])
|
|
79
|
-
padvar_mul1 = graph_builder.emit('ReduceSum', [tmp_mul], attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
|
80
|
-
padvar_mul3 = graph_builder.emit('Mul', [padvar_mul1, r_tmp_var_eps])
|
|
81
|
-
pd_var = graph_builder.emit('Mul', [padvar_mul3, const_neg_half])
|
|
82
|
-
|
|
83
|
-
# pd_mean
|
|
84
|
-
pdmean1_sum = graph_builder.emit('ReduceSum', [dy_mul_gamma],
|
|
85
|
-
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
|
86
|
-
neg_rsqrt_var_eps = graph_builder.emit('Mul', [rsqrt_var_eps, const_neg_one])
|
|
87
|
-
pd_mean_1 = graph_builder.emit('Mul', [neg_rsqrt_var_eps, pdmean1_sum])
|
|
88
|
-
|
|
89
|
-
pdmean2_mul1 = graph_builder.emit('Mul', [const_neg_two, x_sub_mean])
|
|
90
|
-
pdmean2_sum = graph_builder.emit('ReduceSum', [pdmean2_mul1],
|
|
91
|
-
attrs={'reduce_axis': norm_axis, 'keep_dims': True})
|
|
92
|
-
pdmean2_mul3 = graph_builder.emit('Mul', [pdmean2_sum, mean_cof])
|
|
93
|
-
pd_mean_2 = graph_builder.emit('Mul', [pdmean2_mul3, pd_var])
|
|
94
|
-
|
|
95
|
-
pd_mean = graph_builder.emit('Add', [pd_mean_1, pd_mean_2])
|
|
96
|
-
|
|
97
|
-
# cal dx
|
|
98
|
-
pd_x_1 = graph_builder.emit('Mul', [dy_mul_gamma, rsqrt_var_eps])
|
|
99
|
-
|
|
100
|
-
pdx2_mul = graph_builder.emit('Mul', [pd_var, x_sub_mean])
|
|
101
|
-
pdx2_mul_two = graph_builder.emit('Mul', [pdx2_mul, const_two])
|
|
102
|
-
pd_x_2 = graph_builder.emit('Mul', [pdx2_mul_two, mean_cof])
|
|
103
|
-
|
|
104
|
-
pd_x_3 = graph_builder.emit('Mul', [pd_mean, mean_cof])
|
|
105
|
-
|
|
106
|
-
dx_tmp = graph_builder.emit('Add', [pd_x_1, pd_x_2])
|
|
107
|
-
dx = graph_builder.emit('Add', [dx_tmp, pd_x_3])
|
|
108
|
-
|
|
109
|
-
if processor == 'aicore' and ori_dtype == 'float16':
|
|
110
|
-
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
|
|
111
|
-
dg = graph_builder.emit('Cast', [dg], attrs={'dst_type': 'float16'})
|
|
112
|
-
db = graph_builder.emit('Cast', [db], attrs={'dst_type': 'float16'})
|
|
113
|
-
return dx, dg, db
|