mindspore 2.1.0__cp39-cp39-win_amd64.whl → 2.2.10__cp39-cp39-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +12 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +61 -71
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +74 -104
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/amp.py +47 -11
- mindspore/atlprov.dll +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +13 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +28 -5
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +141 -88
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +84 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +165 -109
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +144 -86
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +172 -187
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +51 -57
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +6 -17
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/aicpu/eps.py +32 -0
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +28 -19
- mindspore/ops/operations/_grad_ops.py +72 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +189 -141
- mindspore/ops/operations/nn_ops.py +794 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +36 -15
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/METADATA +5 -3
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/RECORD +433 -479
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/entry_points.txt +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.10.dist-info}/top_level.txt +0 -0
|
@@ -1,80 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""generate json desc for graph kernel ops"""
|
|
16
|
-
import json
|
|
17
|
-
import json.decoder as jd
|
|
18
|
-
import traceback
|
|
19
|
-
from mindspore import log as logger
|
|
20
|
-
import mindspore._extends.graph_kernel.expanders as expanders
|
|
21
|
-
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def create_expander(expand_info):
|
|
25
|
-
"""Create an expander according to op name"""
|
|
26
|
-
def call_func(func, arg):
|
|
27
|
-
return func(arg)
|
|
28
|
-
op_name = str(expand_info['name'])
|
|
29
|
-
if not hasattr(expanders, op_name):
|
|
30
|
-
raise GraphKernelUnsupportedException("Expander does not support op: {}".format(op_name))
|
|
31
|
-
expander = getattr(expanders, op_name)
|
|
32
|
-
return call_func(expander, expand_info)
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def extract_expand_info(kernel_info):
|
|
36
|
-
"""Convert the json into a more friendly format"""
|
|
37
|
-
input_desc = []
|
|
38
|
-
if 'input_desc' in kernel_info and kernel_info['input_desc']:
|
|
39
|
-
for desc in kernel_info['input_desc']:
|
|
40
|
-
input_desc += desc
|
|
41
|
-
attrs = {}
|
|
42
|
-
if 'attr' in kernel_info and kernel_info['attr']:
|
|
43
|
-
for attr in kernel_info["attr"]:
|
|
44
|
-
attrs[attr["name"]] = attr["value"]
|
|
45
|
-
expand_info = {
|
|
46
|
-
"name": kernel_info["name"],
|
|
47
|
-
"input_desc": input_desc,
|
|
48
|
-
"output_desc": kernel_info["output_desc"],
|
|
49
|
-
"attr": attrs,
|
|
50
|
-
"process": kernel_info["process"],
|
|
51
|
-
}
|
|
52
|
-
return expand_info
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def get_op_expander(json_str: str):
|
|
56
|
-
"""get op expander by json info"""
|
|
57
|
-
try:
|
|
58
|
-
kernel_info = json.loads(json_str)
|
|
59
|
-
expand_info = extract_expand_info(kernel_info)
|
|
60
|
-
|
|
61
|
-
expander = create_expander(expand_info)
|
|
62
|
-
graph = expander.run()
|
|
63
|
-
|
|
64
|
-
# dump graph to json desc.
|
|
65
|
-
desc = graph.dump()
|
|
66
|
-
return json.dumps(desc)
|
|
67
|
-
|
|
68
|
-
except jd.JSONDecodeError:
|
|
69
|
-
logger.error("Decode input json str failed in expander, json is: {}".format(json_str))
|
|
70
|
-
logger.error(traceback.format_exc())
|
|
71
|
-
return ""
|
|
72
|
-
except GraphKernelUnsupportedException as e:
|
|
73
|
-
logger.info(e.message)
|
|
74
|
-
return ""
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def get_expander_op_list():
|
|
78
|
-
"""get supported expander op list"""
|
|
79
|
-
op_list = [name for name in dir(expanders) if name[0].isupper()]
|
|
80
|
-
return ' '.join(op_list)
|
|
@@ -1,54 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ============================================================================
|
|
15
|
-
"""expanders init. Deprecated, please add the new operators in the c++ file"""
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
from .addn import AddN
|
|
19
|
-
from .batchnorm import BatchNorm
|
|
20
|
-
from .batchnorm_grad import BatchNormGrad
|
|
21
|
-
from .clip_by_norm_no_div_sum import ClipByNormNoDivSum
|
|
22
|
-
from .conv2d import Conv2D
|
|
23
|
-
from .complex import CAbs, CAdd, CDiv, CMul, CSub, CRealDiv
|
|
24
|
-
from .dropout_grad import DropoutGrad
|
|
25
|
-
from .equal_count import EqualCount
|
|
26
|
-
from .erfc import Erfc
|
|
27
|
-
from .fused_adam import FusedAdam
|
|
28
|
-
from .fused_adam_weight_decay import FusedAdamWeightDecay
|
|
29
|
-
from .fused_mul_add import FusedMulAdd
|
|
30
|
-
from .gelu_grad import GeLUGrad
|
|
31
|
-
from .gkdropout import GkDropout
|
|
32
|
-
from .identity import Identity
|
|
33
|
-
from .layernorm import LayerNorm
|
|
34
|
-
from .layernorm_grad import LayerNormGrad
|
|
35
|
-
from .logsoftmax import LogSoftmax
|
|
36
|
-
from .logsoftmax_grad import LogSoftmaxGrad
|
|
37
|
-
from .matmul import BatchMatMul, MatMul
|
|
38
|
-
from .maximum_grad import MaximumGrad
|
|
39
|
-
from .minimum_grad import MinimumGrad
|
|
40
|
-
from .oneslike import OnesLike
|
|
41
|
-
from .reduce_mean import ReduceMean
|
|
42
|
-
from .relu_grad import ReluGrad
|
|
43
|
-
from .sigmoid_cross_entropy_with_logits import SigmoidCrossEntropyWithLogits
|
|
44
|
-
from .sigmoid_cross_entropy_with_logits_grad import SigmoidCrossEntropyWithLogitsGrad
|
|
45
|
-
from .sigmoid_grad import SigmoidGrad
|
|
46
|
-
from .slice import Slice
|
|
47
|
-
from .softmax_cross_entropy_with_logits import SoftmaxCrossEntropyWithLogits
|
|
48
|
-
from .softmax_grad_ext import SoftmaxGradExt
|
|
49
|
-
from .sqrt_grad import SqrtGrad
|
|
50
|
-
from .squared_difference import SquaredDifference
|
|
51
|
-
from .square_sum_v1 import SquareSumV1
|
|
52
|
-
from .square_sum_all import SquareSumAll
|
|
53
|
-
from .tanh_grad import TanhGrad
|
|
54
|
-
from .softsign import Softsign
|
|
@@ -1,269 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""GraphKernel expander utils"""
|
|
16
|
-
from abc import ABCMeta, abstractmethod
|
|
17
|
-
from mindspore._extends.graph_kernel.model import model_builder as builder
|
|
18
|
-
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class Expander(metaclass=ABCMeta):
|
|
22
|
-
"""
|
|
23
|
-
Expander is the base class of expanders.
|
|
24
|
-
|
|
25
|
-
The method `_expand` should be overridden to implement the operator detail.
|
|
26
|
-
"""
|
|
27
|
-
def __init__(self, expand_info):
|
|
28
|
-
self.name = expand_info["name"]
|
|
29
|
-
self.inputs = expand_info["input_desc"]
|
|
30
|
-
self.outputs = expand_info["output_desc"]
|
|
31
|
-
self.attrs = expand_info["attr"]
|
|
32
|
-
self.processor = expand_info["process"]
|
|
33
|
-
|
|
34
|
-
def run(self):
|
|
35
|
-
"""
|
|
36
|
-
Expand the operator to a graph.
|
|
37
|
-
|
|
38
|
-
`GraphKernelUnsupportedException` would be raised if check failed.
|
|
39
|
-
"""
|
|
40
|
-
self._check()
|
|
41
|
-
graph_builder = builder.GraphBuilder()
|
|
42
|
-
with graph_builder.graph_scope(self.name) as graph_scope:
|
|
43
|
-
# transform input_desc to Tensor
|
|
44
|
-
self.inputs = [graph_builder.tensor(inp['shape'], inp['data_type'], inp['format']) for inp in self.inputs]
|
|
45
|
-
graph_scope.set_input(*self.inputs)
|
|
46
|
-
outputs = self._expand(graph_builder)
|
|
47
|
-
if isinstance(outputs, (list, tuple)):
|
|
48
|
-
self._check_output_same(outputs)
|
|
49
|
-
graph_scope.set_output(*outputs)
|
|
50
|
-
else:
|
|
51
|
-
self._check_output_same([outputs])
|
|
52
|
-
graph_scope.set_output(outputs)
|
|
53
|
-
|
|
54
|
-
graph = graph_builder.get()[0]
|
|
55
|
-
graph.set_processor(self.processor)
|
|
56
|
-
return graph
|
|
57
|
-
|
|
58
|
-
def _check(self):
|
|
59
|
-
"""Check inputs"""
|
|
60
|
-
|
|
61
|
-
def _check_output_same(self, outputs):
|
|
62
|
-
for index, value in enumerate(self.outputs):
|
|
63
|
-
if list(outputs[index].shape) != list(value['shape']):
|
|
64
|
-
raise GKException("{} 's output shape {} is wrong. Expected:{}".format(
|
|
65
|
-
self.__class__.__name__, list(outputs[index].shape), list(value['shape'])))
|
|
66
|
-
if outputs[index].dtype != value['data_type']:
|
|
67
|
-
raise GKException("{} 's output data_type {} is wrong. Expected: {}".format(
|
|
68
|
-
self.__class__.__name__, outputs[index].dtype, value['data_type']))
|
|
69
|
-
if outputs[index].data_format != value['format']:
|
|
70
|
-
raise GKException("{} 's output format {} is wrong. Expected: {}".format(
|
|
71
|
-
self.__class__.__name__, outputs[index].data_format, value['format']))
|
|
72
|
-
|
|
73
|
-
@abstractmethod
|
|
74
|
-
def _expand(self, graph_builder):
|
|
75
|
-
"""Expand operator, this function should be overridden in subclass"""
|
|
76
|
-
raise Exception("_expand() is not implemented in {}".format(self.__class__.__name__))
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
class ExpanderInfoValidator:
|
|
80
|
-
"""ExpanderInfoValidator is the utility class which defines the validator decorator for expanders"""
|
|
81
|
-
|
|
82
|
-
def __init__(self):
|
|
83
|
-
"""Init"""
|
|
84
|
-
|
|
85
|
-
@staticmethod
|
|
86
|
-
def _add_check_function(kls, func):
|
|
87
|
-
"""
|
|
88
|
-
Rewrite the function `_check` in class Expander
|
|
89
|
-
to append the new `func` after the original checks.
|
|
90
|
-
"""
|
|
91
|
-
old_check = getattr(kls, "_check")
|
|
92
|
-
|
|
93
|
-
def new_check(obj):
|
|
94
|
-
old_check(obj)
|
|
95
|
-
func(obj)
|
|
96
|
-
|
|
97
|
-
setattr(kls, "_check", new_check)
|
|
98
|
-
|
|
99
|
-
@staticmethod
|
|
100
|
-
def add_format(*input_format):
|
|
101
|
-
"""
|
|
102
|
-
Add new supported format for the operator
|
|
103
|
-
|
|
104
|
-
this function will add a list `__supported_formats` into the expander,
|
|
105
|
-
saving the whitelist of formats that this op supports.
|
|
106
|
-
it also rewrites the `_check` function to check the formats.
|
|
107
|
-
"""
|
|
108
|
-
format_list_name = "__supported_formats"
|
|
109
|
-
|
|
110
|
-
def _check_format(obj):
|
|
111
|
-
inp_formats = [inp['format'] for inp in obj.inputs]
|
|
112
|
-
for formats in getattr(obj, format_list_name):
|
|
113
|
-
if len(formats) != len(inp_formats):
|
|
114
|
-
raise GKException("For '{}', length of registered format is different from the length of inputs "
|
|
115
|
-
"format: {} vs {}".format(obj.name, len(formats), len(inp_formats)))
|
|
116
|
-
if all((fmt == inp for fmt, inp in zip(formats, inp_formats))):
|
|
117
|
-
return
|
|
118
|
-
raise GKException("Unregistered format ({}) for op {}".format(','.join(inp_formats), obj.name))
|
|
119
|
-
|
|
120
|
-
def wrapper(cls):
|
|
121
|
-
if not issubclass(cls, Expander):
|
|
122
|
-
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
|
|
123
|
-
if not hasattr(cls, format_list_name):
|
|
124
|
-
setattr(cls, format_list_name, list())
|
|
125
|
-
ExpanderInfoValidator._add_check_function(cls, _check_format)
|
|
126
|
-
getattr(cls, format_list_name).append(input_format)
|
|
127
|
-
return cls
|
|
128
|
-
|
|
129
|
-
return wrapper
|
|
130
|
-
|
|
131
|
-
@staticmethod
|
|
132
|
-
def check_all_formats_same(kls):
|
|
133
|
-
"""Check that all formats are the same"""
|
|
134
|
-
|
|
135
|
-
# Ensure no args case can return a class
|
|
136
|
-
def _check(*args):
|
|
137
|
-
def _check_format(obj):
|
|
138
|
-
inp_formats = [inp['format'] for inp in obj.inputs]
|
|
139
|
-
if all((fmt == inp_formats[0] for fmt in inp_formats[1:])):
|
|
140
|
-
return
|
|
141
|
-
raise GKException("[check_all_formats_same] unmatched formats ({}) for op {}".format(
|
|
142
|
-
','.join(inp_formats), obj.name))
|
|
143
|
-
|
|
144
|
-
def wrapper(cls):
|
|
145
|
-
if not issubclass(cls, Expander):
|
|
146
|
-
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
|
|
147
|
-
ExpanderInfoValidator._add_check_function(cls, _check_format)
|
|
148
|
-
return cls
|
|
149
|
-
|
|
150
|
-
return wrapper
|
|
151
|
-
|
|
152
|
-
return _check()(kls)
|
|
153
|
-
|
|
154
|
-
@staticmethod
|
|
155
|
-
def check_attrs(*args):
|
|
156
|
-
"""Check the attrs exist"""
|
|
157
|
-
|
|
158
|
-
def _check_attr(obj):
|
|
159
|
-
for a in args:
|
|
160
|
-
if a not in obj.attrs:
|
|
161
|
-
raise GKException("attr '{}' does not exist. {}".format(a, obj.name))
|
|
162
|
-
|
|
163
|
-
def wrapper(cls):
|
|
164
|
-
if not issubclass(cls, Expander):
|
|
165
|
-
raise Exception("{} should be subclass of Expander.".format(cls.__name__))
|
|
166
|
-
ExpanderInfoValidator._add_check_function(cls, _check_attr)
|
|
167
|
-
return cls
|
|
168
|
-
|
|
169
|
-
return wrapper
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
def to_frac_z_axis(ori_shape, ori_axis):
|
|
173
|
-
"""
|
|
174
|
-
judge the format is fractal NZ
|
|
175
|
-
Parameters
|
|
176
|
-
----------
|
|
177
|
-
ori_shape: list or tuple
|
|
178
|
-
original shape of input
|
|
179
|
-
ori_axis: list or tuple
|
|
180
|
-
original axis of original shape to operate
|
|
181
|
-
Returns
|
|
182
|
-
-------
|
|
183
|
-
output: list
|
|
184
|
-
axis of the fractal Nz shape
|
|
185
|
-
"""
|
|
186
|
-
frac_z_axis = list(ori_axis)
|
|
187
|
-
shape_len = len(ori_shape)
|
|
188
|
-
axis_count = len(frac_z_axis)
|
|
189
|
-
axis_negative_1 = shape_len - 1
|
|
190
|
-
axis_negative_2 = shape_len - 2
|
|
191
|
-
for i in range(axis_count):
|
|
192
|
-
axis_index = (frac_z_axis[i] + shape_len) % shape_len
|
|
193
|
-
if axis_index == axis_negative_1:
|
|
194
|
-
if frac_z_axis[i] > shape_len - 2: # akg:[2,3] [1,4] tbe:[2,4] [1,3]
|
|
195
|
-
frac_z_axis[i] = axis_index - 1
|
|
196
|
-
frac_z_axis.append(axis_index + 2)
|
|
197
|
-
else: # no case cover this branch now
|
|
198
|
-
frac_z_axis[i] = axis_index - 1
|
|
199
|
-
frac_z_axis.append(axis_index + 2)
|
|
200
|
-
elif axis_index == axis_negative_2:
|
|
201
|
-
frac_z_axis[i] = axis_index + 1
|
|
202
|
-
frac_z_axis.append(axis_index + 2)
|
|
203
|
-
else:
|
|
204
|
-
frac_z_axis[i] = axis_index
|
|
205
|
-
return frac_z_axis
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
def infer_shape_from_fractalnz(fractal):
|
|
209
|
-
"get original shape from fractalnz shape"
|
|
210
|
-
shape = []
|
|
211
|
-
dims = len(fractal)
|
|
212
|
-
batch = dims - 4
|
|
213
|
-
for i in range(batch):
|
|
214
|
-
shape.append(fractal[i])
|
|
215
|
-
m = fractal[dims - 3] * fractal[dims - 2]
|
|
216
|
-
n = fractal[dims - 4] * fractal[dims - 1]
|
|
217
|
-
shape.append(m)
|
|
218
|
-
shape.append(n)
|
|
219
|
-
return shape
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
def get_reduced_ori_shape(shape, axis):
|
|
223
|
-
"get shape after reduced which is based on original shape"
|
|
224
|
-
reduced_ori_shape = []
|
|
225
|
-
for i, value in enumerate(shape):
|
|
226
|
-
if i in axis:
|
|
227
|
-
reduced_ori_shape.append(1)
|
|
228
|
-
else:
|
|
229
|
-
reduced_ori_shape.append(value)
|
|
230
|
-
return reduced_ori_shape
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
def get_reduce_axis_shape(shape, data_format, axis):
|
|
234
|
-
"""
|
|
235
|
-
Get the reduce axis under format `data_format` and original reduced shape.
|
|
236
|
-
Parameters
|
|
237
|
-
----------
|
|
238
|
-
shape: list or tuple
|
|
239
|
-
shape of input
|
|
240
|
-
data_format: str
|
|
241
|
-
data format of input
|
|
242
|
-
axis: None, int, list or tuple
|
|
243
|
-
reduce axis of the original shape
|
|
244
|
-
Returns
|
|
245
|
-
-------
|
|
246
|
-
reduce_axis: list
|
|
247
|
-
reduce axis of the `data_format` shape
|
|
248
|
-
ori_reduced_shape: list
|
|
249
|
-
original reduced shape
|
|
250
|
-
"""
|
|
251
|
-
ori_shape = shape
|
|
252
|
-
if data_format == "FRACTAL_NZ":
|
|
253
|
-
ori_shape = infer_shape_from_fractalnz(shape)
|
|
254
|
-
if not axis:
|
|
255
|
-
axis = []
|
|
256
|
-
for i, _ in enumerate(ori_shape):
|
|
257
|
-
axis.append(i)
|
|
258
|
-
else:
|
|
259
|
-
if isinstance(axis, int):
|
|
260
|
-
axis = [axis]
|
|
261
|
-
for i, _ in enumerate(list(axis)):
|
|
262
|
-
if axis[i] < 0:
|
|
263
|
-
axis[i] += len(ori_shape)
|
|
264
|
-
|
|
265
|
-
ori_reduced_shape = get_reduced_ori_shape(ori_shape, axis)
|
|
266
|
-
reduce_axis = axis
|
|
267
|
-
if data_format == "FRACTAL_NZ":
|
|
268
|
-
reduce_axis = to_frac_z_axis(ori_shape, axis)
|
|
269
|
-
return reduce_axis, ori_reduced_shape
|
|
@@ -1,33 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for addn"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.check_all_formats_same
|
|
21
|
-
class AddN(Expander):
|
|
22
|
-
"""Expand AddN to multiple Adds"""
|
|
23
|
-
|
|
24
|
-
def _check(self):
|
|
25
|
-
if len(self.inputs) < 2:
|
|
26
|
-
raise GKException("For 'AddN', the inputs num should be greater than 1, but got {}"
|
|
27
|
-
.format(len(self.inputs)))
|
|
28
|
-
|
|
29
|
-
def _expand(self, graph_builder):
|
|
30
|
-
result = self.inputs[0]
|
|
31
|
-
for inp in self.inputs[1:]:
|
|
32
|
-
result = graph_builder.emit('Add', [result, inp])
|
|
33
|
-
return result
|
|
@@ -1,152 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for BatchNorm"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
from .expand_dims import ExpandDims
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@VLD.add_format(DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
22
|
-
@VLD.add_format(DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
23
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
24
|
-
@VLD.check_attrs('is_training', 'momentum', 'epsilon')
|
|
25
|
-
class BatchNorm(Expander):
|
|
26
|
-
"""BatchNorm expander"""
|
|
27
|
-
|
|
28
|
-
def _expand(self, graph_builder):
|
|
29
|
-
# get op info
|
|
30
|
-
input_x = self.inputs[0]
|
|
31
|
-
input_scale = self.inputs[1]
|
|
32
|
-
input_offset = self.inputs[2]
|
|
33
|
-
input_mean = self.inputs[3]
|
|
34
|
-
input_variance = self.inputs[4]
|
|
35
|
-
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
|
|
36
|
-
|
|
37
|
-
input_x_ori_type = input_x.dtype
|
|
38
|
-
input_x_new_type = input_x.dtype
|
|
39
|
-
if input_x.dtype == "float16" and input_scale.dtype == "float32" and input_offset.dtype == "float32" and \
|
|
40
|
-
input_mean.dtype == "float32" and input_variance.dtype == "float32":
|
|
41
|
-
input_x_new_type = "float32"
|
|
42
|
-
if input_x_new_type != input_x_ori_type:
|
|
43
|
-
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': input_x_new_type})
|
|
44
|
-
|
|
45
|
-
if self.attrs['is_training']:
|
|
46
|
-
self.inputs[0] = input_x
|
|
47
|
-
res_y, mean_res, variance_res, mean_muls, y_sqrt_rec = self._bn_train(graph_builder)
|
|
48
|
-
if input_x_new_type != input_x_ori_type:
|
|
49
|
-
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
|
|
50
|
-
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
|
|
51
|
-
# infer mode
|
|
52
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
53
|
-
input_mean = graph_builder.emit(
|
|
54
|
-
'Reshape', [input_mean], attrs={'shape': ExpandDims.infer_shape(input_mean.shape, [-1, -1])})
|
|
55
|
-
input_scale = graph_builder.emit(
|
|
56
|
-
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
|
|
57
|
-
input_offset = graph_builder.emit(
|
|
58
|
-
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
|
|
59
|
-
x_sub = graph_builder.emit('Sub', [input_x, input_mean])
|
|
60
|
-
x_sub_mul = graph_builder.emit('Mul', [input_scale, x_sub])
|
|
61
|
-
var_add = graph_builder.emit('Add', [epsilon_v, input_variance])
|
|
62
|
-
var_add_sqrt = graph_builder.emit('Sqrt', [var_add])
|
|
63
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
64
|
-
var_add_sqrt = graph_builder.emit(
|
|
65
|
-
'Reshape', [var_add_sqrt], attrs={'shape': ExpandDims.infer_shape(var_add_sqrt.shape, [-1, -1])})
|
|
66
|
-
x_div = graph_builder.emit('RealDiv', [x_sub_mul, var_add_sqrt])
|
|
67
|
-
res_y = graph_builder.emit('Add', [input_offset, x_div])
|
|
68
|
-
if input_x_new_type != input_x_ori_type:
|
|
69
|
-
res_y = graph_builder.emit('Cast', [res_y], attrs={'dst_type': input_x_ori_type})
|
|
70
|
-
return res_y, var_add, var_add, var_add, var_add
|
|
71
|
-
|
|
72
|
-
def _bn_train(self, graph_builder):
|
|
73
|
-
"""expand BatchNorm for training mode"""
|
|
74
|
-
input_x = self.inputs[0]
|
|
75
|
-
input_scale = self.inputs[1]
|
|
76
|
-
input_offset = self.inputs[2]
|
|
77
|
-
input_mean = self.inputs[3]
|
|
78
|
-
input_variance = self.inputs[4]
|
|
79
|
-
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
|
|
80
|
-
reduce_axis = ()
|
|
81
|
-
shape_x = input_x.shape
|
|
82
|
-
if input_x.data_format == DF.NHWC:
|
|
83
|
-
reduce_axis = (0, 1, 2)
|
|
84
|
-
num = shape_x[0] * shape_x[1] * shape_x[2]
|
|
85
|
-
else:
|
|
86
|
-
reduce_axis = (0, 2, 3)
|
|
87
|
-
num = shape_x[0] * shape_x[2] * shape_x[3]
|
|
88
|
-
num_rec = 1.0 / num
|
|
89
|
-
num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
|
|
90
|
-
|
|
91
|
-
# compute mean value of input_x
|
|
92
|
-
mean_sum = graph_builder.emit(
|
|
93
|
-
'ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
|
94
|
-
mean_muls = graph_builder.emit('Mul', [mean_sum, num_rec_v])
|
|
95
|
-
|
|
96
|
-
# compute variance of input_x
|
|
97
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
98
|
-
mean_muls_expand = graph_builder.emit(
|
|
99
|
-
'Reshape', [mean_muls], attrs={'shape': ExpandDims.infer_shape(mean_muls.shape, [-1, -1])})
|
|
100
|
-
else:
|
|
101
|
-
mean_muls_expand = mean_muls
|
|
102
|
-
var_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
|
|
103
|
-
var_mul = graph_builder.emit('Mul', [var_sub, var_sub])
|
|
104
|
-
var_sum = graph_builder.emit('ReduceSum', [var_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
|
105
|
-
var_mul = graph_builder.emit('Mul', [var_sum, num_rec_v])
|
|
106
|
-
|
|
107
|
-
# y_sqrt_rec means 1 / sqrt(variance + epsilon), which is calculated in backward pass
|
|
108
|
-
scalar_one = 1.0
|
|
109
|
-
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
|
|
110
|
-
y_add = graph_builder.emit('Add', [var_mul, epsilon_v])
|
|
111
|
-
y_sqrt = graph_builder.emit('Sqrt', [y_add])
|
|
112
|
-
y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
|
|
113
|
-
|
|
114
|
-
# compute res_y
|
|
115
|
-
tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
|
|
116
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
117
|
-
y_sqrt_rec_expand = graph_builder.emit(
|
|
118
|
-
'Reshape', [y_sqrt_rec], attrs={'shape': ExpandDims.infer_shape(y_sqrt_rec.shape, [-1, -1])})
|
|
119
|
-
else:
|
|
120
|
-
y_sqrt_rec_expand = y_sqrt_rec
|
|
121
|
-
y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
|
|
122
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
123
|
-
input_scale_expand = graph_builder.emit(
|
|
124
|
-
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
|
|
125
|
-
else:
|
|
126
|
-
input_scale_expand = input_scale
|
|
127
|
-
res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
|
|
128
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
129
|
-
input_offset_expand = graph_builder.emit(
|
|
130
|
-
'Reshape', [input_offset], attrs={'shape': ExpandDims.infer_shape(input_offset.shape, [-1, -1])})
|
|
131
|
-
else:
|
|
132
|
-
input_offset_expand = input_offset
|
|
133
|
-
res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
|
|
134
|
-
|
|
135
|
-
# compute mean_res
|
|
136
|
-
momentum_sub = scalar_one - self.attrs['momentum']
|
|
137
|
-
momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub)
|
|
138
|
-
new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
|
|
139
|
-
momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'])
|
|
140
|
-
current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
|
|
141
|
-
updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
|
|
142
|
-
mean_res = graph_builder.emit('Assign', [input_mean, updated_moving_mean])
|
|
143
|
-
|
|
144
|
-
# variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
|
|
145
|
-
var_num = float(num) / (num - 1)
|
|
146
|
-
var_num_v = graph_builder.value(input_scale.dtype, var_num)
|
|
147
|
-
var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
|
|
148
|
-
new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
|
|
149
|
-
current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
|
|
150
|
-
updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
|
|
151
|
-
variance_res = graph_builder.emit('Assign', [input_variance, updated_moving_variance])
|
|
152
|
-
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec
|