mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.6.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
- mindspore/Newtonsoft.Json.dll +0 -0
- mindspore/__init__.py +13 -6
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_check_jit_forbidden_api.py +3 -0
- mindspore/_checkparam.py +3 -38
- mindspore/_deprecated/__init__.py +17 -0
- mindspore/_deprecated/jit.py +198 -0
- mindspore/_extends/builtin_operations.py +1 -1
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +6 -7
- mindspore/_extends/parse/compile_config.py +83 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +394 -0
- mindspore/_extends/parse/jit_fallback_modules/__init__.py +0 -0
- mindspore/_extends/parse/jit_fallback_modules/check_utils.py +123 -0
- mindspore/_extends/parse/jit_fallback_modules/third_party_modules.py +50 -0
- mindspore/_extends/parse/parser.py +47 -198
- mindspore/_extends/parse/resources.py +1 -5
- mindspore/_extends/parse/standard_method.py +229 -99
- mindspore/_extends/pijit/__init__.py +2 -2
- mindspore/_extends/pijit/pijit_func_white_list.py +17 -12
- mindspore/_extends/pijit/tensor_func_list.py +27 -0
- mindspore/_extends/utils.py +1 -1
- mindspore/amp.py +11 -5
- mindspore/atlprov.dll +0 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/__init__.py +2 -2
- mindspore/boost/base.py +3 -7
- mindspore/boost/boost_cell_wrapper.py +138 -43
- mindspore/c1.dll +0 -0
- mindspore/c1xx.dll +0 -0
- mindspore/c2.dll +0 -0
- mindspore/common/__init__.py +6 -3
- mindspore/common/_grad_function.py +56 -0
- mindspore/common/_pijit_context.py +14 -5
- mindspore/common/_register_for_tensor.py +1 -2
- mindspore/common/_stub_tensor.py +30 -14
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +4760 -0
- mindspore/common/api.py +480 -372
- mindspore/common/auto_dynamic_shape.py +41 -44
- mindspore/common/dtype.py +39 -36
- mindspore/common/dump.py +9 -6
- mindspore/common/file_system.py +9 -1
- mindspore/common/generator.py +5 -0
- mindspore/common/hook_handle.py +6 -2
- mindspore/common/initializer.py +13 -10
- mindspore/common/jit_begin_end.py +94 -0
- mindspore/common/jit_config.py +6 -1
- mindspore/common/jit_context.py +76 -0
- mindspore/common/jit_trace.py +378 -0
- mindspore/common/lazy_inline.py +9 -3
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/mutable.py +5 -4
- mindspore/common/parameter.py +135 -52
- mindspore/common/seed.py +2 -2
- mindspore/common/sparse_tensor.py +23 -17
- mindspore/common/tensor.py +975 -1981
- mindspore/communication/__init__.py +7 -5
- mindspore/communication/_comm_helper.py +52 -2
- mindspore/communication/comm_func.py +240 -181
- mindspore/communication/management.py +95 -26
- mindspore/context.py +324 -573
- mindspore/dataset/__init__.py +65 -37
- mindspore/dataset/audio/__init__.py +2 -8
- mindspore/dataset/audio/transforms.py +3 -17
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +87 -6
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +6 -5
- mindspore/dataset/engine/datasets.py +292 -267
- mindspore/dataset/engine/datasets_audio.py +22 -8
- mindspore/dataset/engine/datasets_standard_format.py +46 -27
- mindspore/dataset/engine/datasets_text.py +78 -48
- mindspore/dataset/engine/datasets_user_defined.py +183 -117
- mindspore/dataset/engine/datasets_vision.py +120 -44
- mindspore/dataset/engine/iterators.py +283 -63
- mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +1 -1
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +289 -43
- mindspore/dataset/engine/serializer_deserializer.py +3 -2
- mindspore/dataset/engine/validators.py +53 -11
- mindspore/dataset/text/__init__.py +7 -6
- mindspore/dataset/text/transforms.py +6 -5
- mindspore/dataset/text/utils.py +3 -3
- mindspore/dataset/transforms/__init__.py +0 -9
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +31 -14
- mindspore/dataset/utils/browse_dataset.py +1 -1
- mindspore/dataset/vision/__init__.py +2 -9
- mindspore/dataset/vision/transforms.py +202 -158
- mindspore/dataset/vision/utils.py +7 -5
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +153 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +123 -0
- mindspore/{ops_generate/gen_constants.py → device_context/cpu/__init__.py} +6 -17
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +170 -0
- mindspore/dnnl.dll +0 -0
- mindspore/dpcmi.dll +0 -0
- mindspore/experimental/es/embedding_service.py +35 -27
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +209 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/map_parameter.py +4 -4
- mindspore/experimental/optim/adadelta.py +6 -6
- mindspore/experimental/optim/adagrad.py +4 -4
- mindspore/experimental/optim/adam.py +7 -0
- mindspore/experimental/optim/adamax.py +4 -4
- mindspore/experimental/optim/adamw.py +4 -0
- mindspore/experimental/optim/asgd.py +1 -1
- mindspore/experimental/optim/lr_scheduler.py +73 -46
- mindspore/experimental/optim/radam.py +34 -31
- mindspore/experimental/optim/rprop.py +1 -1
- mindspore/experimental/optim/sgd.py +1 -1
- mindspore/hal/contiguous_tensors_handle.py +6 -10
- mindspore/hal/device.py +55 -53
- mindspore/hal/event.py +52 -52
- mindspore/hal/memory.py +179 -120
- mindspore/hal/stream.py +150 -109
- mindspore/include/api/context.h +0 -1
- mindspore/include/dataset/constants.h +7 -4
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +50 -0
- mindspore/mindrecord/__init__.py +21 -8
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_backend.dll → mindspore_ops_host.dll} +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/__init__.py +798 -761
- mindspore/mint/distributed/__init__.py +70 -4
- mindspore/mint/distributed/distributed.py +2679 -44
- mindspore/mint/linalg/__init__.py +8 -0
- mindspore/mint/nn/__init__.py +743 -22
- mindspore/mint/nn/functional.py +716 -23
- mindspore/mint/nn/layer/__init__.py +21 -4
- mindspore/mint/nn/layer/_functions.py +334 -0
- mindspore/mint/nn/layer/activation.py +276 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +933 -0
- mindspore/mint/nn/layer/normalization.py +223 -28
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +235 -0
- mindspore/mint/optim/__init__.py +3 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/optim/sgd.py +171 -0
- mindspore/mint/special/__init__.py +2 -1
- mindspore/msobj140.dll +0 -0
- mindspore/mspdb140.dll +0 -0
- mindspore/mspdbcore.dll +0 -0
- mindspore/mspdbst.dll +0 -0
- mindspore/mspft140.dll +0 -0
- mindspore/msvcdis140.dll +0 -0
- mindspore/msvcp140_1.dll +0 -0
- mindspore/msvcp140_2.dll +0 -0
- mindspore/msvcp140_atomic_wait.dll +0 -0
- mindspore/msvcp140_codecvt_ids.dll +0 -0
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/__init__.py +4 -1
- mindspore/nn/cell.py +1373 -192
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +29 -27
- mindspore/nn/layer/basic.py +51 -35
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/container.py +1 -1
- mindspore/nn/layer/conv.py +53 -42
- mindspore/nn/layer/embedding.py +12 -11
- mindspore/nn/layer/normalization.py +56 -49
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +120 -42
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +59 -36
- mindspore/nn/learning_rate_schedule.py +8 -4
- mindspore/nn/loss/loss.py +58 -55
- mindspore/nn/optim/ada_grad.py +7 -5
- mindspore/nn/optim/adadelta.py +11 -9
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +19 -15
- mindspore/nn/optim/adamax.py +8 -7
- mindspore/nn/optim/adasum.py +5 -5
- mindspore/nn/optim/asgd.py +3 -1
- mindspore/nn/optim/ftrl.py +11 -9
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/lazyadam.py +12 -10
- mindspore/nn/optim/momentum.py +7 -6
- mindspore/nn/optim/optimizer.py +3 -3
- mindspore/nn/optim/proximal_ada_grad.py +12 -10
- mindspore/nn/optim/rmsprop.py +13 -12
- mindspore/nn/optim/rprop.py +11 -9
- mindspore/nn/optim/sgd.py +9 -6
- mindspore/nn/optim/tft_wrapper.py +5 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/probability/bijector/bijector.py +17 -11
- mindspore/nn/probability/bijector/gumbel_cdf.py +5 -5
- mindspore/nn/probability/bijector/invert.py +2 -2
- mindspore/nn/probability/bijector/scalar_affine.py +3 -3
- mindspore/nn/probability/bijector/softplus.py +3 -2
- mindspore/nn/probability/distribution/beta.py +3 -3
- mindspore/nn/probability/distribution/categorical.py +1 -1
- mindspore/nn/probability/distribution/cauchy.py +4 -2
- mindspore/nn/probability/distribution/exponential.py +6 -7
- mindspore/nn/probability/distribution/gamma.py +2 -2
- mindspore/nn/probability/distribution/gumbel.py +2 -2
- mindspore/nn/probability/distribution/half_normal.py +5 -3
- mindspore/nn/probability/distribution/logistic.py +5 -3
- mindspore/nn/probability/distribution/poisson.py +1 -1
- mindspore/nn/probability/distribution/uniform.py +5 -3
- mindspore/nn/reinforcement/_tensors_queue.py +1 -1
- mindspore/nn/reinforcement/tensor_array.py +1 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/__init__.py +6 -6
- mindspore/nn/wrap/cell_wrapper.py +181 -122
- mindspore/nn/wrap/grad_reducer.py +45 -36
- mindspore/nn/wrap/loss_scale.py +6 -7
- mindspore/numpy/array_creations.py +63 -65
- mindspore/numpy/array_ops.py +149 -144
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +17 -18
- mindspore/numpy/utils_const.py +5 -6
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +5 -3
- mindspore/ops/_grad_experimental/grad_comm_ops.py +112 -16
- mindspore/ops/_grad_experimental/grad_debug_ops.py +14 -2
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_grad_experimental/taylor_rule.py +29 -0
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_register_for_op.py +0 -11
- mindspore/{ops_generate → ops/_utils}/arg_dtype_cast.py +123 -4
- mindspore/{ops_generate → ops/_utils}/arg_handler.py +3 -65
- mindspore/ops/_vmap/vmap_array_ops.py +52 -25
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +21 -14
- mindspore/ops/_vmap/vmap_math_ops.py +15 -16
- mindspore/ops/_vmap/vmap_nn_ops.py +29 -42
- mindspore/ops/auto_generate/__init__.py +4 -3
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +258 -46
- mindspore/ops/auto_generate/gen_extend_func.py +757 -185
- mindspore/ops/auto_generate/gen_ops_def.py +4197 -2243
- mindspore/ops/auto_generate/gen_ops_prim.py +16976 -6055
- mindspore/ops/auto_generate/pyboost_inner_prim.py +221 -87
- mindspore/ops/composite/__init__.py +2 -1
- mindspore/ops/composite/base.py +20 -25
- mindspore/ops/composite/math_ops.py +6 -16
- mindspore/ops/composite/multitype_ops/__init__.py +5 -2
- mindspore/ops/composite/multitype_ops/_compile_utils.py +228 -30
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -2
- mindspore/ops/composite/multitype_ops/add_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/div_impl.py +6 -4
- mindspore/ops/composite/multitype_ops/equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/floordiv_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/getitem_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/greater_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/greater_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/invert_impl.py +50 -0
- mindspore/ops/composite/multitype_ops/left_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/less_equal_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/less_impl.py +4 -3
- mindspore/ops/composite/multitype_ops/logic_not_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/logical_and_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/logical_or_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mod_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/mul_impl.py +3 -2
- mindspore/ops/composite/multitype_ops/negative_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_equal_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/ones_like_impl.py +18 -0
- mindspore/ops/composite/multitype_ops/pow_impl.py +2 -30
- mindspore/ops/composite/multitype_ops/right_shift_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/setitem_impl.py +2 -1
- mindspore/ops/composite/multitype_ops/sub_impl.py +2 -1
- mindspore/ops/function/__init__.py +40 -2
- mindspore/ops/function/_add_attr_func.py +58 -0
- mindspore/ops/function/array_func.py +2089 -2403
- mindspore/ops/function/clip_func.py +80 -23
- mindspore/ops/function/debug_func.py +57 -57
- mindspore/ops/function/grad/__init__.py +1 -0
- mindspore/ops/function/grad/grad_func.py +104 -71
- mindspore/ops/function/image_func.py +2 -2
- mindspore/ops/function/linalg_func.py +47 -78
- mindspore/ops/function/math_func.py +4351 -3813
- mindspore/ops/function/nn_func.py +1712 -637
- mindspore/ops/function/other_func.py +159 -1
- mindspore/ops/function/parameter_func.py +18 -84
- mindspore/ops/function/random_func.py +452 -387
- mindspore/ops/function/reshard_func.py +4 -70
- mindspore/ops/function/sparse_func.py +3 -3
- mindspore/ops/function/sparse_unary_func.py +6 -6
- mindspore/ops/function/spectral_func.py +25 -58
- mindspore/ops/function/vmap_func.py +26 -18
- mindspore/ops/functional.py +23 -7
- mindspore/ops/functional_overload.py +1548 -0
- mindspore/ops/op_info_register.py +32 -244
- mindspore/ops/operations/__init__.py +23 -15
- mindspore/ops/operations/_custom_ops_utils.py +235 -0
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -43
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +43 -84
- mindspore/ops/operations/_ms_kernel.py +4 -10
- mindspore/ops/operations/_rl_inner_ops.py +1 -1
- mindspore/ops/operations/_scalar_ops.py +3 -2
- mindspore/ops/operations/_sequence_ops.py +1 -1
- mindspore/ops/operations/_tensor_array.py +1 -1
- mindspore/ops/operations/array_ops.py +81 -324
- mindspore/ops/operations/comm_ops.py +154 -108
- mindspore/ops/operations/custom_ops.py +298 -87
- mindspore/ops/operations/debug_ops.py +157 -59
- mindspore/ops/operations/inner_ops.py +7 -5
- mindspore/ops/operations/linalg_ops.py +1 -57
- mindspore/ops/operations/manually_defined/_inner.py +1 -1
- mindspore/ops/operations/manually_defined/ops_def.py +928 -180
- mindspore/ops/operations/math_ops.py +32 -234
- mindspore/ops/operations/nn_ops.py +212 -531
- mindspore/ops/operations/other_ops.py +62 -9
- mindspore/ops/operations/random_ops.py +13 -7
- mindspore/ops/operations/reshard_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +2 -2
- mindspore/ops/primitive.py +66 -53
- mindspore/ops/tensor_method.py +1895 -0
- mindspore/ops_generate/__init__.py +0 -5
- mindspore/ops_generate/aclnn/__init__.py +0 -0
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +135 -0
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +257 -0
- mindspore/ops_generate/api/__init__.py +0 -0
- mindspore/ops_generate/api/add_tensor_docs_generator.py +56 -0
- mindspore/ops_generate/api/cpp_create_prim_instance_helper_generator.py +105 -0
- mindspore/ops_generate/api/functional_map_cpp_generator.py +504 -0
- mindspore/ops_generate/api/functional_overload_py_generator.py +112 -0
- mindspore/ops_generate/api/functions_cc_generator.py +237 -0
- mindspore/ops_generate/api/gen_api.py +103 -0
- mindspore/ops_generate/api/op_api_proto.py +235 -0
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +461 -0
- mindspore/ops_generate/common/__init__.py +0 -0
- mindspore/ops_generate/common/base_generator.py +11 -0
- mindspore/ops_generate/common/gen_constants.py +91 -0
- mindspore/ops_generate/common/gen_utils.py +348 -0
- mindspore/ops_generate/common/op_proto.py +473 -0
- mindspore/ops_generate/common/template.py +523 -0
- mindspore/ops_generate/gen_ops.py +22 -1069
- mindspore/ops_generate/op_def/__init__.py +0 -0
- mindspore/ops_generate/op_def/gen_op_def.py +90 -0
- mindspore/ops_generate/op_def/lite_ops_cpp_generator.py +191 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +296 -0
- mindspore/ops_generate/op_def/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/op_def/ops_name_h_generator.py +83 -0
- mindspore/ops_generate/op_def/ops_primitive_h_generator.py +125 -0
- mindspore/ops_generate/op_def_py/__init__.py +0 -0
- mindspore/ops_generate/op_def_py/gen_op_def_py.py +47 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +132 -0
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +489 -0
- mindspore/ops_generate/pyboost/__init__.py +0 -0
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +139 -0
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +175 -0
- mindspore/ops_generate/pyboost/op_template_parser.py +517 -0
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +407 -0
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +100 -0
- mindspore/ops_generate/pyboost/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +155 -0
- mindspore/ops_generate/pyboost/pyboost_inner_prim_generator.py +132 -0
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +272 -0
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +938 -0
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +357 -0
- mindspore/ops_generate/{pyboost_utils.py → pyboost/pyboost_utils.py} +179 -36
- mindspore/ops_generate/resources/__init__.py +0 -0
- mindspore/ops_generate/resources/resource_list.py +30 -0
- mindspore/ops_generate/resources/resource_loader.py +36 -0
- mindspore/ops_generate/resources/resource_manager.py +64 -0
- mindspore/ops_generate/resources/yaml_loader.py +88 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +122 -0
- mindspore/parallel/__init__.py +7 -3
- mindspore/parallel/_auto_parallel_context.py +159 -40
- mindspore/parallel/_cell_wrapper.py +132 -15
- mindspore/parallel/_parallel_serialization.py +107 -5
- mindspore/parallel/_ps_context.py +1 -1
- mindspore/parallel/_recovery_context.py +7 -2
- mindspore/parallel/_tensor.py +142 -18
- mindspore/parallel/_utils.py +199 -23
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/auto_parallel.py +732 -0
- mindspore/parallel/checkpoint_convert.py +159 -0
- mindspore/parallel/checkpoint_transform.py +700 -35
- mindspore/parallel/cluster/process_entity/_api.py +276 -50
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +21 -4
- mindspore/parallel/function/__init__.py +24 -0
- mindspore/parallel/function/reshard_func.py +258 -0
- mindspore/parallel/nn/__init__.py +25 -0
- mindspore/parallel/nn/parallel_cell_wrapper.py +263 -0
- mindspore/parallel/nn/parallel_grad_reducer.py +169 -0
- mindspore/parallel/parameter_broadcast.py +25 -14
- mindspore/parallel/shard.py +137 -59
- mindspore/parallel/transform_safetensors.py +364 -305
- mindspore/pgodb140.dll +0 -0
- mindspore/pgort140.dll +0 -0
- mindspore/profiler/__init__.py +22 -5
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +170 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +264 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +109 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +415 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +372 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +250 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +320 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +327 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +376 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +96 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +139 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +186 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +221 -0
- mindspore/profiler/common/path_manager.py +395 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +500 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_meta_data.py +74 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +251 -0
- mindspore/profiler/common/profiler_path_manager.py +179 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +341 -75
- mindspore/profiler/envprofiler.py +163 -0
- mindspore/profiler/experimental_config.py +197 -0
- mindspore/profiler/mstx.py +242 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +335 -0
- mindspore/profiler/profiler.py +1073 -90
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +118 -0
- mindspore/profiler/schedule.py +243 -0
- mindspore/rewrite/api/node.py +15 -13
- mindspore/rewrite/api/symbol_tree.py +2 -3
- mindspore/run_check/_check_version.py +27 -20
- mindspore/run_check/run_check.py +1 -1
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +177 -0
- mindspore/runtime/memory.py +416 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/safeguard/rewrite_obfuscation.py +12 -9
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tbbmalloc.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +8 -8
- mindspore/train/_utils.py +96 -27
- mindspore/train/amp.py +9 -5
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +2 -16
- mindspore/train/callback/_checkpoint.py +53 -55
- mindspore/train/callback/_cluster_monitor.py +14 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +103 -68
- mindspore/train/callback/_history.py +8 -5
- mindspore/train/callback/_lambda_callback.py +2 -2
- mindspore/train/callback/_landscape.py +0 -3
- mindspore/train/callback/_loss_monitor.py +2 -1
- mindspore/train/callback/_on_request_exit.py +6 -5
- mindspore/train/callback/_reduce_lr_on_plateau.py +11 -6
- mindspore/train/callback/_summary_collector.py +52 -19
- mindspore/train/callback/_time_monitor.py +2 -1
- mindspore/train/callback/{_tft_register.py → _train_fault_tolerance.py} +228 -108
- mindspore/train/data_sink.py +25 -2
- mindspore/train/dataset_helper.py +15 -16
- mindspore/train/loss_scale_manager.py +8 -7
- mindspore/train/metrics/accuracy.py +3 -3
- mindspore/train/metrics/confusion_matrix.py +9 -9
- mindspore/train/metrics/error.py +3 -3
- mindspore/train/metrics/hausdorff_distance.py +4 -4
- mindspore/train/metrics/mean_surface_distance.py +3 -3
- mindspore/train/metrics/metric.py +0 -12
- mindspore/train/metrics/occlusion_sensitivity.py +4 -2
- mindspore/train/metrics/precision.py +11 -10
- mindspore/train/metrics/recall.py +9 -9
- mindspore/train/metrics/root_mean_square_surface_distance.py +2 -2
- mindspore/train/mind_ir_pb2.py +174 -46
- mindspore/train/model.py +269 -136
- mindspore/train/serialization.py +622 -978
- mindspore/train/summary/_summary_adapter.py +2 -2
- mindspore/train/summary/summary_record.py +2 -3
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dryrun.py +140 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/runtime_execution_order_check.py +552 -0
- mindspore/utils/utils.py +138 -4
- mindspore/vcmeta.dll +0 -0
- mindspore/vcruntime140.dll +0 -0
- mindspore/vcruntime140_1.dll +0 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/METADATA +3 -3
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/RECORD +587 -418
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/entry_points.txt +1 -1
- mindspore/_install_custom.py +0 -43
- mindspore/common/_register_for_adapter.py +0 -74
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +0 -252
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -197
- mindspore/ops/operations/_opaque_predicate_registry.py +0 -41
- mindspore/ops_generate/gen_aclnn_implement.py +0 -263
- mindspore/ops_generate/gen_ops_inner_prim.py +0 -131
- mindspore/ops_generate/gen_pyboost_func.py +0 -1052
- mindspore/ops_generate/gen_utils.py +0 -209
- mindspore/ops_generate/op_proto.py +0 -145
- mindspore/ops_generate/template.py +0 -261
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.6.0.dist-info}/top_level.txt +0 -0
|
@@ -28,7 +28,11 @@ from mindspore import _checkparam as validator
|
|
|
28
28
|
from mindspore.common import dtype as mstype
|
|
29
29
|
from mindspore.nn.cell import Cell
|
|
30
30
|
from mindspore.nn.layer.normalization import LayerNormExt as LayerNorm
|
|
31
|
-
from mindspore.
|
|
31
|
+
from mindspore.communication import get_group_size
|
|
32
|
+
from mindspore.communication._comm_helper import GlobalComm
|
|
33
|
+
from mindspore.ops.function import batch_norm
|
|
34
|
+
|
|
35
|
+
from ._functions import _SyncBatchNorm
|
|
32
36
|
|
|
33
37
|
|
|
34
38
|
class _NormBase(Cell):
|
|
@@ -43,6 +47,7 @@ class _NormBase(Cell):
|
|
|
43
47
|
dtype=None
|
|
44
48
|
) -> None:
|
|
45
49
|
super(_NormBase, self).__init__()
|
|
50
|
+
self.set_train()
|
|
46
51
|
self.shape = ops.Shape()
|
|
47
52
|
self.num_features = num_features
|
|
48
53
|
self.eps = eps
|
|
@@ -55,8 +60,6 @@ class _NormBase(Cell):
|
|
|
55
60
|
Tensor(np.empty(num_features), dtype=self.dtype), name="weight")
|
|
56
61
|
self.bias = Parameter(
|
|
57
62
|
Tensor(np.empty(num_features), dtype=self.dtype), name="bias")
|
|
58
|
-
self.weight: Optional[Parameter]
|
|
59
|
-
self.bias: Optional[Parameter]
|
|
60
63
|
else:
|
|
61
64
|
self.weight = None
|
|
62
65
|
self.bias = None
|
|
@@ -65,11 +68,8 @@ class _NormBase(Cell):
|
|
|
65
68
|
requires_grad=False, name="running_mean")
|
|
66
69
|
self.running_var = Parameter(Tensor(np.ones(num_features), dtype=self.dtype),
|
|
67
70
|
requires_grad=False, name="running_var")
|
|
68
|
-
self.
|
|
69
|
-
self.running_var: Optional[Tensor]
|
|
70
|
-
self.num_batches_tracked = Parameter(Tensor(0, dtype=ms.float32),
|
|
71
|
+
self.num_batches_tracked = Parameter(Tensor(0, dtype=ms.int64),
|
|
71
72
|
requires_grad=False, name="num_batches_tracked")
|
|
72
|
-
self.num_batches_tracked: Optional[Tensor]
|
|
73
73
|
else:
|
|
74
74
|
self.running_mean = None
|
|
75
75
|
self.running_var = None
|
|
@@ -84,7 +84,7 @@ class _NormBase(Cell):
|
|
|
84
84
|
np.zeros(self.num_features), dtype=self.dtype)
|
|
85
85
|
one_running_var = Tensor(
|
|
86
86
|
np.ones(self.num_features), dtype=self.dtype)
|
|
87
|
-
zero_num_batches_tracked = Tensor(0, dtype=ms.
|
|
87
|
+
zero_num_batches_tracked = Tensor(0, dtype=ms.int64)
|
|
88
88
|
|
|
89
89
|
ops.assign(self.running_mean, zero_running_mean)
|
|
90
90
|
ops.assign(self.running_var, one_running_var)
|
|
@@ -122,6 +122,7 @@ class _BatchNorm(_NormBase):
|
|
|
122
122
|
dtype)
|
|
123
123
|
self.training = True
|
|
124
124
|
|
|
125
|
+
|
|
125
126
|
def _check_input_dim(self, input):
|
|
126
127
|
raise NotImplementedError
|
|
127
128
|
|
|
@@ -135,11 +136,9 @@ class _BatchNorm(_NormBase):
|
|
|
135
136
|
|
|
136
137
|
if self.training and self.track_running_stats:
|
|
137
138
|
if self.num_batches_tracked is not None:
|
|
138
|
-
|
|
139
|
-
ops.assign_add(self.num_batches_tracked,
|
|
140
|
-
num_batches_tracked_one)
|
|
139
|
+
self.num_batches_tracked += 1
|
|
141
140
|
if self.momentum is None:
|
|
142
|
-
exponential_average_factor =
|
|
141
|
+
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
|
143
142
|
else:
|
|
144
143
|
exponential_average_factor = self.momentum
|
|
145
144
|
|
|
@@ -206,7 +205,7 @@ class BatchNorm1d(_BatchNorm):
|
|
|
206
205
|
Tensor, has the same type and shape as `input`.
|
|
207
206
|
|
|
208
207
|
Raises:
|
|
209
|
-
TypeError: If `num_features` is not
|
|
208
|
+
TypeError: If `num_features` is not an int number.
|
|
210
209
|
TypeError: If `eps` is not a float.
|
|
211
210
|
ValueError: If `num_features` is less than 1.
|
|
212
211
|
|
|
@@ -241,7 +240,7 @@ class BatchNorm2d(_BatchNorm):
|
|
|
241
240
|
|
|
242
241
|
.. math::
|
|
243
242
|
|
|
244
|
-
y = \frac{x -
|
|
243
|
+
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
245
244
|
|
|
246
245
|
The mean and standard-deviation are calculated per-dimension over
|
|
247
246
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
@@ -249,8 +248,8 @@ class BatchNorm2d(_BatchNorm):
|
|
|
249
248
|
elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0.
|
|
250
249
|
|
|
251
250
|
.. warning::
|
|
252
|
-
This API does not support Dynamic Rank.
|
|
253
|
-
This is an experimental API that is subject to change or deletion.
|
|
251
|
+
- This API does not support Dynamic Rank.
|
|
252
|
+
- This is an experimental API that is subject to change or deletion.
|
|
254
253
|
|
|
255
254
|
Args:
|
|
256
255
|
num_features (int): `C` from an expected input of shape :math:`(N, C, H, W)`.
|
|
@@ -263,7 +262,7 @@ class BatchNorm2d(_BatchNorm):
|
|
|
263
262
|
track_running_stats (bool, optional): a boolean value that when set to ``True``, this
|
|
264
263
|
cell tracks the running mean and variance, and when set to ``False``,
|
|
265
264
|
this cell does not track such statistics. And this cell always uses batch statistics
|
|
266
|
-
in both
|
|
265
|
+
in both train and eval modes. Default: ``True`` .
|
|
267
266
|
dtype (:class:`mindspore.dtype`, optional): Dtype of Parameters. Default: ``None`` .
|
|
268
267
|
|
|
269
268
|
Inputs:
|
|
@@ -273,7 +272,7 @@ class BatchNorm2d(_BatchNorm):
|
|
|
273
272
|
Tensor, has the same type and shape as `input`.
|
|
274
273
|
|
|
275
274
|
Raises:
|
|
276
|
-
TypeError: If `num_features` is not
|
|
275
|
+
TypeError: If `num_features` is not an int number.
|
|
277
276
|
TypeError: If `eps` is not a float.
|
|
278
277
|
ValueError: If `num_features` is less than 1.
|
|
279
278
|
|
|
@@ -311,7 +310,7 @@ class BatchNorm3d(_BatchNorm):
|
|
|
311
310
|
|
|
312
311
|
.. math::
|
|
313
312
|
|
|
314
|
-
y = \frac{x -
|
|
313
|
+
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
315
314
|
|
|
316
315
|
The mean and standard-deviation are calculated per-dimension over
|
|
317
316
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
@@ -343,7 +342,7 @@ class BatchNorm3d(_BatchNorm):
|
|
|
343
342
|
Tensor, has the same type and shape as `input`.
|
|
344
343
|
|
|
345
344
|
Raises:
|
|
346
|
-
TypeError: If `num_features` is not
|
|
345
|
+
TypeError: If `num_features` is not an int number.
|
|
347
346
|
TypeError: If `eps` is not a float.
|
|
348
347
|
ValueError: If `num_features` is less than 1.
|
|
349
348
|
|
|
@@ -402,7 +401,7 @@ class GroupNorm(Cell):
|
|
|
402
401
|
additional dimensions.
|
|
403
402
|
|
|
404
403
|
Outputs:
|
|
405
|
-
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `
|
|
404
|
+
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input`.
|
|
406
405
|
|
|
407
406
|
Raises:
|
|
408
407
|
TypeError: If `num_groups` or `num_channels` is not an int.
|
|
@@ -435,8 +434,8 @@ class GroupNorm(Cell):
|
|
|
435
434
|
"""Initialize GroupNorm."""
|
|
436
435
|
super(GroupNorm, self).__init__()
|
|
437
436
|
ms_dtype = mstype.float32 if dtype is None else dtype
|
|
438
|
-
|
|
439
|
-
|
|
437
|
+
weight_init = 'ones'
|
|
438
|
+
bias_init = 'zeros'
|
|
440
439
|
|
|
441
440
|
self.num_groups = validator.check_positive_int(
|
|
442
441
|
num_groups, "num_groups", self.cls_name)
|
|
@@ -450,14 +449,14 @@ class GroupNorm(Cell):
|
|
|
450
449
|
self.affine = validator.check_bool(
|
|
451
450
|
affine, arg_name="affine", prim_name=self.cls_name)
|
|
452
451
|
|
|
453
|
-
self.
|
|
454
|
-
|
|
455
|
-
self.
|
|
456
|
-
|
|
452
|
+
self.weight = Parameter(initializer(
|
|
453
|
+
weight_init, self.num_channels, dtype=ms_dtype), name="weight", requires_grad=affine)
|
|
454
|
+
self.bias = Parameter(initializer(
|
|
455
|
+
bias_init, self.num_channels, dtype=ms_dtype), name="bias", requires_grad=affine)
|
|
457
456
|
|
|
458
457
|
def _cal_output(self, x):
|
|
459
458
|
"""calculate groupnorm output"""
|
|
460
|
-
return group_norm(x, self.num_groups, self.
|
|
459
|
+
return ops.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
|
461
460
|
|
|
462
461
|
def extend_repr(self):
|
|
463
462
|
return 'num_groups={}, num_channels={}, eps={}, affine={}'.format(
|
|
@@ -468,10 +467,206 @@ class GroupNorm(Cell):
|
|
|
468
467
|
return output
|
|
469
468
|
|
|
470
469
|
|
|
470
|
+
class SyncBatchNorm(_BatchNorm):
|
|
471
|
+
r"""
|
|
472
|
+
Sync Batch Normalization layer over a N-dimension input.
|
|
473
|
+
|
|
474
|
+
Sync Batch Normalization is cross device synchronized Batch Normalization. The implementation of Batch
|
|
475
|
+
Normalization only normalizes the data within each device. Sync Batch Normalization will normalize the input
|
|
476
|
+
within the group. It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by
|
|
477
|
+
Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
|
|
478
|
+
feature using a mini-batch of data and the learned parameters which can be described in the following formula.
|
|
479
|
+
|
|
480
|
+
.. math::
|
|
481
|
+
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
482
|
+
|
|
483
|
+
.. warning::
|
|
484
|
+
This is an experimental API that is subject to change or deletion.
|
|
485
|
+
|
|
486
|
+
Args:
|
|
487
|
+
num_features (int): `C` from an expected input of size :math:`(N, C, +)`.
|
|
488
|
+
eps (float, optional): :math:`\epsilon`, a value added to the denominator for numerical stability.
|
|
489
|
+
Default: ``1e-5`` .
|
|
490
|
+
momentum (float, optional): A floating hyperparameter of the momentum for the
|
|
491
|
+
running_mean and running_var computation. Default: ``0.1`` .
|
|
492
|
+
affine (bool, optional): A bool value. When set to ``True`` , :math:`\gamma` and :math:`\beta` are learnable
|
|
493
|
+
parameters. When set to ``False`` , :math:`\gamma` and :math:`\beta` are unlearnable parameters.
|
|
494
|
+
Default: ``True`` .
|
|
495
|
+
track_running_stats (bool, optional): a boolean value that when set to ``True``, this
|
|
496
|
+
cell tracks the running mean and variance, and when set to ``False``,
|
|
497
|
+
this cell does not track such statistics. And this cell always uses batch statistics
|
|
498
|
+
in both training and eval modes. Default: ``True`` .
|
|
499
|
+
process_group (str, optional): synchronization of stats happen within each process group individually.
|
|
500
|
+
Default behavior is synchronization across the whole world. Default: ``None`` .
|
|
501
|
+
dtype (:class:`mindspore.dtype`, optional): Dtype of Parameters. Default: ``None`` .
|
|
502
|
+
|
|
503
|
+
Inputs:
|
|
504
|
+
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, +)`.
|
|
505
|
+
|
|
506
|
+
Outputs:
|
|
507
|
+
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, +)`.
|
|
508
|
+
|
|
509
|
+
Raises:
|
|
510
|
+
TypeError: If `num_features` is not an int.
|
|
511
|
+
TypeError: If `eps` is not a float.
|
|
512
|
+
ValueError: If `num_features` is less than 1.
|
|
513
|
+
ValueError: If `momentum` is not in range [0, 1].
|
|
514
|
+
ValueError: If rank_id in `process_group` is not in range [0, rank_size).
|
|
515
|
+
|
|
516
|
+
Supported Platforms:
|
|
517
|
+
``Ascend``
|
|
518
|
+
|
|
519
|
+
Examples:
|
|
520
|
+
.. note::
|
|
521
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
522
|
+
|
|
523
|
+
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
524
|
+
Here, examples use msrun to pull multi-process distributed tasks across nodes with a single command
|
|
525
|
+
line instruction.
|
|
526
|
+
Please see the `Ascend tutorial
|
|
527
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
528
|
+
for more details.
|
|
529
|
+
|
|
530
|
+
This example should be run with multiple devices.
|
|
531
|
+
|
|
532
|
+
>>> # Firstly, preparing test_syncbn.py:
|
|
533
|
+
>>> import numpy as np
|
|
534
|
+
>>> import mindspore
|
|
535
|
+
>>> import mindspore.context as context
|
|
536
|
+
>>> from mindspore.mint.nn.layer import SyncBatchNorm
|
|
537
|
+
>>> from mindspore import Tensor
|
|
538
|
+
>>> from mindspore.communication import init, create_group, get_local_rank
|
|
539
|
+
>>> init()
|
|
540
|
+
>>> context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
|
541
|
+
>>> group = "0-1"
|
|
542
|
+
>>> rank_ids = [0, 1]
|
|
543
|
+
>>> create_group(group, rank_ids)
|
|
544
|
+
>>> sync_batch_norm = SyncBatchNorm(num_features=2, process_group=group, dtype=mindspore.float32)
|
|
545
|
+
>>> sync_batch_norm.set_train(False)
|
|
546
|
+
>>> input_x = Tensor(np.linspace(0, 5, 2*2*2*2), mindspore.float32).reshape(2, 2, 2, 2)
|
|
547
|
+
>>> output_data = sync_batch_norm(input_x)
|
|
548
|
+
>>> # Then, executing the command such as the following:
|
|
549
|
+
>>> # msrun --worker_num=2 --local_worker_num=2 --master_port=8975 --log_dir=msrun_log --join=True
|
|
550
|
+
>>> # --cluster_time_out=100 pytest -s -v test_syncbn.py
|
|
551
|
+
|
|
552
|
+
"""
|
|
553
|
+
def __init__(self,
|
|
554
|
+
num_features: int,
|
|
555
|
+
eps: float = 1e-5,
|
|
556
|
+
momentum: float = 0.1,
|
|
557
|
+
affine: bool = True,
|
|
558
|
+
track_running_stats: bool = True,
|
|
559
|
+
process_group: Optional[str] = None,
|
|
560
|
+
dtype=None):
|
|
561
|
+
super(SyncBatchNorm, self).__init__(
|
|
562
|
+
num_features, eps, momentum, affine, track_running_stats, dtype
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
self.process_group = process_group if process_group else GlobalComm.WORLD_COMM_GROUP
|
|
566
|
+
self.world_size = get_group_size(self.process_group)
|
|
567
|
+
self.sync_batch_norm = _SyncBatchNorm(
|
|
568
|
+
self.num_features, self.world_size, self.dtype)
|
|
569
|
+
|
|
570
|
+
def _check_input_dim(self, input):
|
|
571
|
+
if input.ndim < 2:
|
|
572
|
+
raise ValueError(
|
|
573
|
+
"expected at least 2D input (got {}D input)".format(input.ndim)
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
def _check_non_zero_input_channels(self, input):
|
|
577
|
+
if input.shape[1] == 0:
|
|
578
|
+
raise ValueError(
|
|
579
|
+
"SyncBatchNorm number of input channels should be non-zero"
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
def construct(self, input: Tensor) -> Tensor:
|
|
583
|
+
# currently only GPU input is supported
|
|
584
|
+
|
|
585
|
+
self._check_input_dim(input)
|
|
586
|
+
self._check_non_zero_input_channels(input)
|
|
587
|
+
|
|
588
|
+
# exponential_average_factor is set to self.momentum
|
|
589
|
+
# (when it is available) only so that it gets updated
|
|
590
|
+
# in ONNX graph when this node is exported to ONNX.
|
|
591
|
+
if self.momentum is None:
|
|
592
|
+
exponential_average_factor = 0.0
|
|
593
|
+
else:
|
|
594
|
+
exponential_average_factor = self.momentum
|
|
595
|
+
|
|
596
|
+
if self.training and self.track_running_stats:
|
|
597
|
+
self.num_batches_tracked += 1
|
|
598
|
+
if self.momentum is None: # use cumulative moving average
|
|
599
|
+
exponential_average_factor = 1.0 / float(self.num_batches_tracked.value())
|
|
600
|
+
else: # use exponential moving average
|
|
601
|
+
exponential_average_factor = self.momentum
|
|
602
|
+
|
|
603
|
+
r"""
|
|
604
|
+
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
|
|
605
|
+
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
|
|
606
|
+
"""
|
|
607
|
+
if self.training:
|
|
608
|
+
bn_training = True
|
|
609
|
+
else:
|
|
610
|
+
bn_training = (self.running_mean is None) and (
|
|
611
|
+
self.running_var is None)
|
|
612
|
+
|
|
613
|
+
r"""
|
|
614
|
+
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
|
|
615
|
+
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
|
|
616
|
+
used for normalization (i.e. in eval mode when buffers are not None).
|
|
617
|
+
"""
|
|
618
|
+
# If buffers are not to be tracked, ensure that they won't be updated
|
|
619
|
+
running_mean = (
|
|
620
|
+
self.running_mean if not self.training or self.track_running_stats else None
|
|
621
|
+
)
|
|
622
|
+
running_var = (
|
|
623
|
+
self.running_var if not self.training or self.track_running_stats else None
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
# Don't sync batchnorm stats in inference mode (model.eval()).
|
|
627
|
+
need_sync = (bn_training and self.training)
|
|
628
|
+
if need_sync:
|
|
629
|
+
need_sync = self.world_size > 1
|
|
630
|
+
|
|
631
|
+
# fallback to framework BN when synchronization is not necessary
|
|
632
|
+
if not need_sync:
|
|
633
|
+
if self.weight is None:
|
|
634
|
+
weight = Tensor(np.ones(self.num_features), dtype=self.dtype)
|
|
635
|
+
else:
|
|
636
|
+
weight = self.weight
|
|
637
|
+
if self.bias is None:
|
|
638
|
+
bias = Tensor(np.zeros(self.num_features), dtype=self.dtype)
|
|
639
|
+
else:
|
|
640
|
+
bias = self.bias
|
|
641
|
+
if running_mean is None or running_var is None:
|
|
642
|
+
raise ValueError(
|
|
643
|
+
"running mean or running var can\'t be none for batch_norm.")
|
|
644
|
+
return batch_norm(input,
|
|
645
|
+
running_mean,
|
|
646
|
+
running_var,
|
|
647
|
+
weight,
|
|
648
|
+
bias,
|
|
649
|
+
bn_training,
|
|
650
|
+
exponential_average_factor,
|
|
651
|
+
self.eps)
|
|
652
|
+
else:
|
|
653
|
+
output = self.sync_batch_norm(input,
|
|
654
|
+
self.weight,
|
|
655
|
+
self.bias,
|
|
656
|
+
running_mean,
|
|
657
|
+
running_var,
|
|
658
|
+
self.eps,
|
|
659
|
+
exponential_average_factor,
|
|
660
|
+
self.process_group,
|
|
661
|
+
self.world_size)
|
|
662
|
+
return output
|
|
663
|
+
|
|
664
|
+
|
|
471
665
|
__all__ = [
|
|
472
666
|
'GroupNorm',
|
|
473
667
|
'BatchNorm1d',
|
|
474
668
|
'BatchNorm2d',
|
|
475
669
|
'BatchNorm3d',
|
|
476
670
|
'LayerNorm',
|
|
671
|
+
'SyncBatchNorm',
|
|
477
672
|
]
|