mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +47 -198
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +229 -99
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +480 -372
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +5 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +975 -1981
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +324 -573
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +183 -117
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +179 -120
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +798 -761
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +933 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1373 -192
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +19 -15
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +52 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
- mindspore/ops/auto_generate/gen_extend_func.py +757 -185
- mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
- mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4351 -3813
- mindspore/ops/function/nn_func.py +1712 -637
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +452 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +23 -7
- mindspore/ops/functional_overload.py +1548 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +23 -15
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +298 -87
- mindspore/ops/operations/debug_ops.py +157 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +212 -531
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1895 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +159 -40
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +700 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -59
- mindspore/parallel/transform_safetensors.py +364 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +416 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +96 -27
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +269 -136
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +552 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -27,13 +27,30 @@ from mindspore.mint.nn.layer.normalization import BatchNorm1d
|
|
|
27
27
|
from mindspore.mint.nn.layer.normalization import BatchNorm2d
|
|
28
28
|
from mindspore.mint.nn.layer.normalization import BatchNorm3d
|
|
29
29
|
from mindspore.mint.nn.layer.normalization import LayerNorm
|
|
30
|
+
from mindspore.mint.nn.layer.normalization import SyncBatchNorm
|
|
30
31
|
from mindspore.mint.nn.layer.activation import LogSigmoid
|
|
31
32
|
from mindspore.mint.nn.layer.activation import SiLU
|
|
33
|
+
from mindspore.mint.nn.layer.activation import Threshold
|
|
34
|
+
from mindspore.mint.nn.layer.basic import Dropout2d
|
|
35
|
+
from mindspore.mint.nn.layer.pooling import AdaptiveMaxPool1d
|
|
32
36
|
from mindspore.mint.nn.layer.pooling import AdaptiveAvgPool1d
|
|
33
37
|
from mindspore.mint.nn.layer.pooling import AdaptiveAvgPool2d
|
|
38
|
+
from mindspore.mint.nn.layer.pooling import AdaptiveAvgPool3d
|
|
34
39
|
|
|
35
40
|
|
|
36
|
-
__all__ = [
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
41
|
+
__all__ = [
|
|
42
|
+
'GroupNorm',
|
|
43
|
+
'BatchNorm1d',
|
|
44
|
+
'BatchNorm2d',
|
|
45
|
+
'BatchNorm3d',
|
|
46
|
+
'LayerNorm',
|
|
47
|
+
'LogSigmoid',
|
|
48
|
+
'SiLU',
|
|
49
|
+
'Dropout2d',
|
|
50
|
+
'AdaptiveMaxPool1d',
|
|
51
|
+
'AdaptiveAvgPool1d',
|
|
52
|
+
'AdaptiveAvgPool2d',
|
|
53
|
+
'AdaptiveAvgPool3d',
|
|
54
|
+
'SyncBatchNorm',
|
|
55
|
+
'Threshold',
|
|
56
|
+
]
|
|
@@ -0,0 +1,334 @@
|
|
|
1
|
+
import mindspore
|
|
2
|
+
from mindspore import Tensor
|
|
3
|
+
from mindspore import context
|
|
4
|
+
import mindspore.communication
|
|
5
|
+
import mindspore.communication.comm_func
|
|
6
|
+
from mindspore.nn.cell import Cell
|
|
7
|
+
from mindspore.ops.auto_generate.gen_ops_prim import BatchNormReduceGrad
|
|
8
|
+
from mindspore.ops.auto_generate.gen_ops_prim import BatchNormElemtGrad
|
|
9
|
+
from mindspore.communication import GlobalComm
|
|
10
|
+
from mindspore.ops import ReduceOp
|
|
11
|
+
from mindspore._c_expression import TensorPy as Tensor_
|
|
12
|
+
from mindspore.communication._comm_helper import _get_size_helper, HCCL_WORLD_COMM_GROUP
|
|
13
|
+
from mindspore.ops._primitive_cache import _get_cache_prim
|
|
14
|
+
from mindspore.communication.comm_func import all_gather_into_tensor as all_gather_into_tensor_dy
|
|
15
|
+
from mindspore.ops import operations as P
|
|
16
|
+
from mindspore import ops, mint
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
DEFAULT_WORLD_COMM_GROUP = HCCL_WORLD_COMM_GROUP
|
|
20
|
+
|
|
21
|
+
batch_norm_reduce_grad = BatchNormReduceGrad()
|
|
22
|
+
batch_norm_elemt_grad = BatchNormElemtGrad()
|
|
23
|
+
shape = P.Shape()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _deal_comm_outputs(output, async_op):
|
|
27
|
+
if isinstance(output, tuple):
|
|
28
|
+
if not async_op:
|
|
29
|
+
output[1].wait()
|
|
30
|
+
return output[0]
|
|
31
|
+
return output
|
|
32
|
+
|
|
33
|
+
if not async_op:
|
|
34
|
+
return output
|
|
35
|
+
return output
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_group_size(group=GlobalComm.WORLD_COMM_GROUP):
|
|
39
|
+
if not isinstance(group, str):
|
|
40
|
+
raise TypeError("For 'get_group_size', the argument 'group' must be type of string, "
|
|
41
|
+
"but got 'group' type : {}.".format(type(group)))
|
|
42
|
+
return _get_size_helper(group=_get_group(group))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _contiguous(tensor):
|
|
46
|
+
if not tensor.is_contiguous() or tensor.storage_offset() != 0:
|
|
47
|
+
tensor = tensor.contiguous()
|
|
48
|
+
return tensor
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _get_group(group):
|
|
52
|
+
"""Return the world communication group if the `group` is `DEFAULT_WORLD_COMM_GROUP`."""
|
|
53
|
+
if group == DEFAULT_WORLD_COMM_GROUP:
|
|
54
|
+
return GlobalComm.WORLD_COMM_GROUP
|
|
55
|
+
return group
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def all_gather_into_tensor(tensor, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
|
|
59
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
60
|
+
raise TypeError(
|
|
61
|
+
"For all_gather_into_tensor, the input tensor must be tensor")
|
|
62
|
+
group = _get_group(group)
|
|
63
|
+
tensor = _contiguous(tensor)
|
|
64
|
+
all_gather_op = _get_cache_prim(P.AllGather)(group=group)
|
|
65
|
+
output = all_gather_op(tensor)
|
|
66
|
+
return _deal_comm_outputs(output, async_op)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def all_reduce(tensor, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP, async_op=False):
|
|
70
|
+
if not isinstance(tensor, (Tensor, Tensor_)):
|
|
71
|
+
raise TypeError("For all_reduce, the input tensor must be tensor")
|
|
72
|
+
if not isinstance(op, str):
|
|
73
|
+
raise TypeError("For all_reduce, the input op type must be str")
|
|
74
|
+
if op not in ('sum', 'prod', 'min', 'max'):
|
|
75
|
+
raise TypeError(
|
|
76
|
+
"For all_reduce, the input op value must be one of sum, prod, min, max")
|
|
77
|
+
group = _get_group(group)
|
|
78
|
+
tensor = _contiguous(tensor)
|
|
79
|
+
all_reduce_op = _get_cache_prim(P.AllReduce)(op=op, group=group)
|
|
80
|
+
output = all_reduce_op(tensor)
|
|
81
|
+
return _deal_comm_outputs(output, async_op)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def bprop_pynative(input_x, weight, bias, running_mean, running_var, eps, momentum,
|
|
85
|
+
process_group, world_size, output, doutput):
|
|
86
|
+
_, mean_param, invstd_param, count_all_param = output
|
|
87
|
+
dout, _, _, _ = doutput
|
|
88
|
+
|
|
89
|
+
# 不支持 KBK模式
|
|
90
|
+
if not dout.is_contiguous():
|
|
91
|
+
dout = dout.contiguous()
|
|
92
|
+
|
|
93
|
+
grad_input = grad_weight = grad_bias = None
|
|
94
|
+
|
|
95
|
+
inputG = True
|
|
96
|
+
weightG = True
|
|
97
|
+
biasG = True
|
|
98
|
+
|
|
99
|
+
# calculate local stats as well as grad_weight / grad_bias
|
|
100
|
+
sum_dy, sum_dy_xmu, grad_weight, grad_bias = batch_norm_reduce_grad(
|
|
101
|
+
dout,
|
|
102
|
+
input_x,
|
|
103
|
+
mean_param,
|
|
104
|
+
invstd_param,
|
|
105
|
+
weight,
|
|
106
|
+
inputG,
|
|
107
|
+
weightG,
|
|
108
|
+
biasG
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if inputG:
|
|
112
|
+
# synchronizing stats used to calculate input gradient.
|
|
113
|
+
sum_dy_shape = shape(sum_dy)
|
|
114
|
+
num_channels = sum_dy_shape[0]
|
|
115
|
+
combined = mint.cat([sum_dy, sum_dy_xmu], dim=0)
|
|
116
|
+
|
|
117
|
+
new_combined, _ = mindspore.communication.comm_func.all_reduce(
|
|
118
|
+
combined, group=process_group)
|
|
119
|
+
|
|
120
|
+
sum_dy, sum_dy_xmu = mint.split(new_combined, num_channels)
|
|
121
|
+
|
|
122
|
+
# backward pass for gradient calculation
|
|
123
|
+
grad_input = batch_norm_elemt_grad(
|
|
124
|
+
dout,
|
|
125
|
+
input_x,
|
|
126
|
+
mean_param,
|
|
127
|
+
invstd_param,
|
|
128
|
+
weight,
|
|
129
|
+
sum_dy,
|
|
130
|
+
sum_dy_xmu,
|
|
131
|
+
count_all_param
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# synchronizing of grad_weight / grad_bias is not needed as distributed
|
|
135
|
+
# training would handle all reduce.
|
|
136
|
+
if weight is None or not weightG:
|
|
137
|
+
grad_weight = None
|
|
138
|
+
|
|
139
|
+
if weight is None or not biasG:
|
|
140
|
+
grad_bias = None
|
|
141
|
+
|
|
142
|
+
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def bprop_kbk(input_x, weight, bias, running_mean, running_var, eps, momentum,
|
|
146
|
+
process_group, world_size, output, doutput):
|
|
147
|
+
_, mean_param, invstd_param, count_all_param = output
|
|
148
|
+
dout, _, _, _ = doutput
|
|
149
|
+
|
|
150
|
+
dout = dout.contiguous()
|
|
151
|
+
|
|
152
|
+
grad_input = grad_weight = grad_bias = None
|
|
153
|
+
|
|
154
|
+
inputG = True
|
|
155
|
+
weightG = True
|
|
156
|
+
biasG = True
|
|
157
|
+
|
|
158
|
+
# calculate local stats as well as grad_weight / grad_bias
|
|
159
|
+
sum_dy, sum_dy_xmu, grad_weight, grad_bias = batch_norm_reduce_grad(
|
|
160
|
+
dout,
|
|
161
|
+
input_x,
|
|
162
|
+
mean_param,
|
|
163
|
+
invstd_param,
|
|
164
|
+
weight,
|
|
165
|
+
inputG,
|
|
166
|
+
weightG,
|
|
167
|
+
biasG
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
if inputG:
|
|
171
|
+
# synchronizing stats used to calculate input gradient.
|
|
172
|
+
sum_dy_shape = shape(sum_dy)
|
|
173
|
+
num_channels = sum_dy_shape[0]
|
|
174
|
+
combined = mint.cat([sum_dy, sum_dy_xmu], dim=0)
|
|
175
|
+
|
|
176
|
+
new_combined = all_reduce(combined, group=process_group)
|
|
177
|
+
|
|
178
|
+
sum_dy, sum_dy_xmu = mint.split(new_combined, num_channels)
|
|
179
|
+
|
|
180
|
+
# backward pass for gradient calculation
|
|
181
|
+
grad_input = batch_norm_elemt_grad(
|
|
182
|
+
dout,
|
|
183
|
+
input_x,
|
|
184
|
+
mean_param,
|
|
185
|
+
invstd_param,
|
|
186
|
+
weight,
|
|
187
|
+
sum_dy,
|
|
188
|
+
sum_dy_xmu,
|
|
189
|
+
count_all_param
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
# synchronizing of grad_weight / grad_bias is not needed as distributed
|
|
193
|
+
# training would handle all reduce.
|
|
194
|
+
if weight is None or not weightG:
|
|
195
|
+
grad_weight = None
|
|
196
|
+
|
|
197
|
+
if weight is None or not biasG:
|
|
198
|
+
grad_bias = None
|
|
199
|
+
|
|
200
|
+
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def construct_pynative(input, weight, bias, running_mean, running_var, eps, momentum, process_group,
|
|
204
|
+
world_size, self_num_features, self_world_size):
|
|
205
|
+
if self_world_size != world_size:
|
|
206
|
+
raise ValueError('World Size Error')
|
|
207
|
+
if not input.is_contiguous():
|
|
208
|
+
input = input.contiguous()
|
|
209
|
+
if weight is not None:
|
|
210
|
+
weight = weight.contiguous()
|
|
211
|
+
|
|
212
|
+
input_shape = shape(input)
|
|
213
|
+
input_numel = ops.numel(input)
|
|
214
|
+
size = int(input_numel // input_shape[1])
|
|
215
|
+
if size == 1 and world_size < 2:
|
|
216
|
+
raise ValueError(
|
|
217
|
+
'Expected more than 1 value per channel when training, got input size {}'.format(size))
|
|
218
|
+
|
|
219
|
+
# calculate mean/invstd for input.
|
|
220
|
+
mean, invstd = mint.batch_norm_stats(input, eps)
|
|
221
|
+
count = mint.full((1,), input_numel //
|
|
222
|
+
input_shape[1], dtype=mean.dtype)
|
|
223
|
+
|
|
224
|
+
num_channels = input_shape[1]
|
|
225
|
+
if self_num_features != num_channels:
|
|
226
|
+
raise ValueError('Features Error')
|
|
227
|
+
# C, C, 1 -> (2C + 1)
|
|
228
|
+
combined = mint.cat([mean, invstd, count], dim=0)
|
|
229
|
+
# Use allgather instead of allreduce because count could be different across
|
|
230
|
+
# ranks, simple all reduce op can not give correct results.
|
|
231
|
+
# batch_norm_gather_stats_with_counts calculates global mean & invstd based on
|
|
232
|
+
# all gathered mean, invstd and count.
|
|
233
|
+
# world_size * (2C + 1)
|
|
234
|
+
combined, _ = all_gather_into_tensor_dy(combined, process_group)
|
|
235
|
+
combined = ops.reshape(combined, [world_size, -1])
|
|
236
|
+
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
|
|
237
|
+
mean_val_all, invstd_val_all, count_val_all = mint.split(
|
|
238
|
+
combined, num_channels, dim=1)
|
|
239
|
+
# calculate global mean & invstd
|
|
240
|
+
mean, invstd = mint.batch_norm_gather_stats_with_counts(input, mean_val_all, invstd_val_all, running_mean,
|
|
241
|
+
running_var, momentum, eps, count_val_all.view(-1))
|
|
242
|
+
|
|
243
|
+
# apply element-wise normalization
|
|
244
|
+
out = mint.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
|
|
245
|
+
return (out, mean, invstd, count_val_all.view(-1))
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def construct_kbk(input, weight, bias, running_mean, running_var, eps, momentum, process_group,
|
|
249
|
+
world_size, self_num_features, self_world_size):
|
|
250
|
+
if self_world_size != world_size:
|
|
251
|
+
raise ValueError('World Size Error')
|
|
252
|
+
input = input.contiguous()
|
|
253
|
+
if weight is not None:
|
|
254
|
+
weight = weight.contiguous()
|
|
255
|
+
|
|
256
|
+
input_shape = shape(input)
|
|
257
|
+
input_numel = ops.numel(input)
|
|
258
|
+
size = int(input_numel // input_shape[1])
|
|
259
|
+
if size == 1 and world_size < 2:
|
|
260
|
+
raise ValueError(
|
|
261
|
+
'Expected more than 1 value per channel when training, got input size {}'.format(size))
|
|
262
|
+
|
|
263
|
+
# calculate mean/invstd for input.
|
|
264
|
+
mean, invstd = mint.batch_norm_stats(input, eps)
|
|
265
|
+
count = mint.full((1,), input_numel //
|
|
266
|
+
input_shape[1], dtype=mean.dtype)
|
|
267
|
+
|
|
268
|
+
num_channels = input_shape[1]
|
|
269
|
+
if self_num_features != num_channels:
|
|
270
|
+
raise ValueError('Features Error')
|
|
271
|
+
# C, C, 1 -> (2C + 1)
|
|
272
|
+
combined = mint.cat([mean, invstd, count], dim=0)
|
|
273
|
+
# Use allgather instead of allreduce because count could be different across
|
|
274
|
+
# ranks, simple all reduce op can not give correct results.
|
|
275
|
+
# batch_norm_gather_stats_with_counts calculates global mean & invstd based on
|
|
276
|
+
# all gathered mean, invstd and count.
|
|
277
|
+
# world_size * (2C + 1)
|
|
278
|
+
combined = all_gather_into_tensor(combined, process_group)
|
|
279
|
+
combined = ops.reshape(combined, [world_size, -1])
|
|
280
|
+
# world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1
|
|
281
|
+
mean_all, invstd_all, count_all = mint.split(
|
|
282
|
+
combined, num_channels, dim=1)
|
|
283
|
+
# calculate global mean & invstd
|
|
284
|
+
mean, invstd = mint.batch_norm_gather_stats_with_counts(
|
|
285
|
+
input,
|
|
286
|
+
mean_all,
|
|
287
|
+
invstd_all,
|
|
288
|
+
running_mean,
|
|
289
|
+
running_var,
|
|
290
|
+
momentum,
|
|
291
|
+
eps,
|
|
292
|
+
count_all.view(-1)
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
# apply element-wise normalization
|
|
296
|
+
out = mint.batch_norm_elemt(input, weight, bias, mean, invstd, eps)
|
|
297
|
+
return (out, mean, invstd, count_all.view(-1))
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class SyncBatchNormInner(Cell):
|
|
301
|
+
def __init__(self, self_num_features, self_world_size):
|
|
302
|
+
super(SyncBatchNormInner, self).__init__()
|
|
303
|
+
self.num_features = self_num_features
|
|
304
|
+
self.world_size = self_world_size
|
|
305
|
+
self.mode = context.get_context("mode")
|
|
306
|
+
if self.mode == 1:
|
|
307
|
+
self.fn_bprop = bprop_pynative
|
|
308
|
+
self.fn_construct = construct_pynative
|
|
309
|
+
else:
|
|
310
|
+
self.fn_bprop = bprop_kbk
|
|
311
|
+
self.fn_construct = construct_kbk
|
|
312
|
+
|
|
313
|
+
def construct(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
|
|
314
|
+
return self.fn_construct(input, weight, bias, running_mean, running_var, eps, momentum, process_group,
|
|
315
|
+
world_size, self.num_features, self.world_size)
|
|
316
|
+
|
|
317
|
+
def bprop(self, input_x, weight, bias, running_mean, running_var, eps, momentum,
|
|
318
|
+
process_group, world_size, output, doutput):
|
|
319
|
+
return self.fn_bprop(input_x, weight, bias, running_mean, running_var, eps, momentum,
|
|
320
|
+
process_group, world_size, output, doutput)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class _SyncBatchNorm(Cell):
|
|
324
|
+
def __init__(self, num_features, world_size, dtype=mindspore.float32):
|
|
325
|
+
super(_SyncBatchNorm, self).__init__()
|
|
326
|
+
self.num_features = num_features
|
|
327
|
+
self.world_size = world_size
|
|
328
|
+
self.inner = SyncBatchNormInner(self.num_features, self.world_size)
|
|
329
|
+
|
|
330
|
+
def construct(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size):
|
|
331
|
+
res = self.inner(input, weight, bias, running_mean,
|
|
332
|
+
running_var, eps, momentum, process_group, world_size)
|
|
333
|
+
output, _, _, _ = res
|
|
334
|
+
return output
|
|
@@ -77,6 +77,55 @@ class SiLU(Cell):
|
|
|
77
77
|
return mint.nn.functional.silu(x)
|
|
78
78
|
|
|
79
79
|
|
|
80
|
+
class Sigmoid(Cell):
|
|
81
|
+
r"""
|
|
82
|
+
Applies sigmoid activation function element-wise.
|
|
83
|
+
|
|
84
|
+
Sigmoid function is defined as:
|
|
85
|
+
|
|
86
|
+
.. math::
|
|
87
|
+
|
|
88
|
+
\text{sigmoid}(x_i) = \frac{1}{1 + \exp(-x_i)},
|
|
89
|
+
|
|
90
|
+
where :math:`x_i` is the element of `x`.
|
|
91
|
+
|
|
92
|
+
Sigmoid Activation Function Graph:
|
|
93
|
+
|
|
94
|
+
.. image:: ../images/Sigmoid.png
|
|
95
|
+
:align: center
|
|
96
|
+
|
|
97
|
+
Inputs:
|
|
98
|
+
- **input** (Tensor) - `input` is :math:`x` in the preceding formula. Tensor of any dimension,
|
|
99
|
+
the data type is float16, float32, float64, complex64 or complex128.
|
|
100
|
+
|
|
101
|
+
Outputs:
|
|
102
|
+
Tensor, with the same type and shape as the `input`.
|
|
103
|
+
|
|
104
|
+
Raises:
|
|
105
|
+
TypeError: If dtype of `input` is not float16, float32, float64, complex64 or complex128.
|
|
106
|
+
TypeError: If `input` is not a Tensor.
|
|
107
|
+
|
|
108
|
+
Supported Platforms:
|
|
109
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
110
|
+
|
|
111
|
+
Examples:
|
|
112
|
+
>>> import mindspore
|
|
113
|
+
>>> from mindspore import Tensor, nn
|
|
114
|
+
>>> import numpy as np
|
|
115
|
+
>>> input = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
|
|
116
|
+
>>> sigmoid = mint.nn.Sigmoid()
|
|
117
|
+
>>> output = sigmoid(input)
|
|
118
|
+
>>> print(output)
|
|
119
|
+
[0.2688 0.11914 0.5 0.881 0.7305 ]
|
|
120
|
+
"""
|
|
121
|
+
def __init__(self):
|
|
122
|
+
"""Initialize LogSigmoid."""
|
|
123
|
+
super(Sigmoid, self).__init__()
|
|
124
|
+
|
|
125
|
+
def construct(self, input):
|
|
126
|
+
return mint.nn.functional.sigmoid(input)
|
|
127
|
+
|
|
128
|
+
|
|
80
129
|
class LogSigmoid(Cell):
|
|
81
130
|
r"""
|
|
82
131
|
Applies logsigmoid activation element-wise. The input is a Tensor with any valid shape.
|
|
@@ -84,7 +133,7 @@ class LogSigmoid(Cell):
|
|
|
84
133
|
Logsigmoid is defined as:
|
|
85
134
|
|
|
86
135
|
.. math::
|
|
87
|
-
\text{
|
|
136
|
+
\text{LogSigmoid}(x_{i}) = \log(\frac{1}{1 + \exp(-x_i)}),
|
|
88
137
|
|
|
89
138
|
where :math:`x_{i}` is the element of the input.
|
|
90
139
|
|
|
@@ -127,7 +176,233 @@ class LogSigmoid(Cell):
|
|
|
127
176
|
return mint.nn.functional.logsigmoid(input)
|
|
128
177
|
|
|
129
178
|
|
|
179
|
+
class ELU(Cell):
|
|
180
|
+
r"""
|
|
181
|
+
Exponential Linear Unit activation function
|
|
182
|
+
|
|
183
|
+
Applies the exponential linear unit function element-wise.The activation function is defined as:
|
|
184
|
+
|
|
185
|
+
.. math::
|
|
186
|
+
ELU_{i} =
|
|
187
|
+
\begin{cases}
|
|
188
|
+
x_i, &\text{if } x_i \geq 0; \cr
|
|
189
|
+
\alpha * (\exp(x_i) - 1), &\text{otherwise.}
|
|
190
|
+
\end{cases}
|
|
191
|
+
|
|
192
|
+
where :math:`x_i` represents the element of the input and :math:`\alpha` represents the `alpha` parameter, and
|
|
193
|
+
`alpha` represents the smoothness of the ELU.
|
|
194
|
+
|
|
195
|
+
ELU Activation Function Graph:
|
|
196
|
+
|
|
197
|
+
.. image:: ../images/ELU.png
|
|
198
|
+
:align: center
|
|
199
|
+
|
|
200
|
+
.. warning::
|
|
201
|
+
This is an experimental API that is subject to change or deletion.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
alpha (float, optional): The alpha value of ELU, the data type is float. Default: ``1.0``.
|
|
205
|
+
inplace (bool, optional): Whether to use inplace mode, the data type is bool. Default: ``False``.
|
|
206
|
+
|
|
207
|
+
Inputs:
|
|
208
|
+
- **input** (Tensor) - The input of ELU is a Tensor of any dimension.
|
|
209
|
+
|
|
210
|
+
Outputs:
|
|
211
|
+
Tensor, with the same shape and type as the `input`.
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
RuntimeError: If the dtype of `input` is not float16, float32 or bfloat16.
|
|
215
|
+
TypeError: If the dtype of `alpha` is not float.
|
|
216
|
+
|
|
217
|
+
Supported Platforms:
|
|
218
|
+
``Ascend``
|
|
219
|
+
|
|
220
|
+
Examples:
|
|
221
|
+
>>> import mindspore
|
|
222
|
+
>>> from mindspore import Tensor, mint
|
|
223
|
+
>>> import numpy as np
|
|
224
|
+
>>> input = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float32)
|
|
225
|
+
>>> elu = mint.nn.ELU()
|
|
226
|
+
>>> result = elu(input)
|
|
227
|
+
>>> print(result)
|
|
228
|
+
[-0.63212055 -0.86466473 0. 2. 1.]
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
def __init__(self, alpha=1.0, inplace=False):
|
|
232
|
+
"""Initialize ELU."""
|
|
233
|
+
super(ELU, self).__init__()
|
|
234
|
+
self.alpha = alpha
|
|
235
|
+
self.inplace = inplace
|
|
236
|
+
|
|
237
|
+
def construct(self, input):
|
|
238
|
+
return mint.nn.functional.elu(input, self.alpha, self.inplace)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
class GLU(Cell):
|
|
242
|
+
r"""
|
|
243
|
+
Computes GLU (Gated Linear Unit activation function) of the input tensor.
|
|
244
|
+
|
|
245
|
+
.. math::
|
|
246
|
+
{GLU}(a, b)= a \otimes \sigma(b)
|
|
247
|
+
|
|
248
|
+
where :math:`a` is the first half of the `input` Tensor after `input` is split and :math:`b` is the second half.
|
|
249
|
+
|
|
250
|
+
Here :math:`\sigma` is the sigmoid function, and :math:`\otimes` is the Hadamard product.
|
|
251
|
+
See `Language Modeling with Gated Convluational Networks <https://arxiv.org/abs/1612.08083>`_ .
|
|
252
|
+
|
|
253
|
+
Args:
|
|
254
|
+
dim (int, optional): The dimension to split the input `input`. The value range is `[-r, r)` where `r`
|
|
255
|
+
is the number of dimensions of `input`. Default: ``-1`` , the last dimension in `input`.
|
|
256
|
+
|
|
257
|
+
Inputs:
|
|
258
|
+
- **input** (Tensor) - Tensor to be calculated. Dtype is floating point and the shape
|
|
259
|
+
is :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional dimensions. :math:`N`
|
|
260
|
+
is required to be an even number, where :math:`N` is the size of `input` on the dimension
|
|
261
|
+
selected by `dim`.
|
|
262
|
+
|
|
263
|
+
Outputs:
|
|
264
|
+
Tensor, the same dtype as the `input`, with the shape :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`.
|
|
265
|
+
|
|
266
|
+
Raises:
|
|
267
|
+
TypeError: If `input` is not a Tensor or `dim` is not an int.
|
|
268
|
+
IndexError: If the value of `dim` is out of the range of `[-r, r)`, where `r` is the number
|
|
269
|
+
of dimensions of `input`.
|
|
270
|
+
RuntimeError: If dtype of `input` is not supported.
|
|
271
|
+
RuntimeError: If the length of `input` in the dimension selected by `dim` is not even.
|
|
272
|
+
|
|
273
|
+
Supported Platforms:
|
|
274
|
+
``Ascend`` ``CPU``
|
|
275
|
+
|
|
276
|
+
Examples:
|
|
277
|
+
>>> from mindspore import mint, Tensor
|
|
278
|
+
>>> glu = mint.nn.GLU()
|
|
279
|
+
>>> input = Tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]])
|
|
280
|
+
>>> output = glu(input)
|
|
281
|
+
>>> print(output)
|
|
282
|
+
[[0.05744425 0.11973753]
|
|
283
|
+
[0.33409387 0.41398472]]
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
def __init__(self, dim=-1):
|
|
287
|
+
"""Initialize GLU."""
|
|
288
|
+
super().__init__("GLU")
|
|
289
|
+
self.dim = dim
|
|
290
|
+
|
|
291
|
+
def construct(self, input):
|
|
292
|
+
return mint.nn.functional.glu(input, self.dim)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class Tanh(Cell):
|
|
296
|
+
r"""
|
|
297
|
+
Applies the Tanh function element-wise, returns a new tensor with the hyperbolic tangent of the elements of input.
|
|
298
|
+
|
|
299
|
+
Tanh function is defined as:
|
|
300
|
+
|
|
301
|
+
.. math::
|
|
302
|
+
tanh(x_i) = \frac{\exp(x_i) - \exp(-x_i)}{\exp(x_i) + \exp(-x_i)} = \frac{\exp(2x_i) - 1}{\exp(2x_i) + 1},
|
|
303
|
+
|
|
304
|
+
where :math:`x_i` is an element of the input Tensor.
|
|
305
|
+
|
|
306
|
+
Tanh Activation Function Graph:
|
|
307
|
+
|
|
308
|
+
.. image:: ../images/Tanh.png
|
|
309
|
+
:align: center
|
|
310
|
+
|
|
311
|
+
.. warning::
|
|
312
|
+
This is an experimental API that is subject to change or deletion.
|
|
313
|
+
|
|
314
|
+
Inputs:
|
|
315
|
+
- **input** (Tensor) - Tensor of any dimension, input with data type of float16 or float32.
|
|
316
|
+
|
|
317
|
+
Outputs:
|
|
318
|
+
Tensor, with the same type and shape as the `input`.
|
|
319
|
+
|
|
320
|
+
Raises:
|
|
321
|
+
TypeError: If dtype of `input` is neither float16 nor float32.
|
|
322
|
+
|
|
323
|
+
Supported Platforms:
|
|
324
|
+
``Ascend``
|
|
325
|
+
|
|
326
|
+
Examples:
|
|
327
|
+
>>> import mindspore
|
|
328
|
+
>>> from mindspore import Tensor, mint
|
|
329
|
+
>>> import numpy as np
|
|
330
|
+
>>> input = Tensor(np.array([1, 2, 3, 2, 1]), mindspore.float16)
|
|
331
|
+
>>> tanh = mint.nn.Tanh()
|
|
332
|
+
>>> output = tanh(input)
|
|
333
|
+
>>> print(output)
|
|
334
|
+
[0.7617 0.964 0.995 0.964 0.7617]
|
|
335
|
+
"""
|
|
336
|
+
|
|
337
|
+
def __init__(self):
|
|
338
|
+
"""Initialize Tanh."""
|
|
339
|
+
super(Tanh, self).__init__()
|
|
340
|
+
|
|
341
|
+
def construct(self, input):
|
|
342
|
+
return mint.nn.functional.tanh(input)
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class Threshold(Cell):
|
|
346
|
+
r"""
|
|
347
|
+
Compute the Threshold activation function element-wise.
|
|
348
|
+
|
|
349
|
+
The Threshold is defined as:
|
|
350
|
+
|
|
351
|
+
.. math::
|
|
352
|
+
y =
|
|
353
|
+
\begin{cases}
|
|
354
|
+
x, &\text{ if } x > \text{threshold} \\
|
|
355
|
+
\text{value}, &\text{ otherwise }
|
|
356
|
+
\end{cases}
|
|
357
|
+
|
|
358
|
+
.. warning::
|
|
359
|
+
This is an experimental API that is subject to change or deletion.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
threshold (Union[int, float]): The value of the threshold.
|
|
363
|
+
value (Union[int, float]): The value to replace with when element is less than threshold.
|
|
364
|
+
inplace (bool, optional): Whether to apply erasing inplace. Default: ``False``.
|
|
365
|
+
|
|
366
|
+
Inputs:
|
|
367
|
+
- **input** (Tensor) - The input Tensor.
|
|
368
|
+
|
|
369
|
+
Outputs:
|
|
370
|
+
Tensor, the same shape and data type as the input.
|
|
371
|
+
|
|
372
|
+
Raises:
|
|
373
|
+
TypeError: If `input` is not a Tensor.
|
|
374
|
+
TypeError: If `threshold` is not a float or an int.
|
|
375
|
+
TypeError: If `value` is not a float or an int.
|
|
376
|
+
|
|
377
|
+
Supported Platforms:
|
|
378
|
+
``Ascend``
|
|
379
|
+
|
|
380
|
+
Examples:
|
|
381
|
+
>>> import mindspore
|
|
382
|
+
>>> from mindspore import Tensor, mint
|
|
383
|
+
>>> inputs = mindspore.Tensor([0.0, 2, 3], mindspore.float32)
|
|
384
|
+
>>> net = mint.nn.Threshold(1, 100)
|
|
385
|
+
>>> outputs = net(inputs)
|
|
386
|
+
>>> print(outputs)
|
|
387
|
+
[100. 2. 3.]
|
|
388
|
+
"""
|
|
389
|
+
|
|
390
|
+
def __init__(self, threshold, value, inplace=False):
|
|
391
|
+
"""Initialize Tanh."""
|
|
392
|
+
super(Threshold, self).__init__()
|
|
393
|
+
self.threshold = threshold
|
|
394
|
+
self.value = value
|
|
395
|
+
self.inplace = inplace
|
|
396
|
+
|
|
397
|
+
def construct(self, input):
|
|
398
|
+
return mint.nn.functional.threshold(input, self.threshold, self.value,
|
|
399
|
+
self.inplace)
|
|
400
|
+
|
|
130
401
|
__all__ = [
|
|
131
402
|
'LogSigmoid',
|
|
132
403
|
'SiLU',
|
|
404
|
+
'ELU',
|
|
405
|
+
'GLU',
|
|
406
|
+
'Tanh',
|
|
407
|
+
'Threshold',
|
|
133
408
|
]
|