mindspore 2.4.1__cp310-cp310-win_amd64.whl → 2.5.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +8 -3
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +0 -5
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/compile_config.py +64 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
- mindspore/_extends/parse/parser.py +23 -5
- mindspore/_extends/parse/standard_method.py +123 -27
- mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
- mindspore/amp.py +7 -1
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/boost_cell_wrapper.py +136 -41
- mindspore/common/__init__.py +3 -1
- mindspore/common/_register_for_tensor.py +0 -1
- mindspore/common/_stub_tensor.py +25 -4
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +6132 -0
- mindspore/common/api.py +99 -25
- mindspore/common/dtype.py +34 -34
- mindspore/common/dump.py +2 -1
- mindspore/common/file_system.py +8 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +3 -1
- mindspore/common/initializer.py +3 -4
- mindspore/common/lazy_inline.py +8 -2
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/parameter.py +30 -27
- mindspore/common/tensor.py +713 -1337
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +10 -0
- mindspore/communication/comm_func.py +215 -173
- mindspore/communication/management.py +23 -20
- mindspore/context.py +292 -193
- mindspore/dataset/__init__.py +23 -19
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +84 -3
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +5 -4
- mindspore/dataset/engine/datasets.py +192 -149
- mindspore/dataset/engine/datasets_audio.py +14 -0
- mindspore/dataset/engine/datasets_standard_format.py +28 -11
- mindspore/dataset/engine/datasets_text.py +38 -1
- mindspore/dataset/engine/datasets_user_defined.py +125 -65
- mindspore/dataset/engine/datasets_vision.py +81 -8
- mindspore/dataset/engine/iterators.py +281 -63
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +26 -2
- mindspore/dataset/engine/serializer_deserializer.py +1 -1
- mindspore/dataset/engine/validators.py +43 -11
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +29 -12
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +94 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +127 -0
- mindspore/device_context/cpu/__init__.py +25 -0
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +134 -0
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/__init__.py +3 -2
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +239 -64
- mindspore/experimental/llm_boost/atb/llama_boost.py +52 -30
- mindspore/experimental/llm_boost/atb/qwen_boost.py +47 -24
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/optim/adadelta.py +26 -22
- mindspore/experimental/optim/adam.py +3 -0
- mindspore/experimental/optim/lr_scheduler.py +33 -24
- mindspore/experimental/optim/radam.py +33 -30
- mindspore/hal/device.py +28 -0
- mindspore/hal/event.py +17 -0
- mindspore/hal/memory.py +94 -3
- mindspore/hal/stream.py +91 -6
- mindspore/include/api/context.h +1 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +12 -0
- mindspore/mindrecord/__init__.py +1 -1
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +824 -218
- mindspore/mint/distributed/__init__.py +66 -4
- mindspore/mint/distributed/distributed.py +2594 -44
- mindspore/mint/linalg/__init__.py +6 -0
- mindspore/mint/nn/__init__.py +473 -14
- mindspore/mint/nn/functional.py +486 -11
- mindspore/mint/nn/layer/__init__.py +17 -4
- mindspore/mint/nn/layer/_functions.py +330 -0
- mindspore/mint/nn/layer/activation.py +169 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +727 -0
- mindspore/mint/nn/layer/normalization.py +215 -19
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +170 -0
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +2 -0
- mindspore/nn/cell.py +142 -21
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +6 -6
- mindspore/nn/layer/basic.py +35 -25
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/conv.py +3 -0
- mindspore/nn/layer/embedding.py +3 -3
- mindspore/nn/layer/normalization.py +8 -7
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +55 -23
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +48 -26
- mindspore/nn/learning_rate_schedule.py +5 -3
- mindspore/nn/loss/loss.py +31 -36
- mindspore/nn/optim/ada_grad.py +1 -0
- mindspore/nn/optim/adadelta.py +2 -2
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/utils/__init__.py +22 -0
- mindspore/nn/utils/init.py +73 -0
- mindspore/nn/wrap/cell_wrapper.py +4 -6
- mindspore/nn/wrap/loss_scale.py +3 -4
- mindspore/numpy/array_creations.py +60 -62
- mindspore/numpy/array_ops.py +148 -143
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +16 -16
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +2 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +107 -8
- mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_vmap/vmap_array_ops.py +20 -19
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
- mindspore/ops/_vmap/vmap_math_ops.py +11 -9
- mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
- mindspore/ops/auto_generate/gen_extend_func.py +554 -60
- mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
- mindspore/ops/auto_generate/gen_ops_prim.py +8027 -3411
- mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
- mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
- mindspore/ops/function/__init__.py +12 -0
- mindspore/ops/function/array_func.py +561 -159
- mindspore/ops/function/clip_func.py +64 -0
- mindspore/ops/function/debug_func.py +28 -20
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +5 -4
- mindspore/ops/function/math_func.py +1664 -294
- mindspore/ops/function/nn_func.py +988 -317
- mindspore/ops/function/parameter_func.py +3 -56
- mindspore/ops/function/random_func.py +243 -33
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/functional.py +18 -5
- mindspore/ops/functional_overload.py +897 -0
- mindspore/ops/operations/__init__.py +3 -2
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -34
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +38 -8
- mindspore/ops/operations/array_ops.py +45 -303
- mindspore/ops/operations/comm_ops.py +23 -17
- mindspore/ops/operations/custom_ops.py +7 -49
- mindspore/ops/operations/debug_ops.py +42 -47
- mindspore/ops/operations/inner_ops.py +6 -4
- mindspore/ops/operations/linalg_ops.py +3 -2
- mindspore/ops/operations/manually_defined/ops_def.py +185 -104
- mindspore/ops/operations/math_ops.py +11 -216
- mindspore/ops/operations/nn_ops.py +153 -310
- mindspore/ops/primitive.py +23 -21
- mindspore/ops/tensor_method.py +1669 -0
- mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
- mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
- mindspore/ops_generate/arg_handler.py +0 -61
- mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
- mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/base_generator.py +11 -0
- mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
- mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
- mindspore/ops_generate/functional_overload_py_generator.py +110 -0
- mindspore/ops_generate/functions_cc_generator.py +233 -0
- mindspore/ops_generate/gen_aclnn_implement.py +110 -114
- mindspore/ops_generate/gen_constants.py +157 -3
- mindspore/ops_generate/gen_ops.py +245 -990
- mindspore/ops_generate/gen_pyboost_func.py +97 -998
- mindspore/ops_generate/gen_utils.py +119 -33
- mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
- mindspore/ops_generate/op_api_proto.py +206 -0
- mindspore/ops_generate/op_def_py_generator.py +131 -0
- mindspore/ops_generate/op_prim_py_generator.py +480 -0
- mindspore/ops_generate/op_proto.py +373 -108
- mindspore/ops_generate/op_template_parser.py +436 -0
- mindspore/ops_generate/ops_def_cc_generator.py +288 -0
- mindspore/ops_generate/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/ops_name_h_generator.py +68 -0
- mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
- mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
- mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
- mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
- mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
- mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
- mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
- mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
- mindspore/ops_generate/pyboost_utils.py +92 -33
- mindspore/ops_generate/template.py +294 -44
- mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
- mindspore/parallel/__init__.py +3 -3
- mindspore/parallel/_auto_parallel_context.py +44 -34
- mindspore/parallel/_cell_wrapper.py +22 -3
- mindspore/parallel/_parallel_serialization.py +13 -2
- mindspore/parallel/_utils.py +4 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +44 -0
- mindspore/parallel/cluster/process_entity/_api.py +131 -37
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +20 -3
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +3 -0
- mindspore/parallel/transform_safetensors.py +119 -253
- mindspore/profiler/__init__.py +17 -4
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +174 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +202 -0
- mindspore/profiler/common/path_manager.py +371 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +476 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +210 -0
- mindspore/profiler/common/profiler_path_manager.py +120 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +270 -37
- mindspore/profiler/envprofiler.py +138 -0
- mindspore/profiler/mstx.py +199 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +309 -0
- mindspore/profiler/profiler.py +580 -93
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +114 -0
- mindspore/profiler/schedule.py +208 -0
- mindspore/rewrite/api/symbol_tree.py +1 -2
- mindspore/run_check/_check_version.py +18 -13
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +148 -0
- mindspore/runtime/memory.py +392 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +2 -2
- mindspore/train/_utils.py +53 -18
- mindspore/train/amp.py +8 -4
- mindspore/train/callback/_checkpoint.py +32 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +105 -69
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_summary_collector.py +44 -6
- mindspore/train/callback/_tft_register.py +37 -15
- mindspore/train/dataset_helper.py +11 -11
- mindspore/train/metrics/precision.py +4 -5
- mindspore/train/mind_ir_pb2.py +167 -46
- mindspore/train/model.py +13 -14
- mindspore/train/serialization.py +461 -72
- mindspore/train/summary/summary_record.py +1 -2
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +4 -2
- mindspore/utils/dryrun.py +138 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/METADATA +3 -4
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/RECORD +368 -242
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
import mindspore
|
|
2
|
+
from mindspore import Tensor
|
|
3
|
+
from mindspore import context
|
|
4
|
+
from mindspore.nn.cell import Cell
|
|
5
|
+
from mindspore.ops.auto_generate.gen_ops_prim import BatchNormReduceGrad
|
|
6
|
+
from mindspore.ops.auto_generate.gen_ops_prim import BatchNormElemtGrad
|
|
7
|
+
from mindspore.communication import GlobalComm
|
|
8
|
+
from mindspore.ops import ReduceOp
|
|
9
|
+
from mindspore._c_expression import Tensor as Tensor_
|
|
10
|
+
from mindspore.communication._comm_helper import _get_size_helper, HCCL_WORLD_COMM_GROUP
|
|
11
|
+
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
12
|
+
from mindspore.ops import operations as P
|
|
13
|
+
from mindspore import ops, mint
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
|
17
|
+
|
|
18
|
+
batch_norm_reduce_grad = BatchNormReduceGrad()
|
|
19
|
+
batch_norm_elemt_grad = BatchNormElemtGrad()
|
|
20
|
+
shape = P.Shape()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _deal_comm_outputs(output, async_op):
|
|
24
|
+
if isinstance(output, tuple):
|
|
25
|
+
if not async_op:
|
|
26
|
+
output[1].wait()
|
|
27
|
+
return output[0]
|
|
28
|
+
return output
|
|
29
|
+
|
|
30
|
+
if not async_op:
|
|
31
|
+
return output
|
|
32
|
+
return output
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
|
|
36
|
+
if not isinstance(group, str):
|
|
37
|
+
raise TypeError("For 'get_group_size', the argument 'group' must be type of string, "
|
|
38
|
+
"but got 'group' type : {}.".format(type(group)))
|
|
39
|
+
return _get_size_helper(group=_get_group(group))
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _contiguous(tensor):
|
|
43
|
+
if not tensor.is_contiguous() or tensor.storage_offset() != 0:
|
|
44
|
+
tensor = tensor.contiguous()
|
|
45
|
+
return tensor
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _get_group(group):
|
|
49
|
+
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
|
|
50
|
+
if group == DEFAULT_WORLD_COMM_GROUP:
|
|
51
|
+
return GlobalComm.WORLD_COMM_GROUP
|
|
52
|
+
return group
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def all_gather_into_tensor(tensor, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
|
|
56
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
57
|
+
raise TypeError(
|
|
58
|
+
"For all_gather_into_tensor, the input tensor must be tensor")
|
|
59
|
+
group = _get_group(group)
|
|
60
|
+
tensor = _contiguous(tensor)
|
|
61
|
+
all_gather_op = _get_cache_prim(P.AllGather)(group=group)
|
|
62
|
+
output = all_gather_op(tensor)
|
|
63
|
+
return _deal_comm_outputs(output, async_op)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def all_reduce(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
|
|
67
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
68
|
+
raise TypeError("For all_reduce, the input tensor must be tensor")
|
|
69
|
+
if not isinstance(op, str):
|
|
70
|
+
raise TypeError("For all_reduce, the input op type must be str")
|
|
71
|
+
if op not in ('sum', 'prod', 'min', 'max'):
|
|
72
|
+
raise TypeError(
|
|
73
|
+
"For all_reduce, the input op value must be one of sum, prod, min, max")
|
|
74
|
+
group = _get_group(group)
|
|
75
|
+
tensor = _contiguous(tensor)
|
|
76
|
+
all_reduce_op = _get_cache_prim(P.AllReduce)(op=op, group=group)
|
|
77
|
+
output = all_reduce_op(tensor)
|
|
78
|
+
return _deal_comm_outputs(output, async_op)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def bprop_pynative(input_x, weight, bias, running_mean, running_var, eps, momentum,
|
|
82
|
+
process_group, world_size, output, doutput):
|
|
83
|
+
_, mean_param, invstd_param, count_all_param = output
|
|
84
|
+
dout, _, _, _ = doutput
|
|
85
|
+
|
|
86
|
+
# 不支持 KBK模式
|
|
87
|
+
if not dout.is_contiguous():
|
|
88
|
+
dout = dout.contiguous()
|
|
89
|
+
|
|
90
|
+
grad_input = grad_weight = grad_bias = None
|
|
91
|
+
|
|
92
|
+
inputG = True
|
|
93
|
+
weightG = True
|
|
94
|
+
biasG = True
|
|
95
|
+
|
|
96
|
+
# calculate local stats as well as grad_weight / grad_bias
|
|
97
|
+
sum_dy, sum_dy_xmu, grad_weight, grad_bias = batch_norm_reduce_grad(
|
|
98
|
+
dout,
|
|
99
|
+
input_x,
|
|
100
|
+
mean_param,
|
|
101
|
+
invstd_param,
|
|
102
|
+
weight,
|
|
103
|
+
inputG,
|
|
104
|
+
weightG,
|
|
105
|
+
biasG
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if inputG:
|
|
109
|
+
# synchronizing stats used to calculate input gradient.
|
|
110
|
+
sum_dy_shape = shape(sum_dy)
|
|
111
|
+
num_channels = sum_dy_shape[0]
|
|
112
|
+
combined = mint.cat([sum_dy, sum_dy_xmu], dim=0)
|
|
113
|
+
|
|
114
|
+
new_combined = all_reduce(combined, group=process_group)
|
|
115
|
+
|
|
116
|
+
sum_dy, sum_dy_xmu = mint.split(new_combined, num_channels)
|
|
117
|
+
|
|
118
|
+
# backward pass for gradient calculation
|
|
119
|
+
grad_input = batch_norm_elemt_grad(
|
|
120
|
+
dout,
|
|
121
|
+
input_x,
|
|
122
|
+
mean_param,
|
|
123
|
+
invstd_param,
|
|
124
|
+
weight,
|
|
125
|
+
sum_dy,
|
|
126
|
+
sum_dy_xmu,
|
|
127
|
+
count_all_param
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# synchronizing of grad_weight / grad_bias is not needed as distributed
|
|
131
|
+
# training would handle all reduce.
|
|
132
|
+
if weight is None or not weightG:
|
|
133
|
+
grad_weight = None
|
|
134
|
+
|
|
135
|
+
if weight is None or not biasG:
|
|
136
|
+
grad_bias = None
|
|
137
|
+
|
|
138
|
+
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def bprop_kbk(input_x, weight, bias, running_mean, running_var, eps, momentum,
|
|
142
|
+
process_group, world_size, output, doutput):
|
|
143
|
+
_, mean_param, invstd_param, count_all_param = output
|
|
144
|
+
dout, _, _, _ = doutput
|
|
145
|
+
|
|
146
|
+
dout = dout.contiguous()
|
|
147
|
+
|
|
148
|
+
grad_input = grad_weight = grad_bias = None
|
|
149
|
+
|
|
150
|
+
inputG = True
|
|
151
|
+
weightG = True
|
|
152
|
+
biasG = True
|
|
153
|
+
|
|
154
|
+
# calculate local stats as well as grad_weight / grad_bias
|
|
155
|
+
sum_dy, sum_dy_xmu, grad_weight, grad_bias = batch_norm_reduce_grad(
|
|
156
|
+
dout,
|
|
157
|
+
input_x,
|
|
158
|
+
mean_param,
|
|
159
|
+
invstd_param,
|
|
160
|
+
weight,
|
|
161
|
+
inputG,
|
|
162
|
+
weightG,
|
|
163
|
+
biasG
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
if inputG:
|
|
167
|
+
# synchronizing stats used to calculate input gradient.
|
|
168
|
+
sum_dy_shape = shape(sum_dy)
|
|
169
|
+
num_channels = sum_dy_shape[0]
|
|
170
|
+
combined = mint.cat([sum_dy, sum_dy_xmu], dim=0)
|
|
171
|
+
|
|
172
|
+
new_combined = all_reduce(combined, group=process_group)
|
|
173
|
+
|
|
174
|
+
sum_dy, sum_dy_xmu = mint.split(new_combined, num_channels)
|
|
175
|
+
|
|
176
|
+
# backward pass for gradient calculation
|
|
177
|
+
grad_input = batch_norm_elemt_grad(
|
|
178
|
+
dout,
|
|
179
|
+
input_x,
|
|
180
|
+
mean_param,
|
|
181
|
+
invstd_param,
|
|
182
|
+
weight,
|
|
183
|
+
sum_dy,
|
|
184
|
+
sum_dy_xmu,
|
|
185
|
+
count_all_param
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# synchronizing of grad_weight / grad_bias is not needed as distributed
|
|
189
|
+
# training would handle all reduce.
|
|
190
|
+
if weight is None or not weightG:
|
|
191
|
+
grad_weight = None
|
|
192
|
+
|
|
193
|
+
if weight is None or not biasG:
|
|
194
|
+
grad_bias = None
|
|
195
|
+
|
|
196
|
+
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def construct_pynative(input, weight, bias, running_mean, running_var, eps, momentum, process_group,
|
|
200
|
+
world_size, self_num_features, self_world_size):
|
|
201
|
+
if self_world_size != world_size:
|
|
202
|
+
raise ValueError('World Size Error')
|
|
203
|
+
if not input.is_contiguous():
|
|
204
|
+
input = input.contiguous()
|
|
205
|
+
if weight is not None:
|
|
206
|
+
weight = weight.contiguous()
|
|
207
|
+
|
|
208
|
+
input_shape = shape(input)
|
|
209
|
+
input_numel = ops.numel(input)
|
|
210
|
+
size = int(input_numel // input_shape[1])
|
|
211
|
+
if size == 1 and world_size < 2:
|
|
212
|
+
raise ValueError(
|
|
213
|
+
'Expected more than 1 value per channel when training, got input size {}'.format(size))
|
|
214
|
+
|
|
215
|
+
# calculate mean/invstd for input.
|
|
216
|
+
mean, invstd = mint.batch_norm_stats(input, eps)
|
|
217
|
+
count = mint.full((1,), input_numel //
|
|
218
|
+
input_shape[1], dtype=mean.dtype)
|
|
219
|
+
|
|
220
|
+
num_channels = input_shape[1]
|
|
221
|
+
if self_num_features != num_channels:
|
|
222
|
+
raise ValueError('Features Error')
|
|
223
|
+
# C, C, 1 -> (2C + 1)
|
|
224
|
+
combined = mint.cat([mean, invstd, count], dim=0)
|
|
225
|
+
# Use allgather instead of allreduce because count could be different across
|
|
226
|
+
# ranks, simple all reduce op can not give correct results.
|
|
227
|
+
# batch_norm_gather_stats_with_counts calculates global mean & invstd based on
|
|
228
|
+
# all gathered mean, invstd and count.
|
|
229
|
+
# world_size * (2C + 1)
|
|
230
|
+
combined = all_gather_into_tensor(combined, process_group)
|
|
231
|
+
combined = ops.reshape(combined, [world_size, -1])
|
|
232
|
+
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
|
|
233
|
+
mean_val_all, invstd_val_all, count_val_all = mint.split(
|
|
234
|
+
combined, num_channels, dim=1)
|
|
235
|
+
# calculate global mean & invstd
|
|
236
|
+
mean, invstd = mint.batch_norm_gather_stats_with_counts(input, mean_val_all, invstd_val_all, running_mean,
|
|
237
|
+
running_var, momentum, eps, count_val_all.view(-1))
|
|
238
|
+
|
|
239
|
+
# apply element-wise normalization
|
|
240
|
+
out = mint.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
|
|
241
|
+
return (out, mean, invstd, count_val_all.view(-1))
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def construct_kbk(input, weight, bias, running_mean, running_var, eps, momentum, process_group,
|
|
245
|
+
world_size, self_num_features, self_world_size):
|
|
246
|
+
if self_world_size != world_size:
|
|
247
|
+
raise ValueError('World Size Error')
|
|
248
|
+
input = input.contiguous()
|
|
249
|
+
if weight is not None:
|
|
250
|
+
weight = weight.contiguous()
|
|
251
|
+
|
|
252
|
+
input_shape = shape(input)
|
|
253
|
+
input_numel = ops.numel(input)
|
|
254
|
+
size = int(input_numel // input_shape[1])
|
|
255
|
+
if size == 1 and world_size < 2:
|
|
256
|
+
raise ValueError(
|
|
257
|
+
'Expected more than 1 value per channel when training, got input size {}'.format(size))
|
|
258
|
+
|
|
259
|
+
# calculate mean/invstd for input.
|
|
260
|
+
mean, invstd = mint.batch_norm_stats(input, eps)
|
|
261
|
+
count = mint.full((1,), input_numel //
|
|
262
|
+
input_shape[1], dtype=mean.dtype)
|
|
263
|
+
|
|
264
|
+
num_channels = input_shape[1]
|
|
265
|
+
if self_num_features != num_channels:
|
|
266
|
+
raise ValueError('Features Error')
|
|
267
|
+
# C, C, 1 -> (2C + 1)
|
|
268
|
+
combined = mint.cat([mean, invstd, count], dim=0)
|
|
269
|
+
# Use allgather instead of allreduce because count could be different across
|
|
270
|
+
# ranks, simple all reduce op can not give correct results.
|
|
271
|
+
# batch_norm_gather_stats_with_counts calculates global mean & invstd based on
|
|
272
|
+
# all gathered mean, invstd and count.
|
|
273
|
+
# world_size * (2C + 1)
|
|
274
|
+
combined = all_gather_into_tensor(combined, process_group)
|
|
275
|
+
combined = ops.reshape(combined, [world_size, -1])
|
|
276
|
+
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
|
|
277
|
+
mean_all, invstd_all, count_all = mint.split(
|
|
278
|
+
combined, num_channels, dim=1)
|
|
279
|
+
# calculate global mean & invstd
|
|
280
|
+
mean, invstd = mint.batch_norm_gather_stats_with_counts(
|
|
281
|
+
input,
|
|
282
|
+
mean_all,
|
|
283
|
+
invstd_all,
|
|
284
|
+
running_mean,
|
|
285
|
+
running_var,
|
|
286
|
+
momentum,
|
|
287
|
+
eps,
|
|
288
|
+
count_all.view(-1)
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# apply element-wise normalization
|
|
292
|
+
out = mint.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
|
|
293
|
+
return (out, mean, invstd, count_all.view(-1))
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
class SyncBatchNormInner(Cell):
|
|
297
|
+
def __init__(self, self_num_features, self_world_size):
|
|
298
|
+
super(SyncBatchNormInner, self).__init__()
|
|
299
|
+
self.num_features = self_num_features
|
|
300
|
+
self.world_size = self_world_size
|
|
301
|
+
self.mode = context.get_context("mode")
|
|
302
|
+
if self.mode == 1:
|
|
303
|
+
self.fn_bprop = bprop_pynative
|
|
304
|
+
self.fn_construct = construct_pynative
|
|
305
|
+
else:
|
|
306
|
+
self.fn_bprop = bprop_kbk
|
|
307
|
+
self.fn_construct = construct_kbk
|
|
308
|
+
|
|
309
|
+
def construct(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
|
|
310
|
+
return self.fn_construct(input, weight, bias, running_mean, running_var, eps, momentum, process_group,
|
|
311
|
+
world_size, self.num_features, self.world_size)
|
|
312
|
+
|
|
313
|
+
def bprop(self, input_x, weight, bias, running_mean, running_var, eps, momentum,
|
|
314
|
+
process_group, world_size, output, doutput):
|
|
315
|
+
return self.fn_bprop(input_x, weight, bias, running_mean, running_var, eps, momentum,
|
|
316
|
+
process_group, world_size, output, doutput)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
class _SyncBatchNorm(Cell):
|
|
320
|
+
def __init__(self, num_features, world_size, dtype=mindspore.float32):
|
|
321
|
+
super(_SyncBatchNorm, self).__init__()
|
|
322
|
+
self.num_features = num_features
|
|
323
|
+
self.world_size = world_size
|
|
324
|
+
self.inner = SyncBatchNormInner(self.num_features, self.world_size)
|
|
325
|
+
|
|
326
|
+
def construct(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
|
|
327
|
+
res = self.inner(input, weight, bias, running_mean,
|
|
328
|
+
running_var, eps, momentum, process_group, world_size)
|
|
329
|
+
output, _, _, _ = res
|
|
330
|
+
return output
|
|
@@ -84,7 +84,7 @@ class LogSigmoid(Cell):
|
|
|
84
84
|
Logsigmoid is defined as:
|
|
85
85
|
|
|
86
86
|
.. math::
|
|
87
|
-
\text{
|
|
87
|
+
\text{LogSigmoid}(x_{i}) = \log(\frac{1}{1 + \exp(-x_i)}),
|
|
88
88
|
|
|
89
89
|
where :math:`x_{i}` is the element of the input.
|
|
90
90
|
|
|
@@ -127,7 +127,175 @@ class LogSigmoid(Cell):
|
|
|
127
127
|
return mint.nn.functional.logsigmoid(input)
|
|
128
128
|
|
|
129
129
|
|
|
130
|
+
class ELU(Cell):
|
|
131
|
+
r"""
|
|
132
|
+
Exponential Linear Unit activation function
|
|
133
|
+
|
|
134
|
+
Applies the exponential linear unit function element-wise.The activation function is defined as:
|
|
135
|
+
|
|
136
|
+
.. math::
|
|
137
|
+
ELU_{i} =
|
|
138
|
+
\begin{cases}
|
|
139
|
+
x_i, &\text{if } x_i \geq 0; \cr
|
|
140
|
+
\alpha * (\exp(x_i) - 1), &\text{otherwise.}
|
|
141
|
+
\end{cases}
|
|
142
|
+
|
|
143
|
+
where :math:`x_i` represents the element of the input and :math:`\alpha` represents the `alpha` parameter.
|
|
144
|
+
|
|
145
|
+
ELU Activation Function Graph:
|
|
146
|
+
|
|
147
|
+
.. image:: ../images/ELU.png
|
|
148
|
+
:align: center
|
|
149
|
+
|
|
150
|
+
.. warning::
|
|
151
|
+
This is an experimental API that is subject to change or deletion.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
alpha (float, optional): The alpha value of ELU, the data type is float. Default: ``1.0`` .
|
|
155
|
+
|
|
156
|
+
Inputs:
|
|
157
|
+
- **input** (Tensor) - The input of ELU is a Tensor of any dimension.
|
|
158
|
+
|
|
159
|
+
Outputs:
|
|
160
|
+
Tensor, with the same type and shape as the `input`.
|
|
161
|
+
|
|
162
|
+
Raises:
|
|
163
|
+
TypeError: If `alpha` is not a float.
|
|
164
|
+
|
|
165
|
+
Supported Platforms:
|
|
166
|
+
``Ascend``
|
|
167
|
+
|
|
168
|
+
Examples:
|
|
169
|
+
>>> import mindspore
|
|
170
|
+
>>> from mindspore import Tensor, mint
|
|
171
|
+
>>> import numpy as np
|
|
172
|
+
>>> input = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float32)
|
|
173
|
+
>>> elu = mint.nn.ELU()
|
|
174
|
+
>>> result = elu(input)
|
|
175
|
+
>>> print(result)
|
|
176
|
+
[-0.63212055 -0.86466473 0. 2. 1.]
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
def __init__(self, alpha=1.0):
|
|
180
|
+
"""Initialize ELU."""
|
|
181
|
+
super(ELU, self).__init__()
|
|
182
|
+
self.alpha = alpha
|
|
183
|
+
|
|
184
|
+
def construct(self, input):
|
|
185
|
+
return mint.nn.functional.elu(input, self.alpha)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class GLU(Cell):
|
|
189
|
+
r"""
|
|
190
|
+
Computes GLU (Gated Linear Unit activation function) of the input tensor.
|
|
191
|
+
|
|
192
|
+
.. math::
|
|
193
|
+
{GLU}(a, b)= a \otimes \sigma(b)
|
|
194
|
+
|
|
195
|
+
where :math:`a` is the first half of the `input` Tensor after `input` is split and :math:`b` is the second half.
|
|
196
|
+
|
|
197
|
+
Here :math:`\sigma` is the sigmoid function, and :math:`\otimes` is the Hadamard product.
|
|
198
|
+
See `Language Modeling with Gated Convluational Networks <https://arxiv.org/abs/1612.08083>`_ .
|
|
199
|
+
|
|
200
|
+
.. warning::
|
|
201
|
+
This is an experimental API that is subject to change or deletion.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
dim (int, optional): The dimension to split the input `input`. The value range is `[-r, r)` where `r`
|
|
205
|
+
is the number of dimensions of `input`. Default: ``-1`` , the last dimension in `input`.
|
|
206
|
+
|
|
207
|
+
Inputs:
|
|
208
|
+
- **input** (Tensor) - Tensor to be calculated. Dtype is floating point and the shape
|
|
209
|
+
is :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional dimensions. :math:`N`
|
|
210
|
+
is required to be an even number, where :math:`N` is the size of `input` on the dimension
|
|
211
|
+
selected by `dim`.
|
|
212
|
+
|
|
213
|
+
Outputs:
|
|
214
|
+
Tensor, the same dtype as the `input`, with the shape :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`.
|
|
215
|
+
|
|
216
|
+
Raises:
|
|
217
|
+
TypeError: If `input` is not a Tensor or `dim` is not an int.
|
|
218
|
+
IndexError: If the value of `dim` is out of the range of `[-r, r)`, where `r` is the number
|
|
219
|
+
of dimensions of `input`.
|
|
220
|
+
RuntimeError: If dtype of `input` is not supported.
|
|
221
|
+
RuntimeError: If the length of `input` in the dimension selected by `dim` is not even.
|
|
222
|
+
|
|
223
|
+
Supported Platforms:
|
|
224
|
+
``Ascend`` ``CPU``
|
|
225
|
+
|
|
226
|
+
Examples:
|
|
227
|
+
>>> import mindspore as ms
|
|
228
|
+
>>> m = ms.mint.nn.GLU()
|
|
229
|
+
>>> input = ms.Tensor([[0.1,0.2,0.3,0.4],[0.5,0.6,0.7,0.8]])
|
|
230
|
+
>>> output = m(input)
|
|
231
|
+
>>> print(output)
|
|
232
|
+
[[0.05744425 0.11973753]
|
|
233
|
+
[0.33409387 0.41398472]]
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
def __init__(self, dim=-1):
|
|
237
|
+
"""Initialize GLU."""
|
|
238
|
+
super().__init__("GLU")
|
|
239
|
+
self.dim = dim
|
|
240
|
+
|
|
241
|
+
def construct(self, input):
|
|
242
|
+
return mint.nn.functional.glu(input, self.dim)
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class Tanh(Cell):
|
|
246
|
+
r"""
|
|
247
|
+
Applies the Tanh function element-wise, returns a new tensor with the hyperbolic tangent of the elements of input.
|
|
248
|
+
|
|
249
|
+
Tanh function is defined as:
|
|
250
|
+
|
|
251
|
+
.. math::
|
|
252
|
+
tanh(x_i) = \frac{\exp(x_i) - \exp(-x_i)}{\exp(x_i) + \exp(-x_i)} = \frac{\exp(2x_i) - 1}{\exp(2x_i) + 1},
|
|
253
|
+
|
|
254
|
+
where :math:`x_i` is an element of the input Tensor.
|
|
255
|
+
|
|
256
|
+
Tanh Activation Function Graph:
|
|
257
|
+
|
|
258
|
+
.. image:: ../images/Tanh.png
|
|
259
|
+
:align: center
|
|
260
|
+
|
|
261
|
+
.. warning::
|
|
262
|
+
This is an experimental API that is subject to change or deletion.
|
|
263
|
+
|
|
264
|
+
Inputs:
|
|
265
|
+
- **input** (Tensor) - Tensor of any dimension, input with data type of float16 or float32.
|
|
266
|
+
|
|
267
|
+
Outputs:
|
|
268
|
+
Tensor, with the same type and shape as the `input`.
|
|
269
|
+
|
|
270
|
+
Raises:
|
|
271
|
+
TypeError: If dtype of `input` is neither float16 nor float32.
|
|
272
|
+
|
|
273
|
+
Supported Platforms:
|
|
274
|
+
``Ascend``
|
|
275
|
+
|
|
276
|
+
Examples:
|
|
277
|
+
>>> import mindspore
|
|
278
|
+
>>> from mindspore import Tensor, mint
|
|
279
|
+
>>> import numpy as np
|
|
280
|
+
>>> input = Tensor(np.array([1, 2, 3, 2, 1]), mindspore.float16)
|
|
281
|
+
>>> tanh = mint.nn.Tanh()
|
|
282
|
+
>>> output = tanh(input)
|
|
283
|
+
>>> print(output)
|
|
284
|
+
[0.7617 0.964 0.995 0.964 0.7617]
|
|
285
|
+
"""
|
|
286
|
+
|
|
287
|
+
def __init__(self):
|
|
288
|
+
"""Initialize Tanh."""
|
|
289
|
+
super(Tanh, self).__init__()
|
|
290
|
+
|
|
291
|
+
def construct(self, input):
|
|
292
|
+
return mint.nn.functional.tanh(input)
|
|
293
|
+
|
|
294
|
+
|
|
130
295
|
__all__ = [
|
|
131
296
|
'LogSigmoid',
|
|
132
297
|
'SiLU',
|
|
298
|
+
'ELU',
|
|
299
|
+
'GLU',
|
|
300
|
+
'Tanh',
|
|
133
301
|
]
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""activation layer for mint"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
from __future__ import division
|
|
18
|
+
|
|
19
|
+
from mindspore import mint
|
|
20
|
+
from mindspore.nn.cell import Cell
|
|
21
|
+
from mindspore import _checkparam as validator
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class Dropout2d(Cell):
|
|
25
|
+
r"""
|
|
26
|
+
During training, randomly zeroes some channels of the input tensor with probability `p`
|
|
27
|
+
from a Bernoulli distribution (For a 4-dimensional tensor with a shape of :math:`NCHW`,
|
|
28
|
+
the channel feature map refers to a 2-dimensional feature map with the shape of :math:`HW`).
|
|
29
|
+
|
|
30
|
+
For example, the :math:`j\_th` channel of the :math:`i\_th` sample in the batched input is a to-be-processed
|
|
31
|
+
`2D` tensor input[i,j].
|
|
32
|
+
Each channel will be zeroed out independently on every forward call with probability `p` using samples
|
|
33
|
+
from a Bernoulli distribution.
|
|
34
|
+
|
|
35
|
+
`Dropout2d` can improve the independence between channel feature maps.
|
|
36
|
+
|
|
37
|
+
.. warning::
|
|
38
|
+
This is an experimental API that is subject to change or deletion.
|
|
39
|
+
|
|
40
|
+
Refer to :func:`mindspore.mint.nn.functional.dropout2d` for more details.
|
|
41
|
+
|
|
42
|
+
Supported Platforms:
|
|
43
|
+
``Ascend``
|
|
44
|
+
|
|
45
|
+
Examples:
|
|
46
|
+
>>> import mindspore
|
|
47
|
+
>>> from mindspore import Tensor, mint
|
|
48
|
+
>>> import numpy as np
|
|
49
|
+
>>> dropout = mint.nn.Dropout2d(p=0.5)
|
|
50
|
+
>>> x = Tensor(np.ones([2, 1, 2, 3]), mindspore.float32)
|
|
51
|
+
>>> output = dropout(x)
|
|
52
|
+
>>> print(output.shape)
|
|
53
|
+
(2, 1, 2, 3)
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, p=0.5):
|
|
57
|
+
"""Initialize Dropout2d."""
|
|
58
|
+
super(Dropout2d, self).__init__()
|
|
59
|
+
validator.check_float_range(p, 0.0, 1.0, validator.INC_BOTH, "p", self.cls_name)
|
|
60
|
+
self.p = p
|
|
61
|
+
|
|
62
|
+
def construct(self, x):
|
|
63
|
+
if not self.training or self.p == 0:
|
|
64
|
+
return x
|
|
65
|
+
|
|
66
|
+
return mint.nn.functional.dropout2d(x, self.p)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Flatten(Cell):
|
|
70
|
+
r"""
|
|
71
|
+
Flatten the input Tensor along dimensions from `start_dim` to `end_dim`.
|
|
72
|
+
|
|
73
|
+
.. warning::
|
|
74
|
+
This is an experimental API that is subject to change or deletion.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
start_dim (int, optional): The first dimension to flatten. Default: ``1`` .
|
|
78
|
+
end_dim (int, optional): The last dimension to flatten. Default: ``-1`` .
|
|
79
|
+
|
|
80
|
+
Inputs:
|
|
81
|
+
- **input** (Tensor) - The input Tensor to be flattened.
|
|
82
|
+
|
|
83
|
+
Outputs:
|
|
84
|
+
Tensor. If no dimensions are flattened, returns the original `input`, otherwise return the flattened Tensor.
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
TypeError: If `input` is not a Tensor.
|
|
88
|
+
TypeError: If `start_dim` or `end_dim` is not int.
|
|
89
|
+
ValueError: If `start_dim` is greater than `end_dim` after canonicalized.
|
|
90
|
+
|
|
91
|
+
Supported Platforms:
|
|
92
|
+
``Ascend``
|
|
93
|
+
|
|
94
|
+
Examples:
|
|
95
|
+
>>> import mindspore
|
|
96
|
+
>>> from mindspore import Tensor, mint
|
|
97
|
+
>>> import numpy as np
|
|
98
|
+
>>> input = Tensor(np.array([[[1.2, 1.2], [2.1, 2.1]], [[2.2, 2.2], [3.2, 3.2]]]), mindspore.float32)
|
|
99
|
+
>>> net = mint.nn.Flatten()
|
|
100
|
+
>>> output = net(input)
|
|
101
|
+
>>> print(output)
|
|
102
|
+
[[1.2 1.2 2.1 2.1]
|
|
103
|
+
[2.2 2.2 3.2 3.2]]
|
|
104
|
+
>>> print(f"before flatten the x shape is {input.shape}")
|
|
105
|
+
before flatten the input shape is (2, 2, 2)
|
|
106
|
+
>>> print(f"after flatten the output shape is {output.shape}")
|
|
107
|
+
after flatten the output shape is (2, 4)
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def __init__(self, start_dim=1, end_dim=-1):
|
|
111
|
+
"""Initialize Flatten."""
|
|
112
|
+
super(Flatten, self).__init__()
|
|
113
|
+
self.start_dim = start_dim
|
|
114
|
+
self.end_dim = end_dim
|
|
115
|
+
|
|
116
|
+
def construct(self, input):
|
|
117
|
+
return mint.nn.functional.flatten(input, self.start_dim, self.end_dim)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
__all__ = [
|
|
121
|
+
'Dropout2d',
|
|
122
|
+
'Flatten',
|
|
123
|
+
]
|