mindspore 2.4.1__cp310-cp310-win_amd64.whl → 2.5.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +8 -3
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +0 -5
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/compile_config.py +64 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
- mindspore/_extends/parse/parser.py +23 -5
- mindspore/_extends/parse/standard_method.py +123 -27
- mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
- mindspore/amp.py +7 -1
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/boost_cell_wrapper.py +136 -41
- mindspore/common/__init__.py +3 -1
- mindspore/common/_register_for_tensor.py +0 -1
- mindspore/common/_stub_tensor.py +25 -4
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +6132 -0
- mindspore/common/api.py +99 -25
- mindspore/common/dtype.py +34 -34
- mindspore/common/dump.py +2 -1
- mindspore/common/file_system.py +8 -1
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +3 -1
- mindspore/common/initializer.py +3 -4
- mindspore/common/lazy_inline.py +8 -2
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/parameter.py +30 -27
- mindspore/common/tensor.py +713 -1337
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +10 -0
- mindspore/communication/comm_func.py +215 -173
- mindspore/communication/management.py +23 -20
- mindspore/context.py +292 -193
- mindspore/dataset/__init__.py +23 -19
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +84 -3
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +5 -4
- mindspore/dataset/engine/datasets.py +192 -149
- mindspore/dataset/engine/datasets_audio.py +14 -0
- mindspore/dataset/engine/datasets_standard_format.py +28 -11
- mindspore/dataset/engine/datasets_text.py +38 -1
- mindspore/dataset/engine/datasets_user_defined.py +125 -65
- mindspore/dataset/engine/datasets_vision.py +81 -8
- mindspore/dataset/engine/iterators.py +281 -63
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +26 -2
- mindspore/dataset/engine/serializer_deserializer.py +1 -1
- mindspore/dataset/engine/validators.py +43 -11
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +29 -12
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +94 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +127 -0
- mindspore/device_context/cpu/__init__.py +25 -0
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +134 -0
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/__init__.py +3 -2
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +239 -64
- mindspore/experimental/llm_boost/atb/llama_boost.py +52 -30
- mindspore/experimental/llm_boost/atb/qwen_boost.py +47 -24
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/optim/adadelta.py +26 -22
- mindspore/experimental/optim/adam.py +3 -0
- mindspore/experimental/optim/lr_scheduler.py +33 -24
- mindspore/experimental/optim/radam.py +33 -30
- mindspore/hal/device.py +28 -0
- mindspore/hal/event.py +17 -0
- mindspore/hal/memory.py +94 -3
- mindspore/hal/stream.py +91 -6
- mindspore/include/api/context.h +1 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +12 -0
- mindspore/mindrecord/__init__.py +1 -1
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +824 -218
- mindspore/mint/distributed/__init__.py +66 -4
- mindspore/mint/distributed/distributed.py +2594 -44
- mindspore/mint/linalg/__init__.py +6 -0
- mindspore/mint/nn/__init__.py +473 -14
- mindspore/mint/nn/functional.py +486 -11
- mindspore/mint/nn/layer/__init__.py +17 -4
- mindspore/mint/nn/layer/_functions.py +330 -0
- mindspore/mint/nn/layer/activation.py +169 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +727 -0
- mindspore/mint/nn/layer/normalization.py +215 -19
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +170 -0
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +2 -0
- mindspore/nn/cell.py +142 -21
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +6 -6
- mindspore/nn/layer/basic.py +35 -25
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/conv.py +3 -0
- mindspore/nn/layer/embedding.py +3 -3
- mindspore/nn/layer/normalization.py +8 -7
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +55 -23
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +48 -26
- mindspore/nn/learning_rate_schedule.py +5 -3
- mindspore/nn/loss/loss.py +31 -36
- mindspore/nn/optim/ada_grad.py +1 -0
- mindspore/nn/optim/adadelta.py +2 -2
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/utils/__init__.py +22 -0
- mindspore/nn/utils/init.py +73 -0
- mindspore/nn/wrap/cell_wrapper.py +4 -6
- mindspore/nn/wrap/loss_scale.py +3 -4
- mindspore/numpy/array_creations.py +60 -62
- mindspore/numpy/array_ops.py +148 -143
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +16 -16
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +2 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +107 -8
- mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_vmap/vmap_array_ops.py +20 -19
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
- mindspore/ops/_vmap/vmap_math_ops.py +11 -9
- mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
- mindspore/ops/auto_generate/gen_extend_func.py +554 -60
- mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
- mindspore/ops/auto_generate/gen_ops_prim.py +8027 -3411
- mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
- mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
- mindspore/ops/function/__init__.py +12 -0
- mindspore/ops/function/array_func.py +561 -159
- mindspore/ops/function/clip_func.py +64 -0
- mindspore/ops/function/debug_func.py +28 -20
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +5 -4
- mindspore/ops/function/math_func.py +1664 -294
- mindspore/ops/function/nn_func.py +988 -317
- mindspore/ops/function/parameter_func.py +3 -56
- mindspore/ops/function/random_func.py +243 -33
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/functional.py +18 -5
- mindspore/ops/functional_overload.py +897 -0
- mindspore/ops/operations/__init__.py +3 -2
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -34
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +38 -8
- mindspore/ops/operations/array_ops.py +45 -303
- mindspore/ops/operations/comm_ops.py +23 -17
- mindspore/ops/operations/custom_ops.py +7 -49
- mindspore/ops/operations/debug_ops.py +42 -47
- mindspore/ops/operations/inner_ops.py +6 -4
- mindspore/ops/operations/linalg_ops.py +3 -2
- mindspore/ops/operations/manually_defined/ops_def.py +185 -104
- mindspore/ops/operations/math_ops.py +11 -216
- mindspore/ops/operations/nn_ops.py +153 -310
- mindspore/ops/primitive.py +23 -21
- mindspore/ops/tensor_method.py +1669 -0
- mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
- mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
- mindspore/ops_generate/arg_handler.py +0 -61
- mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
- mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/base_generator.py +11 -0
- mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
- mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
- mindspore/ops_generate/functional_overload_py_generator.py +110 -0
- mindspore/ops_generate/functions_cc_generator.py +233 -0
- mindspore/ops_generate/gen_aclnn_implement.py +110 -114
- mindspore/ops_generate/gen_constants.py +157 -3
- mindspore/ops_generate/gen_ops.py +245 -990
- mindspore/ops_generate/gen_pyboost_func.py +97 -998
- mindspore/ops_generate/gen_utils.py +119 -33
- mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
- mindspore/ops_generate/op_api_proto.py +206 -0
- mindspore/ops_generate/op_def_py_generator.py +131 -0
- mindspore/ops_generate/op_prim_py_generator.py +480 -0
- mindspore/ops_generate/op_proto.py +373 -108
- mindspore/ops_generate/op_template_parser.py +436 -0
- mindspore/ops_generate/ops_def_cc_generator.py +288 -0
- mindspore/ops_generate/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/ops_name_h_generator.py +68 -0
- mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
- mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
- mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
- mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
- mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
- mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
- mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
- mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
- mindspore/ops_generate/pyboost_utils.py +92 -33
- mindspore/ops_generate/template.py +294 -44
- mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
- mindspore/parallel/__init__.py +3 -3
- mindspore/parallel/_auto_parallel_context.py +44 -34
- mindspore/parallel/_cell_wrapper.py +22 -3
- mindspore/parallel/_parallel_serialization.py +13 -2
- mindspore/parallel/_utils.py +4 -2
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +44 -0
- mindspore/parallel/cluster/process_entity/_api.py +131 -37
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +20 -3
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +3 -0
- mindspore/parallel/transform_safetensors.py +119 -253
- mindspore/profiler/__init__.py +17 -4
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +174 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +202 -0
- mindspore/profiler/common/path_manager.py +371 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +476 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +210 -0
- mindspore/profiler/common/profiler_path_manager.py +120 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +270 -37
- mindspore/profiler/envprofiler.py +138 -0
- mindspore/profiler/mstx.py +199 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +309 -0
- mindspore/profiler/profiler.py +580 -93
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +114 -0
- mindspore/profiler/schedule.py +208 -0
- mindspore/rewrite/api/symbol_tree.py +1 -2
- mindspore/run_check/_check_version.py +18 -13
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +148 -0
- mindspore/runtime/memory.py +392 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +2 -2
- mindspore/train/_utils.py +53 -18
- mindspore/train/amp.py +8 -4
- mindspore/train/callback/_checkpoint.py +32 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +105 -69
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_summary_collector.py +44 -6
- mindspore/train/callback/_tft_register.py +37 -15
- mindspore/train/dataset_helper.py +11 -11
- mindspore/train/metrics/precision.py +4 -5
- mindspore/train/mind_ir_pb2.py +167 -46
- mindspore/train/model.py +13 -14
- mindspore/train/serialization.py +461 -72
- mindspore/train/summary/summary_record.py +1 -2
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +4 -2
- mindspore/utils/dryrun.py +138 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/METADATA +3 -4
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/RECORD +368 -242
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.1.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
|
@@ -16,7 +16,13 @@
|
|
|
16
16
|
from __future__ import absolute_import
|
|
17
17
|
|
|
18
18
|
from mindspore.ops.function.math_func import inverse_ext as inv
|
|
19
|
+
from mindspore.ops.function.math_func import vector_norm_ext as vector_norm
|
|
20
|
+
from mindspore.ops.function.math_func import matrix_norm_ext as matrix_norm
|
|
21
|
+
from mindspore.ops.function.math_func import linalg_norm as norm
|
|
19
22
|
|
|
20
23
|
__all__ = [
|
|
21
24
|
'inv',
|
|
25
|
+
'vector_norm',
|
|
26
|
+
'matrix_norm',
|
|
27
|
+
'norm',
|
|
22
28
|
]
|
mindspore/mint/nn/__init__.py
CHANGED
|
@@ -22,6 +22,7 @@ import mindspore.ops as ops
|
|
|
22
22
|
from mindspore.mint.nn import functional as F
|
|
23
23
|
from mindspore.nn.cell import Cell
|
|
24
24
|
from mindspore.nn import EmbeddingExt as Embedding, MaxPool2dExt as MaxPool2d, LayerNormExt as LayerNorm, Linear
|
|
25
|
+
import mindspore.nn as nn
|
|
25
26
|
|
|
26
27
|
# 1
|
|
27
28
|
|
|
@@ -32,6 +33,12 @@ from mindspore.nn.layer.basic import Identity
|
|
|
32
33
|
# 4
|
|
33
34
|
|
|
34
35
|
# 5
|
|
36
|
+
from mindspore.mint.nn.layer.padding import (
|
|
37
|
+
ConstantPad1d, ConstantPad2d, ConstantPad3d,
|
|
38
|
+
ZeroPad1d, ZeroPad2d, ZeroPad3d,
|
|
39
|
+
ReflectionPad1d, ReflectionPad2d, ReflectionPad3d,
|
|
40
|
+
ReplicationPad1d, ReplicationPad2d, ReplicationPad3d
|
|
41
|
+
)
|
|
35
42
|
|
|
36
43
|
# 6
|
|
37
44
|
from mindspore.nn.layer.basic import UnfoldExt as Unfold
|
|
@@ -53,7 +60,8 @@ from mindspore.nn.layer import ReLU
|
|
|
53
60
|
# 14
|
|
54
61
|
from mindspore.nn.layer.basic import DropoutExt as Dropout
|
|
55
62
|
# 15
|
|
56
|
-
|
|
63
|
+
from mindspore.mint.nn.layer.conv import Conv2d, ConvTranspose2d
|
|
64
|
+
from mindspore.mint.nn.layer.conv import Conv3d
|
|
57
65
|
# 16
|
|
58
66
|
from mindspore.nn.layer import LogSoftmaxExt as LogSoftmax
|
|
59
67
|
# 17
|
|
@@ -105,6 +113,7 @@ from mindspore.nn.layer import PReLUExt as PReLU
|
|
|
105
113
|
# 40
|
|
106
114
|
from mindspore.mint.nn.layer.normalization import GroupNorm
|
|
107
115
|
from mindspore.mint.nn.layer.normalization import LayerNorm
|
|
116
|
+
from mindspore.mint.nn.layer.normalization import SyncBatchNorm
|
|
108
117
|
# 41
|
|
109
118
|
|
|
110
119
|
# 42
|
|
@@ -221,11 +230,13 @@ from mindspore.mint.nn.layer.activation import SiLU, LogSigmoid
|
|
|
221
230
|
# 97
|
|
222
231
|
|
|
223
232
|
# 98
|
|
224
|
-
|
|
233
|
+
from mindspore.nn.layer import AvgPool1dExt as AvgPool1d
|
|
225
234
|
# 99
|
|
226
235
|
from mindspore.nn.layer import AvgPool2dExt as AvgPool2d
|
|
227
236
|
# 100
|
|
228
237
|
from mindspore.nn.layer import SoftShrink as Softshrink
|
|
238
|
+
# 152
|
|
239
|
+
from mindspore.mint.nn.layer.pooling import AdaptiveAvgPool3d
|
|
229
240
|
# 159
|
|
230
241
|
|
|
231
242
|
# 220
|
|
@@ -237,11 +248,28 @@ from mindspore.nn.layer import HSwish as Hardswish
|
|
|
237
248
|
# 238
|
|
238
249
|
from mindspore.nn.loss import L1LossExt as L1Loss
|
|
239
250
|
|
|
251
|
+
# 254
|
|
252
|
+
from mindspore.mint.nn.layer.pooling import MaxUnpool2d
|
|
253
|
+
|
|
240
254
|
# 257
|
|
241
255
|
|
|
242
256
|
# 258
|
|
243
257
|
from mindspore.ops.function.nn_func import mse_loss_ext
|
|
244
258
|
|
|
259
|
+
# 393
|
|
260
|
+
from mindspore.mint.nn.layer.basic import Dropout2d
|
|
261
|
+
|
|
262
|
+
# 406
|
|
263
|
+
from mindspore.mint.nn.layer.activation import ELU
|
|
264
|
+
|
|
265
|
+
# 407
|
|
266
|
+
from mindspore.mint.nn.layer.basic import Flatten
|
|
267
|
+
|
|
268
|
+
# 421
|
|
269
|
+
from mindspore.mint.nn.layer.activation import Tanh
|
|
270
|
+
|
|
271
|
+
# 536
|
|
272
|
+
from mindspore.mint.nn.layer.activation import GLU
|
|
245
273
|
|
|
246
274
|
# 674
|
|
247
275
|
from mindspore.mint.nn.layer.normalization import BatchNorm1d
|
|
@@ -256,6 +284,209 @@ from mindspore.mint.nn.layer.pooling import AdaptiveAvgPool1d
|
|
|
256
284
|
|
|
257
285
|
from mindspore.mint.nn.layer.pooling import AdaptiveAvgPool2d
|
|
258
286
|
|
|
287
|
+
from mindspore.ops.function.nn_func import cross_entropy_ext as cross_entropy
|
|
288
|
+
|
|
289
|
+
from mindspore.ops.function.nn_func import _nllloss_nd as nllloss
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class NLLLoss(Cell):
|
|
293
|
+
r"""
|
|
294
|
+
Gets the negative log likelihood loss between inputs and target.
|
|
295
|
+
|
|
296
|
+
The nll loss with reduction=none can be described as:
|
|
297
|
+
|
|
298
|
+
.. math::
|
|
299
|
+
|
|
300
|
+
\ell(x, t)=L=\left\{l_{1}, \ldots, l_{N}\right\}^{\top},
|
|
301
|
+
\quad l_{n}=-w_{t_{n}} x_{n, t_{n}},
|
|
302
|
+
\quad w_{c}=\text { weight }[c] \cdot \mathbb{1}
|
|
303
|
+
\{c \not= \text{ignore_index}\},
|
|
304
|
+
|
|
305
|
+
where :math:`x` is the inputs, :math:`t` is the target, :math:`w` is the weight,
|
|
306
|
+
:math:`N` is the batch size, :math:`c` belonging to :math:`[0, C-1]` is class index,
|
|
307
|
+
where :math:`C` is the number of classes.
|
|
308
|
+
|
|
309
|
+
If `reduction` is not ``None`` (default ``'mean'``), then
|
|
310
|
+
|
|
311
|
+
.. math::
|
|
312
|
+
|
|
313
|
+
\ell(x, t)=\left\{\begin{array}{ll}
|
|
314
|
+
\sum_{n=1}^{N} \frac{1}{\sum_{n=1}^{N} w_{t n}} l_{n}, & \text { if reduction }=\text { 'mean', } \\
|
|
315
|
+
\sum_{n=1}^{N} l_{n}, & \text { if reduction }=\text { 'sum' }
|
|
316
|
+
\end{array}\right.
|
|
317
|
+
|
|
318
|
+
.. warning::
|
|
319
|
+
This is an experimental API that is subject to change or deletion.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
weight (Tensor, optional): A rescaling weight applied to the loss of each batch element.
|
|
323
|
+
If not None, the shape is :math:`(C,)`, data type must be float16 or float32 or bfloat16(only supported by
|
|
324
|
+
Atlas A2 training series products). It should have the same data type as `input` . Default: ``None`` .
|
|
325
|
+
ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input
|
|
326
|
+
gradient. Only valid in class indices, please set it to a negative number in probabilities.
|
|
327
|
+
Default: ``-100`` .
|
|
328
|
+
reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
|
|
329
|
+
``'sum'`` . Default: ``'mean'`` .
|
|
330
|
+
|
|
331
|
+
- ``'none'``: no reduction will be applied.
|
|
332
|
+
- ``'mean'``: compute and return the weighted mean of elements in the output.
|
|
333
|
+
- ``'sum'``: the output elements will be summed.
|
|
334
|
+
|
|
335
|
+
Inputs:
|
|
336
|
+
- **input** (Tensor) - :math:`(N)` or :math:`(N, C)` where `C = number of classes` , `N = batch size` ,
|
|
337
|
+
or :math:`(N, C, d_1, d_2, ..., d_K)` (for high-dimensional data).
|
|
338
|
+
`input` is expected to be log-probabilities, data type must be float16 or float32 or bfloat16(only supported
|
|
339
|
+
by Atlas A2 training series products).
|
|
340
|
+
- **target** (Tensor) - :math:`()` or :math:`(N)` ,
|
|
341
|
+
where the value range is :math:`[0, C-1]`, or :math:`(N, d_1, d_2, ..., d_K)` for
|
|
342
|
+
high-dimensional loss, data type must be int32 or int64 or uint8.
|
|
343
|
+
|
|
344
|
+
Outputs:
|
|
345
|
+
Tensor, the data type is the same as `input` .
|
|
346
|
+
|
|
347
|
+
Supported Platforms:
|
|
348
|
+
``Ascend``
|
|
349
|
+
|
|
350
|
+
Examples:
|
|
351
|
+
>>> import mindspore
|
|
352
|
+
>>> import numpy as np
|
|
353
|
+
>>> from mindspore import Tensor, ops
|
|
354
|
+
>>> inputs = mindspore.Tensor(np.random.randn(3, 5), mindspore.float32)
|
|
355
|
+
>>> target = mindspore.Tensor(np.array([1, 0, 4]), mindspore.int32)
|
|
356
|
+
>>> op = mindspore.mint.nn.NLLLoss()
|
|
357
|
+
>>> output = op(inputs, target)
|
|
358
|
+
|
|
359
|
+
"""
|
|
360
|
+
|
|
361
|
+
def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
|
|
362
|
+
super(NLLLoss, self).__init__()
|
|
363
|
+
self.weight = weight
|
|
364
|
+
self.ignore_index = ignore_index
|
|
365
|
+
self.reduction = reduction
|
|
366
|
+
|
|
367
|
+
def construct(self, input, target):
|
|
368
|
+
out = nllloss(input, target, self.weight, self.ignore_index, self.reduction)
|
|
369
|
+
return out
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
class CrossEntropyLoss(Cell):
|
|
373
|
+
r"""
|
|
374
|
+
The cross entropy loss between input and target.
|
|
375
|
+
|
|
376
|
+
The cross entropy supports two kind of targets:
|
|
377
|
+
|
|
378
|
+
- Class indices (int) in the range :math:`[0, C)` where :math:`C` is the number of classes,
|
|
379
|
+
the loss with reduction=none can be described as:
|
|
380
|
+
|
|
381
|
+
.. math::
|
|
382
|
+
|
|
383
|
+
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
|
|
384
|
+
l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})}
|
|
385
|
+
\cdot \mathbb{1}\{y_n \not= \text{ignore_index}\}
|
|
386
|
+
|
|
387
|
+
where :math:`x` is the inputs, :math:`y` is the target, :math:`w` is the weight, :math:`N` is the batch size,
|
|
388
|
+
:math:`c` belonging to :math:`[0, C-1]` is class index, where :math:`C` is the number of classes.
|
|
389
|
+
|
|
390
|
+
If `reduction` is not ``None`` (default ``'mean'`` ), then
|
|
391
|
+
|
|
392
|
+
.. math::
|
|
393
|
+
|
|
394
|
+
\ell(x, y) = \begin{cases}
|
|
395
|
+
\sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore_index}\}} l_n, &
|
|
396
|
+
\text{if reduction} = \text{'mean',}\\
|
|
397
|
+
\sum_{n=1}^N l_n, &
|
|
398
|
+
\text{if reduction} = \text{'sum'.}
|
|
399
|
+
\end{cases}
|
|
400
|
+
|
|
401
|
+
- Probabilities (float) for each class, useful when labels beyond a single class per minibatch item
|
|
402
|
+
are required, the loss with reduction=none can be described as:
|
|
403
|
+
|
|
404
|
+
.. math::
|
|
405
|
+
|
|
406
|
+
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
|
|
407
|
+
l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c}
|
|
408
|
+
|
|
409
|
+
where :math:`x` is the inputs, :math:`y` is the target, :math:`w` is the weight, N is the batch size,
|
|
410
|
+
:math:`c` belonging to :math:`[0, C-1]` is class index, where :math:`C` is the number of classes.
|
|
411
|
+
|
|
412
|
+
If `reduction` is not ``None`` (default ``'mean'`` ), then
|
|
413
|
+
|
|
414
|
+
.. math::
|
|
415
|
+
|
|
416
|
+
\ell(x, y) = \begin{cases}
|
|
417
|
+
\frac{\sum_{n=1}^N l_n}{N}, &
|
|
418
|
+
\text{if reduction} = \text{'mean',}\\
|
|
419
|
+
\sum_{n=1}^N l_n, &
|
|
420
|
+
\text{if reduction} = \text{'sum'.}
|
|
421
|
+
\end{cases}
|
|
422
|
+
|
|
423
|
+
.. warning::
|
|
424
|
+
This is an experimental API that is subject to change or deletion.
|
|
425
|
+
|
|
426
|
+
Note:
|
|
427
|
+
Dynamic shape, dynamic rank and variable constant input are not supported in `strict graph mode
|
|
428
|
+
(jit_syntax_level=mindspore.STRICT)
|
|
429
|
+
<https://www.mindspore.cn/docs/en/master/model_train/program_form/static_graph.html>`_.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
weight (Tensor, optional): A rescaling weight applied to the loss of each batch element.
|
|
433
|
+
If not None, the shape is :math:`(C,)`, data type must be float16 or float32 or bfloat16(only supported by
|
|
434
|
+
Atlas A2 training series products). Default: ``None`` .
|
|
435
|
+
ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input
|
|
436
|
+
gradient. Only valid in class indices, please set it to a negative number in probabilities.
|
|
437
|
+
Default: ``-100`` .
|
|
438
|
+
reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
|
|
439
|
+
``'sum'`` . Default: ``'mean'`` .
|
|
440
|
+
|
|
441
|
+
- ``'none'``: no reduction will be applied.
|
|
442
|
+
- ``'mean'``: compute and return the weighted mean of elements in the output.
|
|
443
|
+
- ``'sum'``: the output elements will be summed.
|
|
444
|
+
|
|
445
|
+
label_smoothing (float, optional): Label smoothing values, a regularization tool used to prevent the model
|
|
446
|
+
from overfitting when calculating Loss. The value range is [0.0, 1.0]. Default: ``0.0`` .
|
|
447
|
+
|
|
448
|
+
Inputs:
|
|
449
|
+
- **input** (Tensor) - :math:`(N)` or :math:`(N, C)` where `C = number of classes` or :math:`(N, C, H, W)`
|
|
450
|
+
in case of 2D Loss, or :math:`(N, C, d_1, d_2, ..., d_K)`.
|
|
451
|
+
`input` is expected to be log-probabilities, data type must be float16 or float32 or bfloat16(only supported
|
|
452
|
+
by Atlas A2 training series products).
|
|
453
|
+
- **target** (Tensor) - For class indices, tensor of shape :math:`()`, :math:`(N)` or
|
|
454
|
+
:math:`(N, d_1, d_2, ..., d_K)` , data type must be int32 or int64. For probabilities, tensor of shape
|
|
455
|
+
:math:`(N,)` , :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` , data type must be float16 or float32
|
|
456
|
+
or bfloat16(only supported by Atlas A2 training series products).
|
|
457
|
+
|
|
458
|
+
Outputs:
|
|
459
|
+
Tensor, the data type is the same as `input` .
|
|
460
|
+
|
|
461
|
+
Supported Platforms:
|
|
462
|
+
``Ascend``
|
|
463
|
+
|
|
464
|
+
Examples:
|
|
465
|
+
>>> import mindspore as ms
|
|
466
|
+
>>> import numpy as np
|
|
467
|
+
>>> # Case 1: Indices labels
|
|
468
|
+
>>> inputs = ms.Tensor(np.random.randn(3, 5), ms.float32)
|
|
469
|
+
>>> target = ms.Tensor(np.array([1, 0, 4]), ms.int32)
|
|
470
|
+
>>> op = ms.mint.nn.CrossEntropyLoss()
|
|
471
|
+
>>> output = op(inputs, target)
|
|
472
|
+
>>> # Case 2: Probability labels
|
|
473
|
+
>>> inputs = ms.Tensor(np.random.randn(3, 5), ms.float32)
|
|
474
|
+
>>> target = ms.Tensor(np.random.randn(3, 5), ms.float32)
|
|
475
|
+
>>> op = ms.mint.nn.CrossEntropyLoss()
|
|
476
|
+
>>> output = op(inputs, target)
|
|
477
|
+
"""
|
|
478
|
+
|
|
479
|
+
def __init__(self, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0):
|
|
480
|
+
super(CrossEntropyLoss, self).__init__()
|
|
481
|
+
self.weight = weight
|
|
482
|
+
self.ignore_index = ignore_index
|
|
483
|
+
self.reduction = reduction
|
|
484
|
+
self.label_smoothing = label_smoothing
|
|
485
|
+
|
|
486
|
+
def construct(self, input, target):
|
|
487
|
+
out = cross_entropy(input, target, self.weight, self.ignore_index, self.reduction, self.label_smoothing)
|
|
488
|
+
return out
|
|
489
|
+
|
|
259
490
|
|
|
260
491
|
class BCEWithLogitsLoss(Cell):
|
|
261
492
|
r"""
|
|
@@ -396,12 +627,6 @@ class GELU(Cell):
|
|
|
396
627
|
>>> print(output)
|
|
397
628
|
[[-1.5880802e-01 3.9999299e+00 -3.1077917e-21]
|
|
398
629
|
[ 1.9545976e+00 -2.2918017e-07 9.0000000e+00]]
|
|
399
|
-
>>> gelu = mint.nn.GELU(approximate=False)
|
|
400
|
-
>>> # CPU not support "approximate=False", using "approximate=True" instead
|
|
401
|
-
>>> output = gelu(input)
|
|
402
|
-
>>> print(output)
|
|
403
|
-
[[-1.5865526e-01 3.9998732e+00 -0.0000000e+00]
|
|
404
|
-
[ 1.9544997e+00 -1.4901161e-06 9.0000000e+00]]
|
|
405
630
|
"""
|
|
406
631
|
|
|
407
632
|
def __init__(self):
|
|
@@ -412,6 +637,82 @@ class GELU(Cell):
|
|
|
412
637
|
return F.gelu(input)
|
|
413
638
|
|
|
414
639
|
|
|
640
|
+
class Hardtanh(Cell):
|
|
641
|
+
r"""
|
|
642
|
+
Activation function Hardtanh.
|
|
643
|
+
|
|
644
|
+
Refer to :func:`mindspore.mint.nn.functional.hardtanh` for more details.
|
|
645
|
+
|
|
646
|
+
Hardtanh Activation Function Graph:
|
|
647
|
+
|
|
648
|
+
.. image:: ../images/Hardtanh.png
|
|
649
|
+
:align: center
|
|
650
|
+
|
|
651
|
+
Supported Platforms:
|
|
652
|
+
``Ascend``
|
|
653
|
+
|
|
654
|
+
Examples:
|
|
655
|
+
>>> import mindspore
|
|
656
|
+
>>> from mindspore import Tensor, mint
|
|
657
|
+
>>> import numpy as np
|
|
658
|
+
>>> input = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float32)
|
|
659
|
+
>>> hardtanh = mint.nn.Hardtanh(min_val=-1.0, max_val=1.0)
|
|
660
|
+
>>> output = hardtanh(x)
|
|
661
|
+
>>> print(output)
|
|
662
|
+
[-1. -1. 0. 1. 1.]
|
|
663
|
+
"""
|
|
664
|
+
|
|
665
|
+
def __init__(self, min_val=-1.0, max_val=1.0, inplace=False):
|
|
666
|
+
"""Initialize ReLU6"""
|
|
667
|
+
super(Hardtanh, self).__init__()
|
|
668
|
+
self.min_val = min_val
|
|
669
|
+
self.max_val = max_val
|
|
670
|
+
self.inplace = inplace
|
|
671
|
+
|
|
672
|
+
def construct(self, input):
|
|
673
|
+
if self.inplace:
|
|
674
|
+
return F.hardtanh_(input, self.min_val, self.max_val)
|
|
675
|
+
return F.hardtanh_op(input, self.min_val, self.max_val)
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
class ReLU6(Cell):
|
|
679
|
+
r"""
|
|
680
|
+
Activation function ReLU6.
|
|
681
|
+
|
|
682
|
+
.. warning::
|
|
683
|
+
This is an experimental API that is subject to change or deletion.
|
|
684
|
+
|
|
685
|
+
Refer to :func:`mindspore.mint.nn.functional.relu6` for more details.
|
|
686
|
+
|
|
687
|
+
ReLU6 Activation Function Graph:
|
|
688
|
+
|
|
689
|
+
.. image:: ../images/ReLU6.png
|
|
690
|
+
:align: center
|
|
691
|
+
|
|
692
|
+
Supported Platforms:
|
|
693
|
+
``Ascend``
|
|
694
|
+
|
|
695
|
+
Examples:
|
|
696
|
+
>>> import mindspore
|
|
697
|
+
>>> from mindspore import Tensor, mint
|
|
698
|
+
>>> import numpy as np
|
|
699
|
+
>>> input = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
|
|
700
|
+
>>> relu6 = mint.nn.ReLU6()
|
|
701
|
+
>>> output = relu6(input)
|
|
702
|
+
>>> print(output)
|
|
703
|
+
[[0. 4. 0.]
|
|
704
|
+
[2. 0. 6.]]
|
|
705
|
+
"""
|
|
706
|
+
|
|
707
|
+
def __init__(self, inplace=False):
|
|
708
|
+
"""Initialize ReLU6"""
|
|
709
|
+
super(ReLU6, self).__init__()
|
|
710
|
+
self.inplace = inplace
|
|
711
|
+
|
|
712
|
+
def construct(self, input):
|
|
713
|
+
return F.relu6(input, self.inplace)
|
|
714
|
+
|
|
715
|
+
|
|
415
716
|
class Mish(Cell):
|
|
416
717
|
r"""
|
|
417
718
|
Computes MISH (A Self Regularized Non-Monotonic Neural Activation Function)
|
|
@@ -523,6 +824,132 @@ class MSELoss(Cell):
|
|
|
523
824
|
return out
|
|
524
825
|
|
|
525
826
|
|
|
827
|
+
class SmoothL1Loss(Cell):
|
|
828
|
+
r"""
|
|
829
|
+
Computes smooth L1 loss, a robust L1 loss.
|
|
830
|
+
|
|
831
|
+
Refer to :func:`mindspore.mint.nn.functional.smooth_l1_loss` for more details.
|
|
832
|
+
|
|
833
|
+
.. warning::
|
|
834
|
+
This is an experimental API that is subject to change or deletion.
|
|
835
|
+
|
|
836
|
+
Supported Platforms:
|
|
837
|
+
``Ascend``
|
|
838
|
+
|
|
839
|
+
Examples:
|
|
840
|
+
>>> import mindspore
|
|
841
|
+
>>> import numpy as np
|
|
842
|
+
>>> from mindspore import Tensor, mint
|
|
843
|
+
>>> input = Tensor(np.array([2, 2, 3]), mindspore.float32)
|
|
844
|
+
>>> target = Tensor(np.array([2, 2, 2]), mindspore.float32)
|
|
845
|
+
>>> beta = 1.0
|
|
846
|
+
>>> reduction_1 = 'none'
|
|
847
|
+
>>> loss1 = mint.nn.SmoothL1Loss(reduction=reduction_1, beta=beta)
|
|
848
|
+
>>> output = loss1(input, target)
|
|
849
|
+
>>> print(output)
|
|
850
|
+
[0. 0. 0.5]
|
|
851
|
+
>>> reduction_2 = 'mean'
|
|
852
|
+
>>> loss2 = mint.nn.SmoothL1Loss(reduction=reduction_2, beta=beta)
|
|
853
|
+
>>> output = loss2(input, target)
|
|
854
|
+
>>> print(output)
|
|
855
|
+
0.16666667
|
|
856
|
+
>>> reduction_3 = 'sum'
|
|
857
|
+
>>> loss3 = mint.nn.SmoothL1Loss(reduction=reduction_3, beta=beta)
|
|
858
|
+
>>> output = loss2(loss3, target)
|
|
859
|
+
>>> print(output)
|
|
860
|
+
0.5
|
|
861
|
+
"""
|
|
862
|
+
|
|
863
|
+
def __init__(self, reduction='mean', beta=1.0):
|
|
864
|
+
super(SmoothL1Loss, self).__init__()
|
|
865
|
+
self.smooth_l1_loss = ops.function.smooth_l1_loss
|
|
866
|
+
self.reduction = reduction
|
|
867
|
+
self.beta = beta
|
|
868
|
+
|
|
869
|
+
def construct(self, input, target):
|
|
870
|
+
out = self.smooth_l1_loss(input, target, self.beta, self.reduction)
|
|
871
|
+
return out
|
|
872
|
+
|
|
873
|
+
|
|
874
|
+
class BCELoss(Cell):
|
|
875
|
+
r"""
|
|
876
|
+
Compute the binary cross entropy between the true labels and predicted labels.
|
|
877
|
+
|
|
878
|
+
Set the predicted labels as :math:`x`, true labels as :math:`y`, the output loss as :math:`\ell(x, y)`.
|
|
879
|
+
The formula is as follow:
|
|
880
|
+
|
|
881
|
+
.. math::
|
|
882
|
+
L = \{l_1,\dots,l_n,\dots,l_N\}^\top, \quad
|
|
883
|
+
l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right]
|
|
884
|
+
|
|
885
|
+
where N is the batch size. Then,
|
|
886
|
+
|
|
887
|
+
.. math::
|
|
888
|
+
\ell(x, y) = \begin{cases}
|
|
889
|
+
L, & \text{if reduction} = \text{'none';}\\
|
|
890
|
+
\operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\
|
|
891
|
+
\operatorname{sum}(L), & \text{if reduction} = \text{'sum'.}
|
|
892
|
+
\end{cases}
|
|
893
|
+
|
|
894
|
+
.. note::
|
|
895
|
+
Note that the predicted labels should always be the output of sigmoid. Because it is a two-class
|
|
896
|
+
classification, the true labels should be numbers between 0 and 1.
|
|
897
|
+
And if :math:`x_n` is either 0 or 1, one of the log terms would be mathematically undefined in the above loss
|
|
898
|
+
equation.
|
|
899
|
+
|
|
900
|
+
.. warning::
|
|
901
|
+
This is an experimental API that is subject to change or deletion.
|
|
902
|
+
|
|
903
|
+
Args:
|
|
904
|
+
weight (Tensor, optional): A rescaling weight applied to the loss of each batch element.
|
|
905
|
+
And it must have the same shape and data type as `inputs`. Default: ``None`` .
|
|
906
|
+
reduction (str, optional): Apply specific reduction method to the output: ``'none'`` , ``'mean'`` ,
|
|
907
|
+
``'sum'`` . Default: ``'mean'`` .
|
|
908
|
+
|
|
909
|
+
- ``'none'``: no reduction will be applied.
|
|
910
|
+
- ``'mean'``: compute and return the weighted mean of elements in the output.
|
|
911
|
+
- ``'sum'``: the output elements will be summed.
|
|
912
|
+
|
|
913
|
+
Inputs:
|
|
914
|
+
- **input** (Tensor) - The input tensor with shape :math:`(N, *)` where :math:`*` means, any number
|
|
915
|
+
of additional dimensions. The data type must be float16 or float32 or bfloat16(only supported
|
|
916
|
+
by Atlas A2 training series products).
|
|
917
|
+
- **target** (Tensor) - The label tensor with shape :math:`(N, *)` where :math:`*` means, any number
|
|
918
|
+
of additional dimensions. The same shape and data type as `input`.
|
|
919
|
+
|
|
920
|
+
Outputs:
|
|
921
|
+
Tensor, has the same dtype as `input`. if `reduction` is ``'none'``, then it has the same shape as `input`.
|
|
922
|
+
Otherwise, it is a scalar Tensor.
|
|
923
|
+
|
|
924
|
+
Raises:
|
|
925
|
+
TypeError: If dtype of `input`, `target` or `weight` (if given) is not float16, float32 or bfloat16.
|
|
926
|
+
ValueError: If `reduction` is not one of ``'none'``, ``'mean'``, ``'sum'``.
|
|
927
|
+
ValueError: If shape of `input` is not the same as `target` or `weight` (if given).
|
|
928
|
+
|
|
929
|
+
Supported Platforms:
|
|
930
|
+
``Ascend``
|
|
931
|
+
|
|
932
|
+
Examples:
|
|
933
|
+
>>> import mindspore as ms
|
|
934
|
+
>>> from mindspore import nn
|
|
935
|
+
>>> import numpy as np
|
|
936
|
+
>>> weight = ms.Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 3.3, 2.2]]), ms.float32)
|
|
937
|
+
>>> loss = nn.BCELoss(weight=weight, reduction='mean')
|
|
938
|
+
>>> input = ms.Tensor(np.array([[0.1, 0.2, 0.3], [0.5, 0.7, 0.9]]), ms.float32)
|
|
939
|
+
>>> target = ms.Tensor(np.array([[0, 1, 0], [0, 0, 1]]), ms.float32)
|
|
940
|
+
>>> output = loss(input, target)
|
|
941
|
+
>>> print(output)
|
|
942
|
+
1.8952923
|
|
943
|
+
"""
|
|
944
|
+
|
|
945
|
+
def __init__(self, weight=None, reduction='mean'):
|
|
946
|
+
super(BCELoss, self).__init__()
|
|
947
|
+
self.bce_loss = nn.loss.BCELoss(weight, reduction)
|
|
948
|
+
|
|
949
|
+
def construct(self, input, target):
|
|
950
|
+
return self.bce_loss(input, target)
|
|
951
|
+
|
|
952
|
+
|
|
526
953
|
__all__ = [
|
|
527
954
|
# 1
|
|
528
955
|
'BCEWithLogitsLoss',
|
|
@@ -533,6 +960,18 @@ __all__ = [
|
|
|
533
960
|
# 4
|
|
534
961
|
|
|
535
962
|
# 5
|
|
963
|
+
'ConstantPad1d',
|
|
964
|
+
'ConstantPad2d',
|
|
965
|
+
'ConstantPad3d',
|
|
966
|
+
'ZeroPad1d',
|
|
967
|
+
'ZeroPad2d',
|
|
968
|
+
'ZeroPad3d',
|
|
969
|
+
'ReflectionPad1d',
|
|
970
|
+
'ReflectionPad2d',
|
|
971
|
+
'ReflectionPad3d',
|
|
972
|
+
'ReplicationPad1d',
|
|
973
|
+
'ReplicationPad2d',
|
|
974
|
+
'ReplicationPad3d',
|
|
536
975
|
|
|
537
976
|
# 6
|
|
538
977
|
'Fold',
|
|
@@ -554,15 +993,15 @@ __all__ = [
|
|
|
554
993
|
# 14
|
|
555
994
|
|
|
556
995
|
# 15
|
|
557
|
-
|
|
996
|
+
'Conv2d',
|
|
558
997
|
# 16
|
|
559
998
|
'LogSoftmax',
|
|
560
999
|
# 17
|
|
561
|
-
|
|
1000
|
+
'ConvTranspose2d',
|
|
562
1001
|
# 18
|
|
563
1002
|
'PReLU',
|
|
564
1003
|
# 19
|
|
565
|
-
|
|
1004
|
+
'Conv3d',
|
|
566
1005
|
# 20
|
|
567
1006
|
|
|
568
1007
|
# 21
|
|
@@ -654,7 +1093,7 @@ __all__ = [
|
|
|
654
1093
|
# 63
|
|
655
1094
|
|
|
656
1095
|
# 64
|
|
657
|
-
|
|
1096
|
+
'SyncBatchNorm',
|
|
658
1097
|
# 65
|
|
659
1098
|
|
|
660
1099
|
# 66
|
|
@@ -724,11 +1163,13 @@ __all__ = [
|
|
|
724
1163
|
'AdaptiveAvgPool2d',
|
|
725
1164
|
|
|
726
1165
|
# 98
|
|
727
|
-
|
|
1166
|
+
'AvgPool1d',
|
|
728
1167
|
# 99
|
|
729
1168
|
'AvgPool2d',
|
|
730
1169
|
# 100
|
|
731
1170
|
'SELU',
|
|
1171
|
+
# 152
|
|
1172
|
+
'AdaptiveAvgPool3d',
|
|
732
1173
|
# 159
|
|
733
1174
|
'GELU',
|
|
734
1175
|
# 220
|
|
@@ -739,15 +1180,33 @@ __all__ = [
|
|
|
739
1180
|
'Hardswish',
|
|
740
1181
|
# 238
|
|
741
1182
|
'L1Loss',
|
|
1183
|
+
# 254
|
|
1184
|
+
'MaxUnpool2d',
|
|
742
1185
|
# 267
|
|
743
1186
|
'Mish',
|
|
744
1187
|
# 258
|
|
745
1188
|
'MSELoss',
|
|
746
1189
|
# 259
|
|
747
1190
|
|
|
1191
|
+
# 294
|
|
1192
|
+
'SmoothL1Loss',
|
|
1193
|
+
|
|
1194
|
+
# 393
|
|
1195
|
+
'Dropout2d',
|
|
1196
|
+
# 406
|
|
1197
|
+
'ELU',
|
|
1198
|
+
# 407
|
|
1199
|
+
'Flatten',
|
|
1200
|
+
# 412
|
|
1201
|
+
'Hardtanh',
|
|
1202
|
+
'ReLU6',
|
|
1203
|
+
# 413
|
|
1204
|
+
'BCELoss',
|
|
1205
|
+
# 421
|
|
1206
|
+
'Tanh',
|
|
1207
|
+
|
|
748
1208
|
# 556
|
|
749
1209
|
'LogSigmoid',
|
|
750
|
-
|
|
751
1210
|
# 674
|
|
752
1211
|
'BatchNorm1d',
|
|
753
1212
|
# 675
|