mindspore 2.1.0__cp38-cp38-win_amd64.whl → 2.2.11__cp38-cp38-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp38-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp38-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp38-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -1
- mindspore/_checkparam.py +23 -29
- mindspore/_extends/graph_kernel/__init__.py +0 -1
- mindspore/_extends/graph_kernel/model/graph_split.py +84 -76
- mindspore/_extends/graph_kernel/model/model_builder.py +9 -50
- mindspore/_extends/graph_kernel/splitter.py +4 -11
- mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +122 -15
- mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +84 -67
- mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +4 -2
- mindspore/_extends/parallel_compile/akg_compiler/util.py +10 -7
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_adapter.py +2 -2
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_helper.py +6 -5
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job.py +1 -1
- mindspore/_extends/parallel_compile/tbe_compiler/tbe_job_manager.py +1 -1
- mindspore/_extends/parse/__init__.py +13 -15
- mindspore/_extends/parse/namespace.py +7 -33
- mindspore/_extends/parse/parser.py +67 -72
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +86 -106
- mindspore/_extends/parse/trope.py +1 -1
- mindspore/_extends/remote/kernel_build_server.py +25 -7
- mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
- mindspore/_install_custom.py +43 -0
- mindspore/amp.py +47 -11
- mindspore/atlprov.dll +0 -0
- mindspore/boost/boost.py +1 -8
- mindspore/boost/boost_cell_wrapper.py +3 -2
- mindspore/boost/grad_accumulation.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +8 -7
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +5 -3
- mindspore/common/_jit_fallback_utils.py +6 -0
- mindspore/common/_register_for_adapter.py +2 -0
- mindspore/common/_register_for_tensor.py +2 -2
- mindspore/common/_stub_tensor.py +13 -0
- mindspore/common/_utils.py +29 -0
- mindspore/common/api.py +174 -259
- mindspore/common/auto_dynamic_shape.py +494 -0
- mindspore/common/dtype.py +18 -11
- mindspore/common/dump.py +6 -4
- mindspore/common/initializer.py +14 -14
- mindspore/common/jit_config.py +33 -15
- mindspore/common/lazy_inline.py +126 -7
- mindspore/common/mindir_util.py +101 -0
- mindspore/common/parameter.py +51 -41
- mindspore/common/seed.py +4 -4
- mindspore/common/sparse_tensor.py +13 -14
- mindspore/common/tensor.py +243 -165
- mindspore/communication/__init__.py +7 -4
- mindspore/communication/_comm_helper.py +83 -4
- mindspore/communication/management.py +152 -84
- mindspore/config/op_info.config +14 -3
- mindspore/context.py +152 -61
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +52 -52
- mindspore/dataset/callback/ds_callback.py +16 -2
- mindspore/dataset/core/config.py +68 -51
- mindspore/dataset/engine/cache_client.py +33 -7
- mindspore/dataset/engine/datasets.py +250 -112
- mindspore/dataset/engine/datasets_audio.py +43 -211
- mindspore/dataset/engine/datasets_standard_format.py +16 -35
- mindspore/dataset/engine/datasets_text.py +43 -67
- mindspore/dataset/engine/datasets_user_defined.py +86 -100
- mindspore/dataset/engine/datasets_vision.py +219 -1029
- mindspore/dataset/engine/iterators.py +11 -4
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +4 -0
- mindspore/dataset/engine/obs/util.py +3 -0
- mindspore/dataset/engine/samplers.py +1 -1
- mindspore/dataset/engine/validators.py +19 -5
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +101 -127
- mindspore/dataset/text/utils.py +205 -138
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/transforms/py_transforms_util.py +40 -12
- mindspore/dataset/transforms/transforms.py +95 -40
- mindspore/dataset/utils/browse_dataset.py +8 -2
- mindspore/dataset/utils/line_reader.py +17 -19
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/c_transforms.py +6 -3
- mindspore/dataset/vision/transforms.py +409 -287
- mindspore/dataset/vision/utils.py +13 -14
- mindspore/dataset/vision/validators.py +11 -1
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/map_parameter.py +14 -0
- mindspore/{nn/optim_ex → experimental/optim}/__init__.py +30 -29
- mindspore/{nn/optim_ex → experimental/optim}/adam.py +60 -67
- mindspore/{nn/optim_ex → experimental/optim}/adamw.py +181 -203
- mindspore/experimental/optim/lr_scheduler.py +1427 -0
- mindspore/{nn/optim_ex → experimental/optim}/optimizer.py +252 -259
- mindspore/{nn/optim_ex → experimental/optim}/sgd.py +147 -152
- mindspore/gen_ops.py +273 -0
- mindspore/include/OWNERS +0 -1
- mindspore/include/api/data_type.h +2 -1
- mindspore/include/api/graph.h +0 -15
- mindspore/include/api/kernel.h +2 -0
- mindspore/include/api/kernel_api.h +37 -12
- mindspore/include/api/model.h +17 -14
- mindspore/include/api/status.h +8 -3
- mindspore/include/api/types.h +37 -4
- mindspore/include/c_api/ms/abstract.h +67 -0
- mindspore/include/c_api/ms/attribute.h +197 -0
- mindspore/include/c_api/ms/base/handle_types.h +43 -0
- mindspore/include/c_api/ms/base/macros.h +32 -0
- mindspore/include/c_api/ms/base/status.h +33 -0
- mindspore/include/c_api/ms/base/types.h +282 -0
- mindspore/include/c_api/ms/context.h +102 -0
- mindspore/include/c_api/ms/graph.h +160 -0
- mindspore/include/c_api/ms/node.h +606 -0
- mindspore/include/c_api/ms/tensor.h +161 -0
- mindspore/include/c_api/ms/value.h +84 -0
- mindspore/include/dataset/constants.h +6 -5
- mindspore/include/dataset/execute.h +23 -13
- mindspore/include/dataset/text.h +26 -26
- mindspore/include/dataset/transforms.h +13 -13
- mindspore/include/dataset/vision.h +60 -60
- mindspore/include/dataset/vision_ascend.h +5 -6
- mindspore/include/dataset/vision_lite.h +17 -17
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +2 -2
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/nn/__init__.py +0 -2
- mindspore/nn/cell.py +313 -74
- mindspore/nn/dynamic_lr.py +21 -21
- mindspore/nn/layer/activation.py +22 -30
- mindspore/nn/layer/basic.py +15 -13
- mindspore/nn/layer/channel_shuffle.py +1 -1
- mindspore/nn/layer/container.py +271 -9
- mindspore/nn/layer/conv.py +323 -204
- mindspore/nn/layer/dense.py +8 -5
- mindspore/nn/layer/embedding.py +33 -27
- mindspore/nn/layer/flash_attention.py +61 -95
- mindspore/nn/layer/image.py +8 -6
- mindspore/nn/layer/math.py +16 -25
- mindspore/nn/layer/normalization.py +107 -66
- mindspore/nn/layer/padding.py +1 -1
- mindspore/nn/layer/pooling.py +131 -109
- mindspore/nn/layer/rnn_cells.py +27 -22
- mindspore/nn/layer/rnns.py +13 -16
- mindspore/nn/layer/thor_layer.py +1 -1
- mindspore/nn/layer/transformer.py +221 -154
- mindspore/nn/learning_rate_schedule.py +9 -1
- mindspore/nn/loss/loss.py +235 -174
- mindspore/nn/optim/ada_grad.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -0
- mindspore/nn/optim/adafactor.py +2 -1
- mindspore/nn/optim/adam.py +7 -4
- mindspore/nn/optim/adamax.py +3 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -3
- mindspore/nn/optim/ftrl.py +6 -5
- mindspore/nn/optim/lamb.py +7 -4
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +5 -3
- mindspore/nn/optim/momentum.py +2 -1
- mindspore/nn/optim/optimizer.py +53 -4
- mindspore/nn/optim/proximal_ada_grad.py +3 -4
- mindspore/nn/optim/rmsprop.py +4 -3
- mindspore/nn/optim/rprop.py +23 -12
- mindspore/nn/optim/sgd.py +26 -11
- mindspore/nn/optim/thor.py +9 -7
- mindspore/nn/probability/bijector/bijector.py +5 -5
- mindspore/nn/probability/bijector/power_transform.py +27 -27
- mindspore/nn/probability/bijector/softplus.py +3 -3
- mindspore/nn/probability/distribution/_utils/custom_ops.py +3 -3
- mindspore/nn/probability/distribution/bernoulli.py +5 -5
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +7 -7
- mindspore/nn/probability/distribution/cauchy.py +0 -1
- mindspore/nn/probability/distribution/distribution.py +3 -3
- mindspore/nn/probability/distribution/gamma.py +3 -3
- mindspore/nn/probability/distribution/geometric.py +4 -4
- mindspore/nn/probability/distribution/gumbel.py +4 -4
- mindspore/nn/probability/distribution/log_normal.py +2 -2
- mindspore/nn/probability/distribution/logistic.py +2 -2
- mindspore/nn/probability/distribution/poisson.py +4 -4
- mindspore/nn/probability/distribution/transformed_distribution.py +3 -3
- mindspore/nn/probability/distribution/uniform.py +6 -6
- mindspore/nn/wrap/__init__.py +4 -2
- mindspore/nn/wrap/cell_wrapper.py +87 -34
- mindspore/nn/wrap/grad_reducer.py +8 -5
- mindspore/nn/wrap/loss_scale.py +105 -42
- mindspore/numpy/array_creations.py +1 -2
- mindspore/numpy/array_ops.py +3 -2
- mindspore/numpy/utils_const.py +5 -5
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/__init__.py +0 -5
- mindspore/ops/_grad_experimental/grad_array_ops.py +2 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +15 -2
- mindspore/ops/_grad_experimental/grad_debug_ops.py +0 -37
- mindspore/ops/_grad_experimental/grad_implementations.py +11 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +2 -216
- mindspore/ops/_grad_experimental/grad_math_ops.py +19 -199
- mindspore/ops/_grad_experimental/grad_sparse.py +15 -0
- mindspore/ops/_grad_experimental/grad_sparse_ops.py +3 -3
- mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +1 -1
- mindspore/ops/_op_impl/aicpu/__init__.py +14 -2
- mindspore/ops/_op_impl/aicpu/add.py +3 -3
- mindspore/ops/_op_impl/aicpu/bias_add_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
- mindspore/ops/_op_impl/{_custom_op/flash_attention/constants.py → aicpu/eps.py} +18 -27
- mindspore/ops/_op_impl/aicpu/gamma.py +2 -2
- mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +21 -2
- mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +6 -3
- mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +0 -1
- mindspore/ops/_op_impl/aicpu/multinomial.py +3 -3
- mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +15 -7
- mindspore/ops/_op_impl/aicpu/random_categorical.py +39 -19
- mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +5 -2
- mindspore/ops/_op_impl/aicpu/random_poisson.py +103 -52
- mindspore/ops/_op_impl/aicpu/random_shuffle.py +17 -15
- mindspore/ops/_op_impl/aicpu/{sparseaddmm.py → sparse_addmm.py} +2 -2
- mindspore/ops/_op_impl/aicpu/{sparsesparsemaximum.py → sparse_sparse_maximum.py} +4 -4
- mindspore/ops/_op_impl/aicpu/standard_laplace.py +5 -5
- mindspore/ops/_op_impl/aicpu/standard_normal.py +5 -5
- mindspore/ops/_op_impl/aicpu/truncated_normal.py +9 -7
- mindspore/ops/_op_impl/aicpu/uniform.py +5 -3
- mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +8 -4
- mindspore/ops/_op_impl/aicpu/uniform_int.py +5 -5
- mindspore/ops/_op_impl/aicpu/uniform_real.py +4 -4
- mindspore/ops/_op_impl/tbe/__init__.py +4 -4
- mindspore/ops/_op_impl/tbe/inplace_index_add.py +7 -3
- mindspore/ops/_op_impl/tbe/trans_data_ds.py +2 -0
- mindspore/ops/_primitive_cache.py +1 -1
- mindspore/ops/_tracefunc.py +45 -13
- mindspore/ops/_utils/utils.py +6 -1
- mindspore/ops/_vmap/vmap_array_ops.py +3 -3
- mindspore/ops/_vmap/vmap_base.py +3 -3
- mindspore/ops/_vmap/vmap_convolution_ops.py +1 -1
- mindspore/ops/_vmap/vmap_grad_math_ops.py +6 -4
- mindspore/ops/_vmap/vmap_math_ops.py +5 -2
- mindspore/ops/_vmap/vmap_nn_ops.py +61 -7
- mindspore/ops/arg_dtype_cast.py +54 -0
- mindspore/ops/composite/base.py +37 -10
- mindspore/ops/composite/math_ops.py +5 -4
- mindspore/ops/composite/multitype_ops/_compile_utils.py +275 -73
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +16 -9
- mindspore/ops/composite/multitype_ops/add_impl.py +43 -4
- mindspore/ops/composite/multitype_ops/getitem_impl.py +42 -4
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +6 -0
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/zeros_like_impl.py +9 -0
- mindspore/ops/deprecated.py +304 -0
- mindspore/ops/function/__init__.py +4 -1
- mindspore/ops/function/array_func.py +174 -193
- mindspore/ops/function/clip_func.py +81 -13
- mindspore/ops/function/debug_func.py +1 -1
- mindspore/ops/function/grad/grad_func.py +18 -9
- mindspore/ops/function/image_func.py +10 -4
- mindspore/ops/function/linalg_func.py +5 -5
- mindspore/ops/function/math_func.py +575 -386
- mindspore/ops/function/nn_func.py +568 -260
- mindspore/ops/function/random_func.py +88 -57
- mindspore/ops/function/sparse_func.py +1 -1
- mindspore/ops/function/sparse_unary_func.py +14 -12
- mindspore/ops/function/vmap_func.py +6 -5
- mindspore/ops/functional.py +15 -10
- mindspore/ops/op_info_register.py +244 -25
- mindspore/ops/operations/__init__.py +31 -19
- mindspore/ops/operations/_grad_ops.py +71 -7
- mindspore/ops/operations/_inner_ops.py +350 -17
- mindspore/ops/operations/_quant_ops.py +4 -8
- mindspore/ops/operations/_sequence_ops.py +42 -0
- mindspore/ops/operations/array_ops.py +68 -282
- mindspore/ops/operations/comm_ops.py +107 -59
- mindspore/ops/operations/custom_ops.py +94 -70
- mindspore/ops/operations/debug_ops.py +8 -4
- mindspore/ops/operations/image_ops.py +18 -12
- mindspore/ops/operations/inner_ops.py +26 -3
- mindspore/ops/operations/math_ops.py +192 -144
- mindspore/ops/operations/nn_ops.py +857 -489
- mindspore/ops/operations/other_ops.py +0 -22
- mindspore/ops/operations/random_ops.py +53 -111
- mindspore/ops/operations/sparse_ops.py +3 -1
- mindspore/ops/primitive.py +24 -18
- mindspore/parallel/_auto_parallel_context.py +68 -8
- mindspore/parallel/_cost_model_context.py +2 -2
- mindspore/parallel/_offload_context.py +17 -3
- mindspore/parallel/_parallel_serialization.py +12 -5
- mindspore/parallel/_ps_context.py +12 -0
- mindspore/parallel/_tensor.py +18 -13
- mindspore/parallel/_transformer/layers.py +5 -3
- mindspore/parallel/_transformer/loss.py +1 -0
- mindspore/parallel/_transformer/moe.py +2 -2
- mindspore/parallel/_transformer/op_parallel_config.py +12 -1
- mindspore/parallel/_transformer/transformer.py +23 -3
- mindspore/parallel/_utils.py +11 -7
- mindspore/parallel/algo_parameter_config.py +85 -5
- mindspore/parallel/checkpoint_transform.py +19 -12
- mindspore/parallel/shard.py +21 -14
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/common/struct_type.py +3 -3
- mindspore/profiler/common/util.py +4 -2
- mindspore/profiler/envprofiling.py +1 -1
- mindspore/profiler/parser/aicpu_data_parser.py +5 -3
- mindspore/profiler/parser/ascend_flops_generator.py +2 -2
- mindspore/profiler/parser/ascend_fpbp_generator.py +1 -1
- mindspore/profiler/parser/ascend_hccl_generator.py +249 -12
- mindspore/profiler/parser/ascend_msprof_exporter.py +150 -255
- mindspore/profiler/parser/ascend_msprof_generator.py +204 -17
- mindspore/profiler/parser/ascend_op_generator.py +6 -6
- mindspore/profiler/parser/ascend_steptrace_generator.py +6 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +14 -187
- mindspore/profiler/parser/base_timeline_generator.py +10 -8
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +16 -12
- mindspore/profiler/parser/flops_parser.py +15 -11
- mindspore/profiler/parser/framework_parser.py +38 -22
- mindspore/profiler/parser/hccl_parser.py +16 -12
- mindspore/profiler/parser/integrator.py +22 -11
- mindspore/profiler/parser/memory_usage_parser.py +2 -2
- mindspore/profiler/parser/minddata_analyzer.py +12 -14
- mindspore/profiler/parser/minddata_pipeline_parser.py +1 -1
- mindspore/profiler/parser/msadvisor_parser.py +8 -4
- mindspore/profiler/parser/op_intermediate_parser.py +5 -2
- mindspore/profiler/parser/optime_parser.py +1 -1
- mindspore/profiler/parser/profiler_info.py +21 -2
- mindspore/profiler/parser/step_trace_parser.py +11 -14
- mindspore/profiler/profiling.py +179 -89
- mindspore/rewrite/api/node.py +102 -19
- mindspore/rewrite/api/node_type.py +5 -1
- mindspore/rewrite/api/pattern_engine.py +1 -1
- mindspore/rewrite/api/scoped_value.py +9 -17
- mindspore/rewrite/api/symbol_tree.py +131 -47
- mindspore/rewrite/ast_helpers/__init__.py +2 -1
- mindspore/rewrite/ast_helpers/ast_finder.py +129 -0
- mindspore/rewrite/ast_helpers/ast_modifier.py +116 -104
- mindspore/rewrite/ast_transformers/flatten_recursive_stmt.py +93 -46
- mindspore/rewrite/common/rewrite_elog.py +5 -1
- mindspore/rewrite/namer.py +33 -24
- mindspore/rewrite/namespace.py +14 -5
- mindspore/{_extends/graph_kernel/expanders/complex → rewrite/node}/__init__.py +9 -9
- mindspore/rewrite/node/call_function.py +79 -0
- mindspore/rewrite/node/cell_container.py +135 -0
- mindspore/rewrite/node/control_flow.py +88 -0
- mindspore/rewrite/{node.py → node/node.py} +273 -234
- mindspore/rewrite/node/node_manager.py +254 -0
- mindspore/rewrite/{topological_manager.py → node/node_topological_manager.py} +13 -46
- mindspore/rewrite/parsers/arguments_parser.py +22 -21
- mindspore/rewrite/parsers/assign_parser.py +216 -221
- mindspore/rewrite/parsers/attribute_parser.py +9 -7
- mindspore/rewrite/parsers/class_def_parser.py +174 -113
- mindspore/rewrite/parsers/constant_parser.py +9 -6
- mindspore/rewrite/parsers/container_parser.py +9 -7
- mindspore/rewrite/parsers/for_parser.py +42 -21
- mindspore/rewrite/parsers/function_def_parser.py +24 -16
- mindspore/rewrite/parsers/if_parser.py +28 -24
- mindspore/rewrite/parsers/module_parser.py +196 -25
- mindspore/rewrite/{parser.py → parsers/parser.py} +4 -2
- mindspore/rewrite/{parser_register.py → parsers/parser_register.py} +1 -1
- mindspore/rewrite/parsers/return_parser.py +6 -6
- mindspore/rewrite/sparsify/sparse_transformer.py +12 -3
- mindspore/rewrite/sparsify/utils.py +1 -1
- mindspore/rewrite/symbol_tree.py +523 -578
- mindspore/rewrite/symbol_tree_builder.py +9 -193
- mindspore/rewrite/symbol_tree_dumper.py +2 -2
- mindspore/run_check/_check_version.py +6 -4
- mindspore/{ops/bprop_mindir → safeguard}/__init__.py +4 -3
- mindspore/safeguard/rewrite_obfuscation.py +541 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +7 -3
- mindspore/train/amp.py +323 -123
- mindspore/train/anf_ir_pb2.py +14 -2
- mindspore/train/callback/_backup_and_restore.py +2 -12
- mindspore/train/callback/_callback.py +29 -4
- mindspore/train/callback/_checkpoint.py +23 -8
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +3 -4
- mindspore/train/callback/_summary_collector.py +15 -8
- mindspore/train/callback/_time_monitor.py +58 -5
- mindspore/train/data_sink.py +5 -11
- mindspore/train/dataset_helper.py +84 -57
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/__init__.py +3 -3
- mindspore/train/metrics/cosine_similarity.py +1 -1
- mindspore/train/metrics/hausdorff_distance.py +3 -2
- mindspore/train/metrics/mean_surface_distance.py +3 -2
- mindspore/train/metrics/metric.py +39 -19
- mindspore/train/metrics/roc.py +2 -2
- mindspore/train/metrics/root_mean_square_surface_distance.py +4 -3
- mindspore/train/mind_ir_pb2.py +85 -36
- mindspore/train/model.py +187 -47
- mindspore/train/serialization.py +487 -161
- mindspore/train/summary/_summary_adapter.py +1 -1
- mindspore/train/summary/_writer_pool.py +3 -2
- mindspore/train/summary/summary_record.py +37 -17
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/train/train_thor/dataset_helper.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/METADATA +7 -4
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/RECORD +429 -486
- mindspore/_extends/graph_kernel/expander.py +0 -80
- mindspore/_extends/graph_kernel/expanders/__init__.py +0 -54
- mindspore/_extends/graph_kernel/expanders/_utils.py +0 -269
- mindspore/_extends/graph_kernel/expanders/addn.py +0 -33
- mindspore/_extends/graph_kernel/expanders/batchnorm.py +0 -152
- mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py +0 -105
- mindspore/_extends/graph_kernel/expanders/clip_by_norm_no_div_sum.py +0 -33
- mindspore/_extends/graph_kernel/expanders/complex/abs.py +0 -30
- mindspore/_extends/graph_kernel/expanders/complex/add.py +0 -44
- mindspore/_extends/graph_kernel/expanders/complex/div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/mul.py +0 -52
- mindspore/_extends/graph_kernel/expanders/complex/real_div.py +0 -62
- mindspore/_extends/graph_kernel/expanders/complex/sub.py +0 -45
- mindspore/_extends/graph_kernel/expanders/conv2d.py +0 -200
- mindspore/_extends/graph_kernel/expanders/dropout_grad.py +0 -30
- mindspore/_extends/graph_kernel/expanders/equal_count.py +0 -50
- mindspore/_extends/graph_kernel/expanders/erfc.py +0 -35
- mindspore/_extends/graph_kernel/expanders/expand_dims.py +0 -50
- mindspore/_extends/graph_kernel/expanders/fused_adam.py +0 -44
- mindspore/_extends/graph_kernel/expanders/fused_adam_weight_decay.py +0 -47
- mindspore/_extends/graph_kernel/expanders/fused_mul_add.py +0 -28
- mindspore/_extends/graph_kernel/expanders/gelu_grad.py +0 -70
- mindspore/_extends/graph_kernel/expanders/gkdropout.py +0 -40
- mindspore/_extends/graph_kernel/expanders/identity.py +0 -25
- mindspore/_extends/graph_kernel/expanders/layernorm.py +0 -93
- mindspore/_extends/graph_kernel/expanders/layernorm_grad.py +0 -113
- mindspore/_extends/graph_kernel/expanders/logsoftmax.py +0 -46
- mindspore/_extends/graph_kernel/expanders/logsoftmax_grad.py +0 -36
- mindspore/_extends/graph_kernel/expanders/matmul.py +0 -80
- mindspore/_extends/graph_kernel/expanders/maximum_grad.py +0 -59
- mindspore/_extends/graph_kernel/expanders/minimum_grad.py +0 -80
- mindspore/_extends/graph_kernel/expanders/oneslike.py +0 -26
- mindspore/_extends/graph_kernel/expanders/reduce_mean.py +0 -43
- mindspore/_extends/graph_kernel/expanders/relu_grad.py +0 -32
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits.py +0 -41
- mindspore/_extends/graph_kernel/expanders/sigmoid_cross_entropy_with_logits_grad.py +0 -35
- mindspore/_extends/graph_kernel/expanders/sigmoid_grad.py +0 -31
- mindspore/_extends/graph_kernel/expanders/slice.py +0 -35
- mindspore/_extends/graph_kernel/expanders/softmax_cross_entropy_with_logits.py +0 -42
- mindspore/_extends/graph_kernel/expanders/softmax_grad_ext.py +0 -41
- mindspore/_extends/graph_kernel/expanders/softsign.py +0 -28
- mindspore/_extends/graph_kernel/expanders/sqrt_grad.py +0 -29
- mindspore/_extends/graph_kernel/expanders/square_sum_all.py +0 -44
- mindspore/_extends/graph_kernel/expanders/square_sum_v1.py +0 -37
- mindspore/_extends/graph_kernel/expanders/squared_difference.py +0 -43
- mindspore/_extends/graph_kernel/expanders/tanh_grad.py +0 -31
- mindspore/_extends/graph_kernel/model/op_infer.py +0 -506
- mindspore/dataset/datapreprocess/__init__.py +0 -20
- mindspore/dataset/datapreprocess/preprocess_imagenet_validate_dataset.py +0 -54
- mindspore/include/api/net.h +0 -142
- mindspore/nn/lr_scheduler.py +0 -262
- mindspore/ops/_grad_experimental/grad_image_ops.py +0 -248
- mindspore/ops/_grad_experimental/grad_linalg_ops.py +0 -181
- mindspore/ops/_grad_experimental/grad_other_ops.py +0 -72
- mindspore/ops/_grad_experimental/grad_scalar_ops.py +0 -112
- mindspore/ops/_grad_experimental/grad_sequence_ops.py +0 -351
- mindspore/ops/_op_impl/_custom_op/flash_attention/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/attention.py +0 -350
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_bwd.py +0 -409
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_fwd.py +0 -578
- mindspore/ops/_op_impl/_custom_op/flash_attention/flash_attention_impl.py +0 -199
- mindspore/ops/_op_impl/_custom_op/flash_attention/tik_ops_utils.py +0 -446
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/__init__.py +0 -0
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/sparse_tiling.py +0 -45
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/strategy.py +0 -67
- mindspore/ops/_op_impl/_custom_op/flash_attention/tiling_strategy/wukong_tiling.py +0 -62
- mindspore/ops/bprop_mindir/BNTrainingReduce_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Broadcast_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Depend_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/DepthwiseConv2dNative_bprop.mindir +0 -138
- mindspore/ops/bprop_mindir/EmbeddingLookup_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Load_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/ScatterNonAliasingAdd_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseGatherV2_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/SparseSoftmaxCrossEntropyWithLogits_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Switch_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TransShape_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/TupleGetItem_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unique_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/Unstack_bprop.mindir +0 -0
- mindspore/ops/bprop_mindir/generate_mindir.py +0 -114
- mindspore/rewrite/node_visitor.py +0 -44
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/WHEEL +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/entry_points.txt +0 -0
- {mindspore-2.1.0.dist-info → mindspore-2.2.11.dist-info}/top_level.txt +0 -0
|
@@ -1,105 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for BatchNormGrad"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
from .expand_dims import ExpandDims
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
22
|
-
@VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
23
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
|
24
|
-
@VLD.check_attrs('is_training', 'epsilon')
|
|
25
|
-
class BatchNormGrad(Expander):
|
|
26
|
-
"""BatchNormGrad expander"""
|
|
27
|
-
|
|
28
|
-
def _expand(self, graph_builder):
|
|
29
|
-
# get op info
|
|
30
|
-
input_dy = self.inputs[0]
|
|
31
|
-
input_x = self.inputs[1]
|
|
32
|
-
input_scale = self.inputs[2]
|
|
33
|
-
input_save_mean = self.inputs[3]
|
|
34
|
-
input_save_inv_variance = self.inputs[4]
|
|
35
|
-
|
|
36
|
-
reduce_axis = ()
|
|
37
|
-
shape_x = input_x.shape
|
|
38
|
-
if input_x.data_format == DF.NHWC:
|
|
39
|
-
reduce_axis = (0, 1, 2)
|
|
40
|
-
num = shape_x[0] * shape_x[1] * shape_x[2]
|
|
41
|
-
else:
|
|
42
|
-
reduce_axis = (0, 2, 3)
|
|
43
|
-
num = shape_x[0] * shape_x[2] * shape_x[3]
|
|
44
|
-
ori_type = input_x.dtype
|
|
45
|
-
if ori_type == 'float16':
|
|
46
|
-
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
|
|
47
|
-
if input_dy.dtype == 'float16':
|
|
48
|
-
input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'})
|
|
49
|
-
num_rec = -1.0 / num
|
|
50
|
-
num_rec_v = graph_builder.value(input_scale.dtype, num_rec)
|
|
51
|
-
dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
|
52
|
-
|
|
53
|
-
# in training input_save_inv_variance means 1 / sqrt(variance + epsilon), which is calculated in forward pass
|
|
54
|
-
if self.attrs['is_training']:
|
|
55
|
-
inv_variance = input_save_inv_variance
|
|
56
|
-
else:
|
|
57
|
-
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'])
|
|
58
|
-
var_add = graph_builder.emit('Add', [input_save_inv_variance, epsilon_v])
|
|
59
|
-
sqrt_var_eps = graph_builder.emit('Sqrt', [var_add])
|
|
60
|
-
scalar_one = 1.0
|
|
61
|
-
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one)
|
|
62
|
-
inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps])
|
|
63
|
-
|
|
64
|
-
# compute dgamma
|
|
65
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
66
|
-
input_save_mean = graph_builder.emit(
|
|
67
|
-
'Reshape', [input_save_mean], attrs={'shape': ExpandDims.infer_shape(input_save_mean.shape, [-1, -1])})
|
|
68
|
-
inv_variance = graph_builder.emit(
|
|
69
|
-
'Reshape', [inv_variance], attrs={'shape': ExpandDims.infer_shape(inv_variance.shape, [-1, -1])})
|
|
70
|
-
input_scale = graph_builder.emit(
|
|
71
|
-
'Reshape', [input_scale], attrs={'shape': ExpandDims.infer_shape(input_scale.shape, [-1, -1])})
|
|
72
|
-
x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean])
|
|
73
|
-
x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance])
|
|
74
|
-
dgamma_param = graph_builder.emit('Mul', [input_dy, x_div])
|
|
75
|
-
dgamma = graph_builder.emit(
|
|
76
|
-
'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
|
77
|
-
|
|
78
|
-
# compute dx
|
|
79
|
-
if self.attrs['is_training']:
|
|
80
|
-
tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta])
|
|
81
|
-
if input_x.data_format in (DF.DEFAULT, DF.NCHW):
|
|
82
|
-
dgamma_expand = graph_builder.emit(
|
|
83
|
-
'Reshape', [dgamma], attrs={'shape': ExpandDims.infer_shape(dgamma.shape, [-1, -1])})
|
|
84
|
-
tmp_b = graph_builder.emit(
|
|
85
|
-
'Reshape', [tmp_b], attrs={'shape': ExpandDims.infer_shape(tmp_b.shape, [-1, -1])})
|
|
86
|
-
else:
|
|
87
|
-
dgamma_expand = dgamma
|
|
88
|
-
x_sub_mean_dgamma_mul = graph_builder.emit('Mul', [x_div, dgamma_expand])
|
|
89
|
-
tmp_c = graph_builder.emit('Mul', [num_rec_v, x_sub_mean_dgamma_mul])
|
|
90
|
-
tmp_ab_add = graph_builder.emit('Add', [input_dy, tmp_b])
|
|
91
|
-
tmp_abc_add = graph_builder.emit('Add', [tmp_ab_add, tmp_c])
|
|
92
|
-
gamma_mul = graph_builder.emit('Mul', [input_scale, tmp_abc_add])
|
|
93
|
-
dx = graph_builder.emit('Mul', [inv_variance, gamma_mul])
|
|
94
|
-
else:
|
|
95
|
-
y_scale = graph_builder.emit('Mul', [input_scale, input_dy])
|
|
96
|
-
dx = graph_builder.emit('Mul', [inv_variance, y_scale])
|
|
97
|
-
if ori_type == 'float16':
|
|
98
|
-
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
|
|
99
|
-
|
|
100
|
-
# set output tensors' data_format
|
|
101
|
-
dx.data_format = self.outputs[0]['format']
|
|
102
|
-
dgamma.data_format = self.outputs[1]['format']
|
|
103
|
-
dbeta.data_format = self.outputs[2]['format']
|
|
104
|
-
|
|
105
|
-
return dx, dgamma, dbeta
|
|
@@ -1,33 +0,0 @@
|
|
|
1
|
-
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for ClipByNormNoDivSum"""
|
|
16
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
@VLD.check_all_formats_same
|
|
20
|
-
class ClipByNormNoDivSum(Expander):
|
|
21
|
-
"""ClipByNormNoDivSum expander"""
|
|
22
|
-
|
|
23
|
-
def _expand(self, graph_builder):
|
|
24
|
-
input_x0, input_x1, input_x2, input_x3 = self.inputs
|
|
25
|
-
|
|
26
|
-
# cal result
|
|
27
|
-
greater_res = graph_builder.emit('Greater', [input_x0, input_x1])
|
|
28
|
-
select_res0 = graph_builder.emit('Select', [greater_res, input_x0, input_x2])
|
|
29
|
-
sqrt_res = graph_builder.emit('Sqrt', [select_res0])
|
|
30
|
-
select_res1 = graph_builder.emit('Select', [greater_res, sqrt_res, input_x0])
|
|
31
|
-
result = graph_builder.emit('Maximum', [select_res1, input_x3])
|
|
32
|
-
|
|
33
|
-
return result
|
|
@@ -1,30 +0,0 @@
|
|
|
1
|
-
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for cabs"""
|
|
16
|
-
from mindspore._extends.graph_kernel.expanders._utils import Expander
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class CAbs(Expander):
|
|
20
|
-
"""CAbs expander"""
|
|
21
|
-
|
|
22
|
-
def _expand(self, graph_builder):
|
|
23
|
-
input_x = self.inputs[0]
|
|
24
|
-
x_real = graph_builder.emit('CReal', [input_x])
|
|
25
|
-
x_imag = graph_builder.emit('CImag', [input_x])
|
|
26
|
-
squre_x_real = graph_builder.emit('Mul', [x_real, x_real])
|
|
27
|
-
squre_x_imag = graph_builder.emit('Mul', [x_imag, x_imag])
|
|
28
|
-
squre_sum = graph_builder.emit('Add', [squre_x_real, squre_x_imag])
|
|
29
|
-
result = graph_builder.emit('Sqrt', [squre_sum])
|
|
30
|
-
return result
|
|
@@ -1,44 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for cadd"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
|
21
|
-
class CAdd(Expander):
|
|
22
|
-
"""CAdd expander"""
|
|
23
|
-
|
|
24
|
-
def _expand(self, graph_builder):
|
|
25
|
-
input_x, input_y = self.inputs
|
|
26
|
-
if input_x.dtype == input_y.dtype:
|
|
27
|
-
x_real = graph_builder.emit('CReal', [input_x])
|
|
28
|
-
y_real = graph_builder.emit('CReal', [input_y])
|
|
29
|
-
x_imag = graph_builder.emit('CImag', [input_x])
|
|
30
|
-
y_imag = graph_builder.emit('CImag', [input_y])
|
|
31
|
-
result_real = graph_builder.emit('Add', [x_real, y_real])
|
|
32
|
-
result_imag = graph_builder.emit('Add', [x_imag, y_imag])
|
|
33
|
-
result = graph_builder.emit('Complex', [result_real, result_imag])
|
|
34
|
-
elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
|
|
35
|
-
x_real = graph_builder.emit('CReal', [input_x])
|
|
36
|
-
x_imag = graph_builder.emit('CImag', [input_x])
|
|
37
|
-
x_real_add_y = graph_builder.emit('Add', [x_real, input_y])
|
|
38
|
-
result = graph_builder.emit('Complex', [x_real_add_y, x_imag])
|
|
39
|
-
elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
|
|
40
|
-
y_real = graph_builder.emit('CReal', [input_y])
|
|
41
|
-
y_imag = graph_builder.emit('CImag', [input_y])
|
|
42
|
-
y_real_add_x = graph_builder.emit('Add', [y_real, input_x])
|
|
43
|
-
result = graph_builder.emit('Complex', [y_real_add_x, y_imag])
|
|
44
|
-
return result
|
|
@@ -1,62 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for cdiv"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
|
21
|
-
class CDiv(Expander):
|
|
22
|
-
"""CDiv expander"""
|
|
23
|
-
|
|
24
|
-
def _expand(self, graph_builder):
|
|
25
|
-
"""CDiv Implementation"""
|
|
26
|
-
input_x, input_y = self.inputs
|
|
27
|
-
if input_x.dtype == input_y.dtype:
|
|
28
|
-
x_real = graph_builder.emit('CReal', [input_x])
|
|
29
|
-
y_real = graph_builder.emit('CReal', [input_y])
|
|
30
|
-
x_imag = graph_builder.emit('CImag', [input_x])
|
|
31
|
-
y_imag = graph_builder.emit('CImag', [input_y])
|
|
32
|
-
squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
|
|
33
|
-
squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
|
|
34
|
-
final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
|
|
35
|
-
x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
|
|
36
|
-
x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
|
|
37
|
-
x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
|
|
38
|
-
x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
|
|
39
|
-
final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag])
|
|
40
|
-
final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag])
|
|
41
|
-
result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator])
|
|
42
|
-
result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator])
|
|
43
|
-
result = graph_builder.emit('Complex', [result_real, result_imag])
|
|
44
|
-
elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
|
|
45
|
-
x_real = graph_builder.emit('CReal', [input_x])
|
|
46
|
-
x_imag = graph_builder.emit('CImag', [input_x])
|
|
47
|
-
x_real_div_y = graph_builder.emit('RealDiv', [x_real, input_y])
|
|
48
|
-
x_imag_div_y = graph_builder.emit('RealDiv', [x_imag, input_y])
|
|
49
|
-
result = graph_builder.emit('Complex', [x_real_div_y, x_imag_div_y])
|
|
50
|
-
elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
|
|
51
|
-
y_real = graph_builder.emit('CReal', [input_y])
|
|
52
|
-
y_imag = graph_builder.emit('CImag', [input_y])
|
|
53
|
-
neg_y_imag = graph_builder.emit('Neg', [y_imag])
|
|
54
|
-
squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
|
|
55
|
-
squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
|
|
56
|
-
final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
|
|
57
|
-
x_mul_y_real = graph_builder.emit('Mul', [input_x, y_real])
|
|
58
|
-
x_mul_neg_y_imag = graph_builder.emit('Mul', [input_x, neg_y_imag])
|
|
59
|
-
y_real_div_x = graph_builder.emit('RealDiv', [x_mul_y_real, final_denominator])
|
|
60
|
-
y_imag_div_x = graph_builder.emit('RealDiv', [x_mul_neg_y_imag, final_denominator])
|
|
61
|
-
result = graph_builder.emit('Complex', [y_real_div_x, y_imag_div_x])
|
|
62
|
-
return result
|
|
@@ -1,52 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for cmul"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
|
21
|
-
class CMul(Expander):
|
|
22
|
-
"""CMul expander"""
|
|
23
|
-
|
|
24
|
-
def _expand(self, graph_builder):
|
|
25
|
-
"""CMul Implementation"""
|
|
26
|
-
input_x, input_y = self.inputs
|
|
27
|
-
if input_x.dtype == input_y.dtype:
|
|
28
|
-
x_real = graph_builder.emit('CReal', [input_x])
|
|
29
|
-
y_real = graph_builder.emit('CReal', [input_y])
|
|
30
|
-
x_imag = graph_builder.emit('CImag', [input_x])
|
|
31
|
-
y_imag = graph_builder.emit('CImag', [input_y])
|
|
32
|
-
x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
|
|
33
|
-
x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
|
|
34
|
-
x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
|
|
35
|
-
x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
|
|
36
|
-
result_real = graph_builder.emit('Sub', [x_real_mul_y_real, x_imag_mul_y_imag])
|
|
37
|
-
result_imag = graph_builder.emit('Add', [x_real_mul_y_imag, x_imag_mul_y_real])
|
|
38
|
-
result = graph_builder.emit('Complex', [result_real, result_imag])
|
|
39
|
-
elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
|
|
40
|
-
x_real = graph_builder.emit('CReal', [input_x])
|
|
41
|
-
x_imag = graph_builder.emit('CImag', [input_x])
|
|
42
|
-
x_real_mul_y = graph_builder.emit('Mul', [x_real, input_y])
|
|
43
|
-
x_imag_mul_y = graph_builder.emit('Mul', [x_imag, input_y])
|
|
44
|
-
result = graph_builder.emit('Complex', [x_real_mul_y, x_imag_mul_y])
|
|
45
|
-
elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
|
|
46
|
-
y_real = graph_builder.emit('CReal', [input_y])
|
|
47
|
-
y_imag = graph_builder.emit('CImag', [input_y])
|
|
48
|
-
y_real_mul_x = graph_builder.emit('Mul', [y_real, input_x])
|
|
49
|
-
y_imag_mul_x = graph_builder.emit('Mul', [y_imag, input_x])
|
|
50
|
-
result = graph_builder.emit('Complex', [y_real_mul_x, y_imag_mul_x])
|
|
51
|
-
|
|
52
|
-
return result
|
|
@@ -1,62 +0,0 @@
|
|
|
1
|
-
# Copyright 2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for crealdiv"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
|
21
|
-
class CRealDiv(Expander):
|
|
22
|
-
"""CRealDiv expander"""
|
|
23
|
-
|
|
24
|
-
def _expand(self, graph_builder):
|
|
25
|
-
"""CRealDiv Implementation"""
|
|
26
|
-
input_x, input_y = self.inputs
|
|
27
|
-
if input_x.dtype == input_y.dtype:
|
|
28
|
-
x_real = graph_builder.emit('CReal', [input_x])
|
|
29
|
-
y_real = graph_builder.emit('CReal', [input_y])
|
|
30
|
-
x_imag = graph_builder.emit('CImag', [input_x])
|
|
31
|
-
y_imag = graph_builder.emit('CImag', [input_y])
|
|
32
|
-
squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
|
|
33
|
-
squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
|
|
34
|
-
final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
|
|
35
|
-
x_real_mul_y_real = graph_builder.emit('Mul', [x_real, y_real])
|
|
36
|
-
x_imag_mul_y_imag = graph_builder.emit('Mul', [x_imag, y_imag])
|
|
37
|
-
x_real_mul_y_imag = graph_builder.emit('Mul', [x_real, y_imag])
|
|
38
|
-
x_imag_mul_y_real = graph_builder.emit('Mul', [x_imag, y_real])
|
|
39
|
-
final_numerator_real = graph_builder.emit('Add', [x_real_mul_y_real, x_imag_mul_y_imag])
|
|
40
|
-
final_numerator_imag = graph_builder.emit('Sub', [x_imag_mul_y_real, x_real_mul_y_imag])
|
|
41
|
-
result_real = graph_builder.emit('RealDiv', [final_numerator_real, final_denominator])
|
|
42
|
-
result_imag = graph_builder.emit('RealDiv', [final_numerator_imag, final_denominator])
|
|
43
|
-
result = graph_builder.emit('Complex', [result_real, result_imag])
|
|
44
|
-
elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
|
|
45
|
-
x_real = graph_builder.emit('CReal', [input_x])
|
|
46
|
-
x_imag = graph_builder.emit('CImag', [input_x])
|
|
47
|
-
x_real_div_y = graph_builder.emit('RealDiv', [x_real, input_y])
|
|
48
|
-
x_imag_div_y = graph_builder.emit('RealDiv', [x_imag, input_y])
|
|
49
|
-
result = graph_builder.emit('Complex', [x_real_div_y, x_imag_div_y])
|
|
50
|
-
elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
|
|
51
|
-
y_real = graph_builder.emit('CReal', [input_y])
|
|
52
|
-
y_imag = graph_builder.emit('CImag', [input_y])
|
|
53
|
-
neg_y_imag = graph_builder.emit('Neg', [y_imag])
|
|
54
|
-
squre_y_real = graph_builder.emit('Mul', [y_real, y_real])
|
|
55
|
-
squre_y_imag = graph_builder.emit('Mul', [y_imag, y_imag])
|
|
56
|
-
final_denominator = graph_builder.emit('Add', [squre_y_real, squre_y_imag])
|
|
57
|
-
x_mul_y_real = graph_builder.emit('Mul', [input_x, y_real])
|
|
58
|
-
x_mul_neg_y_imag = graph_builder.emit('Mul', [input_x, neg_y_imag])
|
|
59
|
-
y_real_div_x = graph_builder.emit('RealDiv', [x_mul_y_real, final_denominator])
|
|
60
|
-
y_imag_div_x = graph_builder.emit('RealDiv', [x_mul_neg_y_imag, final_denominator])
|
|
61
|
-
result = graph_builder.emit('Complex', [y_real_div_x, y_imag_div_x])
|
|
62
|
-
return result
|
|
@@ -1,45 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for csub"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
17
|
-
from mindspore._extends.graph_kernel.expanders._utils import Expander, ExpanderInfoValidator as VLD
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
|
21
|
-
class CSub(Expander):
|
|
22
|
-
"""CSub expander"""
|
|
23
|
-
|
|
24
|
-
def _expand(self, graph_builder):
|
|
25
|
-
input_x, input_y = self.inputs
|
|
26
|
-
if input_x.dtype == input_y.dtype:
|
|
27
|
-
x_real = graph_builder.emit('CReal', [input_x])
|
|
28
|
-
y_real = graph_builder.emit('CReal', [input_y])
|
|
29
|
-
x_imag = graph_builder.emit('CImag', [input_x])
|
|
30
|
-
y_imag = graph_builder.emit('CImag', [input_y])
|
|
31
|
-
result_real = graph_builder.emit('Sub', [x_real, y_real])
|
|
32
|
-
result_imag = graph_builder.emit('Sub', [x_imag, y_imag])
|
|
33
|
-
result = graph_builder.emit('Complex', [result_real, result_imag])
|
|
34
|
-
elif input_x.dtype == "complex64" or input_x.dtype == "complex128":
|
|
35
|
-
x_real = graph_builder.emit('CReal', [input_x])
|
|
36
|
-
x_imag = graph_builder.emit('CImag', [input_x])
|
|
37
|
-
x_real_sub_y = graph_builder.emit('Sub', [x_real, input_y])
|
|
38
|
-
result = graph_builder.emit('Complex', [x_real_sub_y, x_imag])
|
|
39
|
-
elif input_y.dtype == "complex64" or input_y.dtype == "complex128":
|
|
40
|
-
y_real = graph_builder.emit('CReal', [input_y])
|
|
41
|
-
y_imag = graph_builder.emit('CImag', [input_y])
|
|
42
|
-
x_sub_y_real = graph_builder.emit('Sub', [input_x, y_real])
|
|
43
|
-
y_imag = graph_builder.emit('Neg', [y_imag])
|
|
44
|
-
result = graph_builder.emit('Complex', [x_sub_y_real, y_imag])
|
|
45
|
-
return result
|
|
@@ -1,200 +0,0 @@
|
|
|
1
|
-
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
|
2
|
-
#
|
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
-
# you may not use this file except in compliance with the License.
|
|
5
|
-
# You may obtain a copy of the License at
|
|
6
|
-
#
|
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
-
#
|
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
-
# See the License for the specific language governing permissions and
|
|
13
|
-
# limitations under the License.
|
|
14
|
-
# ===========================================================================
|
|
15
|
-
"""generate json desc for Conv2D"""
|
|
16
|
-
from mindspore._extends.graph_kernel.model.op_infer import check_format_any, check_nd, conv_had_pad
|
|
17
|
-
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
|
18
|
-
from mindspore._extends.graph_kernel.model.model import GraphKernelUnsupportedException as GKException
|
|
19
|
-
from ._utils import Expander, ExpanderInfoValidator as VLD
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
@VLD.add_format(DF.DEFAULT, DF.DEFAULT)
|
|
23
|
-
@VLD.add_format(DF.NHWC, DF.NHWC)
|
|
24
|
-
@VLD.check_attrs('format', 'pad_list', 'pad_mode', 'groups', 'group', 'kernel_size', 'stride', 'dilation')
|
|
25
|
-
class Conv2D(Expander):
|
|
26
|
-
"""
|
|
27
|
-
Conv2D expander
|
|
28
|
-
|
|
29
|
-
Currently, only Conv2D that meets several conditions can be expanded, other cases will be skipped.
|
|
30
|
-
Conditions to expand:
|
|
31
|
-
inputs are NHWC format and float16.
|
|
32
|
-
attr groups and group are 1.
|
|
33
|
-
attr dilation are all 1.
|
|
34
|
-
N channel of inputs > 16.
|
|
35
|
-
C channel of inputs > 8.
|
|
36
|
-
output N*H*W are multiplies of 128.
|
|
37
|
-
"""
|
|
38
|
-
M_ALIGN = 32
|
|
39
|
-
N_ALIGN = 32
|
|
40
|
-
K_ALIGN = 16
|
|
41
|
-
K_LIMIT = 800
|
|
42
|
-
MNK_LIMIT = 3 * (10 ** 10)
|
|
43
|
-
N0_CHANNEL_ALIGN = 32
|
|
44
|
-
N1_CHANNEL_ALIGN = 32
|
|
45
|
-
C_CHANNEL_ALIGN = 16
|
|
46
|
-
OUT_NHW_ALIGN = 128
|
|
47
|
-
|
|
48
|
-
def __init__(self, expand_info):
|
|
49
|
-
super().__init__(expand_info)
|
|
50
|
-
self.dst_type = self.outputs[0]['data_type']
|
|
51
|
-
self.dst_format = self.outputs[0]['format']
|
|
52
|
-
self.has_pad = False
|
|
53
|
-
self.can_optimize_to_matmul = False
|
|
54
|
-
self.shape_0_pad = self.inputs[0]['shape']
|
|
55
|
-
self.shape_1_pad = self.inputs[1]['shape']
|
|
56
|
-
self.m = 0
|
|
57
|
-
self.n = 0
|
|
58
|
-
self.k = 0
|
|
59
|
-
|
|
60
|
-
def _optimize_to_matmul(self):
|
|
61
|
-
stride = self.attrs['stride']
|
|
62
|
-
dilation = self.attrs['dilation']
|
|
63
|
-
_, h, w, _ = self.inputs[1]['shape']
|
|
64
|
-
if h == 1 and w == 1 and stride == [1, 1, 1, 1] and dilation == [1, 1, 1, 1] and \
|
|
65
|
-
self.m % self.M_ALIGN == 0 and self.n % self.N_ALIGN == 0 and self.k % self.K_ALIGN == 0:
|
|
66
|
-
return True
|
|
67
|
-
return False
|
|
68
|
-
|
|
69
|
-
def _common_check(self):
|
|
70
|
-
"""common check for inputs and attrs"""
|
|
71
|
-
type_0 = self.inputs[0]['data_type']
|
|
72
|
-
type_1 = self.inputs[1]['data_type']
|
|
73
|
-
if type_0 != "float16" or type_1 != "float16":
|
|
74
|
-
raise GKException("For 'Conv2D', inputs data type should be both float16, but got {} and {}"
|
|
75
|
-
.format(type_0, type_1))
|
|
76
|
-
|
|
77
|
-
formats = [self.inputs[0]['format'], self.inputs[1]['format'], self.attrs['format']]
|
|
78
|
-
check_format_any(formats, DF.NHWC)
|
|
79
|
-
|
|
80
|
-
groups = self.attrs['groups']
|
|
81
|
-
group = self.attrs['group']
|
|
82
|
-
if groups != 1 or group != 1:
|
|
83
|
-
raise GKException("For 'Conv2D', value of attr 'groups' and 'group' should be both 1, but got {} and {}."
|
|
84
|
-
.format(groups, group))
|
|
85
|
-
|
|
86
|
-
dilation = self.attrs['dilation']
|
|
87
|
-
check_nd(dilation, 4)
|
|
88
|
-
if dilation != [1, 1, 1, 1]:
|
|
89
|
-
raise GKException("For 'Conv2D', value of attr 'dilation' should be [1, 1, 1, 1], but got {}"
|
|
90
|
-
.format(dilation))
|
|
91
|
-
|
|
92
|
-
def _check(self):
|
|
93
|
-
self._common_check()
|
|
94
|
-
|
|
95
|
-
pad_list = self.attrs['pad_list']
|
|
96
|
-
check_nd(pad_list, 4)
|
|
97
|
-
self.has_pad = conv_had_pad(pad_list, self.attrs['pad_mode'])
|
|
98
|
-
|
|
99
|
-
shape_0 = self.inputs[0]['shape']
|
|
100
|
-
shape_1 = self.inputs[1]['shape']
|
|
101
|
-
stride = self.attrs['stride']
|
|
102
|
-
check_nd(shape_0, 4)
|
|
103
|
-
check_nd(shape_1, 4)
|
|
104
|
-
check_nd(stride, 4)
|
|
105
|
-
n0, h0, w0, c0 = shape_0
|
|
106
|
-
n1, h1, w1, c1 = shape_1
|
|
107
|
-
if (n0 % self.N0_CHANNEL_ALIGN) != 0:
|
|
108
|
-
raise GKException("For 'Conv2D', N channel of first input should be multiples of {}, but got {}"
|
|
109
|
-
.format(self.N0_CHANNEL_ALIGN, n0))
|
|
110
|
-
if (n1 % self.N1_CHANNEL_ALIGN) != 0:
|
|
111
|
-
raise GKException("For 'Conv2D', N channel of second input should be multiples of {}, but got {}"
|
|
112
|
-
.format(self.N1_CHANNEL_ALIGN, n1))
|
|
113
|
-
if c0 != c1 or (c0 % self.C_CHANNEL_ALIGN) != 0:
|
|
114
|
-
raise GKException("For 'Conv2D', C channel of inputs should be same and also be multiples of {}, but got "
|
|
115
|
-
"{} and {}".format(self.C_CHANNEL_ALIGN, c0, c1))
|
|
116
|
-
# n0 pad
|
|
117
|
-
n0 = ((n0 + self.N0_CHANNEL_ALIGN - 1) //
|
|
118
|
-
self.N0_CHANNEL_ALIGN) * self.N0_CHANNEL_ALIGN
|
|
119
|
-
# h0, w0 pad
|
|
120
|
-
if self.has_pad:
|
|
121
|
-
h0 = h0 + pad_list[0] + pad_list[1]
|
|
122
|
-
w0 = w0 + pad_list[2] + pad_list[3]
|
|
123
|
-
# c0, c1 pad
|
|
124
|
-
c0 = ((c0 + self.C_CHANNEL_ALIGN - 1) // self.C_CHANNEL_ALIGN) * self.C_CHANNEL_ALIGN
|
|
125
|
-
c1 = c0
|
|
126
|
-
# n1 pad
|
|
127
|
-
n1 = ((n1 + self.N1_CHANNEL_ALIGN - 1) //
|
|
128
|
-
self.N1_CHANNEL_ALIGN) * self.N1_CHANNEL_ALIGN
|
|
129
|
-
|
|
130
|
-
# check if can optimize to matmul
|
|
131
|
-
self.m, self.n, self.k = n0 * h0 * w0, n1, c1
|
|
132
|
-
self.can_optimize_to_matmul = self._optimize_to_matmul()
|
|
133
|
-
|
|
134
|
-
# requirements
|
|
135
|
-
if self.can_optimize_to_matmul:
|
|
136
|
-
if self.k > self.K_LIMIT:
|
|
137
|
-
raise GKException("For 'Conv2D', if transformed to 'MatMul', C0 should not be larger than {}, but got "
|
|
138
|
-
"{}".format(self.K_LIMIT, self.k))
|
|
139
|
-
if self.m * self.n * self.k >= self.MNK_LIMIT:
|
|
140
|
-
raise GKException("For 'Conv2D', if transformed to 'MatMul', The total size should not be larger than "
|
|
141
|
-
"{}, but got {}".format(self.MNK_LIMIT, self.m * self.n * self.k))
|
|
142
|
-
else:
|
|
143
|
-
out_h, out_w = (h0 - h1) // stride[-2] + 1, (w0 - w1) // stride[-1] + 1
|
|
144
|
-
if ((n0 * out_h * out_w) % self.OUT_NHW_ALIGN) != 0:
|
|
145
|
-
raise GKException("For 'Conv2D', N({}) * H({}) * W({}) of output should be multiplies of {}"
|
|
146
|
-
.format(n0, out_h, out_w, self.OUT_NHW_ALIGN))
|
|
147
|
-
if stride != [1, 1, 2, 2]:
|
|
148
|
-
raise GKException("For 'Conv2D', value of attr 'stride' should be [1, 1, 2, 2], but got {}"
|
|
149
|
-
.format(stride))
|
|
150
|
-
|
|
151
|
-
self.shape_0_pad = [n0, h0, w0, c0]
|
|
152
|
-
self.shape_1_pad = [n1, h1, w1, c1]
|
|
153
|
-
|
|
154
|
-
def _expand(self, graph_builder):
|
|
155
|
-
input_0 = self.inputs[0]
|
|
156
|
-
input_1 = self.inputs[1]
|
|
157
|
-
n0, _, _, c0 = input_0.shape
|
|
158
|
-
n1, _, _, c1 = input_1.shape
|
|
159
|
-
n0_p, h0_p, w0_p, c0_p = self.shape_0_pad
|
|
160
|
-
n1_p, _, _, c1_p = self.shape_1_pad
|
|
161
|
-
|
|
162
|
-
pad_value = 0
|
|
163
|
-
# input0 pad
|
|
164
|
-
input_0_pad_before = [0, 0, 0, 0]
|
|
165
|
-
input_0_pad_after = [0, 0, 0, 0]
|
|
166
|
-
if self.has_pad:
|
|
167
|
-
pad_list = self.attrs['pad_list']
|
|
168
|
-
input_0_pad_before = [0, pad_list[0], pad_list[2], 0]
|
|
169
|
-
input_0_pad_after = [0, pad_list[1], pad_list[3], 0]
|
|
170
|
-
input_0_pad_after[0] = n0_p - n0
|
|
171
|
-
input_0_pad_after[3] = c0_p - c0
|
|
172
|
-
if input_0_pad_before != [0, 0, 0, 0] or input_0_pad_after != [0, 0, 0, 0]:
|
|
173
|
-
input_0 = graph_builder.emit('PadAkg', [input_0], attrs={'head': input_0_pad_before,
|
|
174
|
-
'tail': input_0_pad_after,
|
|
175
|
-
'pad_val': pad_value})
|
|
176
|
-
# input1 pad
|
|
177
|
-
input_1_pad_after = [n1_p - n1, 0, 0, c1_p - c1]
|
|
178
|
-
if input_1_pad_after != [0, 0, 0, 0]:
|
|
179
|
-
input_1 = graph_builder.emit('PadAkg', [input_1], attrs={'head': [0, 0, 0, 0],
|
|
180
|
-
'tail': input_1_pad_after,
|
|
181
|
-
'pad_val': pad_value})
|
|
182
|
-
if self.can_optimize_to_matmul:
|
|
183
|
-
a = graph_builder.emit('Reshape', [input_0], attrs={'shape': [self.m, self.k]})
|
|
184
|
-
b = graph_builder.emit('Reshape', [input_1], attrs={'shape': [self.n, self.k]})
|
|
185
|
-
c = graph_builder.emit('MatMul', [a, b], attrs={'transpose_a': False,
|
|
186
|
-
'transpose_b': True,
|
|
187
|
-
'dst_type': self.dst_type})
|
|
188
|
-
result = graph_builder.emit('Reshape', [c], attrs={'shape': [n0_p, h0_p, w0_p, n1_p],
|
|
189
|
-
'format': self.dst_format})
|
|
190
|
-
else:
|
|
191
|
-
attrs = self.attrs
|
|
192
|
-
attrs['pad_list'] = [0, 0, 0, 0]
|
|
193
|
-
attrs['dst_type'] = self.dst_type
|
|
194
|
-
result = graph_builder.emit('Conv2D', [input_0, input_1], attrs=attrs)
|
|
195
|
-
# unpad
|
|
196
|
-
unpad_after = [input_0_pad_after[0], 0, 0, input_1_pad_after[0]]
|
|
197
|
-
if unpad_after != [0, 0, 0, 0]:
|
|
198
|
-
result = graph_builder.emit('UnPadAkg', [result], attrs={'tail': unpad_after})
|
|
199
|
-
|
|
200
|
-
return result
|