mindspore 2.4.10__cp310-cp310-win_amd64.whl → 2.5.0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +8 -3
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +0 -5
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/compile_config.py +64 -0
- mindspore/_extends/parse/deprecated/__init__.py +0 -0
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +375 -0
- mindspore/_extends/parse/parser.py +23 -5
- mindspore/_extends/parse/standard_method.py +123 -27
- mindspore/_extends/pijit/pijit_func_white_list.py +1 -1
- mindspore/amp.py +7 -1
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/boost_cell_wrapper.py +136 -41
- mindspore/common/__init__.py +3 -1
- mindspore/common/_register_for_tensor.py +0 -1
- mindspore/common/_stub_tensor.py +25 -4
- mindspore/common/_tensor_cpp_method.py +17 -0
- mindspore/common/_tensor_docs.py +6132 -0
- mindspore/common/api.py +98 -21
- mindspore/common/dtype.py +34 -34
- mindspore/common/dump.py +2 -1
- mindspore/common/file_system.py +8 -3
- mindspore/common/generator.py +2 -0
- mindspore/common/hook_handle.py +3 -1
- mindspore/common/initializer.py +3 -4
- mindspore/common/lazy_inline.py +8 -2
- mindspore/common/mindir_util.py +10 -2
- mindspore/common/parameter.py +31 -15
- mindspore/common/tensor.py +713 -1337
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +5 -0
- mindspore/communication/comm_func.py +215 -173
- mindspore/communication/management.py +23 -20
- mindspore/context.py +285 -191
- mindspore/dataset/__init__.py +23 -19
- mindspore/dataset/callback/ds_callback.py +2 -1
- mindspore/dataset/core/config.py +84 -3
- mindspore/dataset/engine/cache_admin.py +3 -3
- mindspore/dataset/engine/cache_client.py +5 -4
- mindspore/dataset/engine/datasets.py +192 -149
- mindspore/dataset/engine/datasets_audio.py +14 -0
- mindspore/dataset/engine/datasets_standard_format.py +11 -11
- mindspore/dataset/engine/datasets_text.py +38 -1
- mindspore/dataset/engine/datasets_user_defined.py +100 -66
- mindspore/dataset/engine/datasets_vision.py +81 -8
- mindspore/dataset/engine/iterators.py +281 -63
- mindspore/dataset/engine/obs/util.py +8 -0
- mindspore/dataset/engine/queue.py +40 -0
- mindspore/dataset/engine/samplers.py +26 -2
- mindspore/dataset/engine/serializer_deserializer.py +1 -1
- mindspore/dataset/engine/validators.py +43 -11
- mindspore/dataset/transforms/py_transforms_util.py +17 -0
- mindspore/dataset/transforms/transforms.py +29 -12
- mindspore/dataset/vision/validators.py +1 -2
- mindspore/device_context/__init__.py +21 -0
- mindspore/device_context/ascend/__init__.py +25 -0
- mindspore/device_context/ascend/device.py +72 -0
- mindspore/device_context/ascend/op_debug.py +94 -0
- mindspore/device_context/ascend/op_precision.py +193 -0
- mindspore/device_context/ascend/op_tuning.py +127 -0
- mindspore/device_context/cpu/__init__.py +25 -0
- mindspore/device_context/cpu/device.py +62 -0
- mindspore/device_context/cpu/op_tuning.py +43 -0
- mindspore/device_context/gpu/__init__.py +21 -0
- mindspore/device_context/gpu/device.py +70 -0
- mindspore/device_context/gpu/op_precision.py +67 -0
- mindspore/device_context/gpu/op_tuning.py +175 -0
- mindspore/device_manager.py +134 -0
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/__init__.py +1 -0
- mindspore/experimental/llm_boost/ascend_native/__init__.py +22 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +211 -0
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +52 -0
- mindspore/experimental/llm_boost/atb/boost_base.py +2 -3
- mindspore/experimental/llm_boost/atb/llama_boost.py +6 -1
- mindspore/experimental/llm_boost/register.py +1 -0
- mindspore/experimental/optim/adadelta.py +26 -22
- mindspore/experimental/optim/adam.py +3 -0
- mindspore/experimental/optim/lr_scheduler.py +33 -24
- mindspore/experimental/optim/radam.py +33 -30
- mindspore/hal/device.py +28 -0
- mindspore/hal/event.py +17 -0
- mindspore/hal/memory.py +94 -3
- mindspore/hal/stream.py +91 -6
- mindspore/include/api/context.h +0 -1
- mindspore/jpeg62.dll +0 -0
- mindspore/log.py +12 -0
- mindspore/mindrecord/__init__.py +1 -1
- mindspore/mindrecord/config.py +17 -316
- mindspore/mindrecord/filereader.py +1 -9
- mindspore/mindrecord/filewriter.py +5 -15
- mindspore/mindrecord/mindpage.py +1 -9
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +824 -218
- mindspore/mint/distributed/__init__.py +66 -4
- mindspore/mint/distributed/distributed.py +2594 -44
- mindspore/mint/linalg/__init__.py +6 -0
- mindspore/mint/nn/__init__.py +473 -14
- mindspore/mint/nn/functional.py +486 -11
- mindspore/mint/nn/layer/__init__.py +17 -4
- mindspore/mint/nn/layer/_functions.py +330 -0
- mindspore/mint/nn/layer/activation.py +169 -1
- mindspore/mint/nn/layer/basic.py +123 -0
- mindspore/mint/nn/layer/conv.py +727 -0
- mindspore/mint/nn/layer/normalization.py +215 -19
- mindspore/mint/nn/layer/padding.py +797 -0
- mindspore/mint/nn/layer/pooling.py +170 -0
- mindspore/mint/optim/__init__.py +2 -1
- mindspore/mint/optim/adam.py +223 -0
- mindspore/mint/optim/adamw.py +26 -19
- mindspore/mint/special/__init__.py +2 -1
- mindspore/multiprocessing/__init__.py +5 -0
- mindspore/nn/cell.py +126 -19
- mindspore/nn/dynamic_lr.py +2 -1
- mindspore/nn/layer/activation.py +6 -6
- mindspore/nn/layer/basic.py +35 -25
- mindspore/nn/layer/channel_shuffle.py +3 -3
- mindspore/nn/layer/embedding.py +3 -3
- mindspore/nn/layer/normalization.py +8 -7
- mindspore/nn/layer/padding.py +4 -3
- mindspore/nn/layer/pooling.py +47 -13
- mindspore/nn/layer/rnn_cells.py +1 -1
- mindspore/nn/layer/rnns.py +2 -1
- mindspore/nn/layer/timedistributed.py +5 -5
- mindspore/nn/layer/transformer.py +48 -26
- mindspore/nn/learning_rate_schedule.py +5 -3
- mindspore/nn/loss/loss.py +31 -36
- mindspore/nn/optim/ada_grad.py +1 -0
- mindspore/nn/optim/adadelta.py +2 -2
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lars.py +1 -4
- mindspore/nn/optim/optimizer.py +1 -1
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/thor.py +2 -1
- mindspore/nn/utils/init.py +13 -11
- mindspore/nn/wrap/cell_wrapper.py +4 -6
- mindspore/nn/wrap/loss_scale.py +3 -4
- mindspore/numpy/array_creations.py +60 -62
- mindspore/numpy/array_ops.py +148 -143
- mindspore/numpy/logic_ops.py +41 -42
- mindspore/numpy/math_ops.py +361 -359
- mindspore/numpy/utils.py +16 -16
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +2 -1
- mindspore/ops/_grad_experimental/grad_comm_ops.py +94 -13
- mindspore/ops/_grad_experimental/grad_debug_ops.py +6 -1
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_grad_experimental/grad_math_ops.py +2 -1
- mindspore/ops/_op_impl/cpu/__init__.py +1 -0
- mindspore/ops/_op_impl/cpu/raise_op.py +28 -0
- mindspore/ops/_vmap/vmap_array_ops.py +20 -19
- mindspore/ops/_vmap/vmap_base.py +0 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +19 -13
- mindspore/ops/_vmap/vmap_math_ops.py +11 -9
- mindspore/ops/_vmap/vmap_nn_ops.py +20 -34
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +149 -12
- mindspore/ops/auto_generate/gen_arg_handler.py +0 -61
- mindspore/ops/auto_generate/gen_extend_func.py +554 -60
- mindspore/ops/auto_generate/gen_ops_def.py +1621 -115
- mindspore/ops/auto_generate/gen_ops_prim.py +8024 -3409
- mindspore/ops/auto_generate/pyboost_inner_prim.py +183 -79
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +229 -30
- mindspore/ops/composite/multitype_ops/pow_impl.py +0 -29
- mindspore/ops/function/__init__.py +12 -0
- mindspore/ops/function/array_func.py +561 -159
- mindspore/ops/function/clip_func.py +64 -0
- mindspore/ops/function/debug_func.py +28 -20
- mindspore/ops/function/image_func.py +1 -1
- mindspore/ops/function/linalg_func.py +5 -4
- mindspore/ops/function/math_func.py +1659 -290
- mindspore/ops/function/nn_func.py +988 -317
- mindspore/ops/function/parameter_func.py +3 -56
- mindspore/ops/function/random_func.py +243 -33
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/functional.py +18 -5
- mindspore/ops/functional_overload.py +897 -0
- mindspore/ops/operations/__init__.py +3 -2
- mindspore/ops/operations/_embedding_cache_ops.py +4 -4
- mindspore/ops/operations/_grad_ops.py +2 -34
- mindspore/ops/operations/_infer_ops.py +2 -1
- mindspore/ops/operations/_inner_ops.py +38 -8
- mindspore/ops/operations/array_ops.py +45 -303
- mindspore/ops/operations/comm_ops.py +19 -16
- mindspore/ops/operations/custom_ops.py +11 -55
- mindspore/ops/operations/debug_ops.py +42 -47
- mindspore/ops/operations/inner_ops.py +6 -4
- mindspore/ops/operations/linalg_ops.py +3 -2
- mindspore/ops/operations/manually_defined/ops_def.py +185 -104
- mindspore/ops/operations/math_ops.py +11 -216
- mindspore/ops/operations/nn_ops.py +146 -308
- mindspore/ops/primitive.py +23 -21
- mindspore/ops/tensor_method.py +1669 -0
- mindspore/ops_generate/aclnn_kernel_register_auto_cc_generator.py +110 -0
- mindspore/ops_generate/add_tensor_docs_generator.py +54 -0
- mindspore/ops_generate/arg_handler.py +0 -61
- mindspore/ops_generate/auto_grad_impl_cc_generator.py +135 -0
- mindspore/ops_generate/auto_grad_reg_cc_generator.py +93 -0
- mindspore/ops_generate/base_generator.py +11 -0
- mindspore/ops_generate/cpp_create_prim_instance_helper_generator.py +108 -0
- mindspore/ops_generate/functional_map_cpp_generator.py +491 -0
- mindspore/ops_generate/functional_overload_py_generator.py +110 -0
- mindspore/ops_generate/functions_cc_generator.py +233 -0
- mindspore/ops_generate/gen_aclnn_implement.py +110 -114
- mindspore/ops_generate/gen_constants.py +157 -3
- mindspore/ops_generate/gen_ops.py +245 -990
- mindspore/ops_generate/gen_pyboost_func.py +97 -998
- mindspore/ops_generate/gen_utils.py +119 -33
- mindspore/ops_generate/lite_ops_cpp_generator.py +155 -0
- mindspore/ops_generate/op_api_proto.py +206 -0
- mindspore/ops_generate/op_def_py_generator.py +131 -0
- mindspore/ops_generate/op_prim_py_generator.py +480 -0
- mindspore/ops_generate/op_proto.py +373 -108
- mindspore/ops_generate/op_template_parser.py +436 -0
- mindspore/ops_generate/ops_def_cc_generator.py +288 -0
- mindspore/ops_generate/ops_def_h_generator.py +74 -0
- mindspore/ops_generate/ops_name_h_generator.py +68 -0
- mindspore/ops_generate/ops_primitive_h_generator.py +81 -0
- mindspore/ops_generate/pyboost_functions_cpp_generator.py +370 -0
- mindspore/ops_generate/pyboost_functions_h_generator.py +68 -0
- mindspore/ops_generate/pyboost_functions_py_generator.py +148 -0
- mindspore/ops_generate/pyboost_grad_function_cpp_generator.py +154 -0
- mindspore/ops_generate/pyboost_inner_prim_generator.py +131 -0
- mindspore/ops_generate/pyboost_native_grad_functions_generator.py +268 -0
- mindspore/ops_generate/pyboost_op_cpp_code_generator.py +851 -0
- mindspore/ops_generate/pyboost_overload_functions_cpp_generator.py +344 -0
- mindspore/ops_generate/pyboost_utils.py +92 -33
- mindspore/ops_generate/template.py +294 -44
- mindspore/ops_generate/tensor_func_reg_cpp_generator.py +422 -0
- mindspore/parallel/__init__.py +3 -3
- mindspore/parallel/_auto_parallel_context.py +24 -33
- mindspore/parallel/_parallel_serialization.py +13 -2
- mindspore/parallel/_utils.py +4 -1
- mindspore/parallel/algo_parameter_config.py +1 -1
- mindspore/parallel/checkpoint_transform.py +44 -0
- mindspore/parallel/cluster/process_entity/_api.py +131 -37
- mindspore/parallel/cluster/process_entity/_utils.py +41 -6
- mindspore/parallel/cluster/run.py +20 -3
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +3 -0
- mindspore/parallel/transform_safetensors.py +119 -253
- mindspore/profiler/__init__.py +17 -4
- mindspore/profiler/analysis/__init__.py +0 -0
- mindspore/profiler/analysis/parser/__init__.py +0 -0
- mindspore/profiler/analysis/parser/ascend_cann_parser.py +166 -0
- mindspore/profiler/analysis/parser/base_parser.py +158 -0
- mindspore/profiler/analysis/parser/framework_cann_relation_parser.py +45 -0
- mindspore/profiler/analysis/parser/ms_framework_parser.py +142 -0
- mindspore/profiler/analysis/parser/ms_minddata_parser.py +145 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/ascend_timeline_assembler.py +261 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/base_timeline_assembler.py +40 -0
- mindspore/profiler/analysis/parser/timeline_assembly_factory/trace_view_container.py +84 -0
- mindspore/profiler/analysis/parser/timeline_creator/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_creator/base_timeline_creator.py +44 -0
- mindspore/profiler/analysis/parser/timeline_creator/cpu_op_timeline_creator.py +90 -0
- mindspore/profiler/analysis/parser/timeline_creator/fwk_timeline_creator.py +76 -0
- mindspore/profiler/analysis/parser/timeline_creator/msprof_timeline_creator.py +103 -0
- mindspore/profiler/analysis/parser/timeline_creator/scope_layer_timeline_creator.py +134 -0
- mindspore/profiler/analysis/parser/timeline_event/__init__.py +0 -0
- mindspore/profiler/analysis/parser/timeline_event/base_event.py +233 -0
- mindspore/profiler/analysis/parser/timeline_event/cpu_op_event.py +47 -0
- mindspore/profiler/analysis/parser/timeline_event/flow_event.py +36 -0
- mindspore/profiler/analysis/parser/timeline_event/fwk_event.py +260 -0
- mindspore/profiler/analysis/parser/timeline_event/msprof_event.py +73 -0
- mindspore/profiler/analysis/parser/timeline_event/scope_layer_event.py +53 -0
- mindspore/profiler/analysis/parser/timeline_event/timeline_event_pool.py +146 -0
- mindspore/profiler/analysis/task_manager.py +131 -0
- mindspore/profiler/analysis/time_converter.py +84 -0
- mindspore/profiler/analysis/viewer/__init__.py +0 -0
- mindspore/profiler/analysis/viewer/ascend_communication_viewer.py +333 -0
- mindspore/profiler/analysis/viewer/ascend_integrate_viewer.py +87 -0
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +252 -0
- mindspore/profiler/analysis/viewer/ascend_memory_viewer.py +313 -0
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +322 -0
- mindspore/profiler/analysis/viewer/ascend_step_trace_time_viewer.py +265 -0
- mindspore/profiler/analysis/viewer/ascend_timeline_viewer.py +58 -0
- mindspore/profiler/analysis/viewer/base_viewer.py +26 -0
- mindspore/profiler/analysis/viewer/ms_dataset_viewer.py +97 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +581 -0
- mindspore/profiler/analysis/work_flow.py +73 -0
- mindspore/profiler/common/ascend_msprof_exporter.py +138 -0
- mindspore/profiler/common/command_executor.py +90 -0
- mindspore/profiler/common/constant.py +174 -3
- mindspore/profiler/common/file_manager.py +208 -0
- mindspore/profiler/common/log.py +130 -0
- mindspore/profiler/common/msprof_cmd_tool.py +202 -0
- mindspore/profiler/common/path_manager.py +371 -0
- mindspore/profiler/common/process_bar.py +168 -0
- mindspore/profiler/common/process_pool.py +9 -3
- mindspore/profiler/common/profiler_context.py +476 -0
- mindspore/profiler/common/profiler_info.py +304 -0
- mindspore/profiler/common/profiler_output_path.py +284 -0
- mindspore/profiler/common/profiler_parameters.py +210 -0
- mindspore/profiler/common/profiler_path_manager.py +120 -0
- mindspore/profiler/common/record_function.py +76 -0
- mindspore/profiler/common/tlv_decoder.py +76 -0
- mindspore/profiler/common/util.py +75 -2
- mindspore/profiler/dynamic_profiler.py +270 -37
- mindspore/profiler/envprofiler.py +138 -0
- mindspore/profiler/mstx.py +199 -0
- mindspore/profiler/platform/__init__.py +21 -0
- mindspore/profiler/platform/base_profiler.py +40 -0
- mindspore/profiler/platform/cpu_profiler.py +124 -0
- mindspore/profiler/platform/gpu_profiler.py +74 -0
- mindspore/profiler/platform/npu_profiler.py +309 -0
- mindspore/profiler/profiler.py +580 -93
- mindspore/profiler/profiler_action_controller.py +187 -0
- mindspore/profiler/profiler_interface.py +114 -0
- mindspore/profiler/schedule.py +208 -0
- mindspore/rewrite/api/symbol_tree.py +1 -2
- mindspore/run_check/_check_version.py +2 -6
- mindspore/runtime/__init__.py +37 -0
- mindspore/runtime/device.py +27 -0
- mindspore/runtime/event.py +209 -0
- mindspore/runtime/executor.py +148 -0
- mindspore/runtime/memory.py +392 -0
- mindspore/runtime/stream.py +460 -0
- mindspore/runtime/thread_bind_core.py +401 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/__init__.py +2 -2
- mindspore/train/_utils.py +53 -18
- mindspore/train/amp.py +8 -4
- mindspore/train/callback/_checkpoint.py +32 -18
- mindspore/train/callback/_early_stop.py +1 -1
- mindspore/train/callback/_flops_collector.py +105 -69
- mindspore/train/callback/_history.py +1 -1
- mindspore/train/callback/_summary_collector.py +44 -6
- mindspore/train/callback/_tft_register.py +31 -10
- mindspore/train/dataset_helper.py +11 -11
- mindspore/train/metrics/precision.py +4 -5
- mindspore/train/mind_ir_pb2.py +167 -46
- mindspore/train/model.py +13 -15
- mindspore/train/serialization.py +462 -76
- mindspore/train/summary/summary_record.py +1 -2
- mindspore/train/train_thor/model_thor.py +1 -1
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +4 -2
- mindspore/utils/dryrun.py +138 -0
- mindspore/utils/runtime_execution_order_check.py +550 -0
- mindspore/version.py +1 -1
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/METADATA +2 -3
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/RECORD +362 -238
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/entry_points.txt +1 -1
- mindspore/common/_tensor_overload.py +0 -139
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/profiler/envprofiling.py +0 -254
- mindspore/profiler/profiling.py +0 -1926
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/WHEEL +0 -0
- {mindspore-2.4.10.dist-info → mindspore-2.5.0.dist-info}/top_level.txt +0 -0
|
@@ -28,7 +28,11 @@ from mindspore import _checkparam as validator
|
|
|
28
28
|
from mindspore.common import dtype as mstype
|
|
29
29
|
from mindspore.nn.cell import Cell
|
|
30
30
|
from mindspore.nn.layer.normalization import LayerNormExt as LayerNorm
|
|
31
|
-
from mindspore.
|
|
31
|
+
from mindspore.communication import get_group_size
|
|
32
|
+
from mindspore.communication._comm_helper import GlobalComm
|
|
33
|
+
from mindspore.ops.function import batch_norm
|
|
34
|
+
|
|
35
|
+
from ._functions import _SyncBatchNorm
|
|
32
36
|
|
|
33
37
|
|
|
34
38
|
class _NormBase(Cell):
|
|
@@ -43,6 +47,7 @@ class _NormBase(Cell):
|
|
|
43
47
|
dtype=None
|
|
44
48
|
) -> None:
|
|
45
49
|
super(_NormBase, self).__init__()
|
|
50
|
+
self.set_train()
|
|
46
51
|
self.shape = ops.Shape()
|
|
47
52
|
self.num_features = num_features
|
|
48
53
|
self.eps = eps
|
|
@@ -55,8 +60,6 @@ class _NormBase(Cell):
|
|
|
55
60
|
Tensor(np.empty(num_features), dtype=self.dtype), name="weight")
|
|
56
61
|
self.bias = Parameter(
|
|
57
62
|
Tensor(np.empty(num_features), dtype=self.dtype), name="bias")
|
|
58
|
-
self.weight: Optional[Parameter]
|
|
59
|
-
self.bias: Optional[Parameter]
|
|
60
63
|
else:
|
|
61
64
|
self.weight = None
|
|
62
65
|
self.bias = None
|
|
@@ -65,11 +68,8 @@ class _NormBase(Cell):
|
|
|
65
68
|
requires_grad=False, name="running_mean")
|
|
66
69
|
self.running_var = Parameter(Tensor(np.ones(num_features), dtype=self.dtype),
|
|
67
70
|
requires_grad=False, name="running_var")
|
|
68
|
-
self.running_mean: Optional[Tensor]
|
|
69
|
-
self.running_var: Optional[Tensor]
|
|
70
71
|
self.num_batches_tracked = Parameter(Tensor(0, dtype=ms.float32),
|
|
71
72
|
requires_grad=False, name="num_batches_tracked")
|
|
72
|
-
self.num_batches_tracked: Optional[Tensor]
|
|
73
73
|
else:
|
|
74
74
|
self.running_mean = None
|
|
75
75
|
self.running_var = None
|
|
@@ -122,6 +122,7 @@ class _BatchNorm(_NormBase):
|
|
|
122
122
|
dtype)
|
|
123
123
|
self.training = True
|
|
124
124
|
|
|
125
|
+
|
|
125
126
|
def _check_input_dim(self, input):
|
|
126
127
|
raise NotImplementedError
|
|
127
128
|
|
|
@@ -206,7 +207,7 @@ class BatchNorm1d(_BatchNorm):
|
|
|
206
207
|
Tensor, has the same type and shape as `input`.
|
|
207
208
|
|
|
208
209
|
Raises:
|
|
209
|
-
TypeError: If `num_features` is not
|
|
210
|
+
TypeError: If `num_features` is not an int number.
|
|
210
211
|
TypeError: If `eps` is not a float.
|
|
211
212
|
ValueError: If `num_features` is less than 1.
|
|
212
213
|
|
|
@@ -241,7 +242,7 @@ class BatchNorm2d(_BatchNorm):
|
|
|
241
242
|
|
|
242
243
|
.. math::
|
|
243
244
|
|
|
244
|
-
y = \frac{x -
|
|
245
|
+
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
245
246
|
|
|
246
247
|
The mean and standard-deviation are calculated per-dimension over
|
|
247
248
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
@@ -273,7 +274,7 @@ class BatchNorm2d(_BatchNorm):
|
|
|
273
274
|
Tensor, has the same type and shape as `input`.
|
|
274
275
|
|
|
275
276
|
Raises:
|
|
276
|
-
TypeError: If `num_features` is not
|
|
277
|
+
TypeError: If `num_features` is not an int number.
|
|
277
278
|
TypeError: If `eps` is not a float.
|
|
278
279
|
ValueError: If `num_features` is less than 1.
|
|
279
280
|
|
|
@@ -311,7 +312,7 @@ class BatchNorm3d(_BatchNorm):
|
|
|
311
312
|
|
|
312
313
|
.. math::
|
|
313
314
|
|
|
314
|
-
y = \frac{x -
|
|
315
|
+
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
315
316
|
|
|
316
317
|
The mean and standard-deviation are calculated per-dimension over
|
|
317
318
|
the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
|
|
@@ -343,7 +344,7 @@ class BatchNorm3d(_BatchNorm):
|
|
|
343
344
|
Tensor, has the same type and shape as `input`.
|
|
344
345
|
|
|
345
346
|
Raises:
|
|
346
|
-
TypeError: If `num_features` is not
|
|
347
|
+
TypeError: If `num_features` is not an int number.
|
|
347
348
|
TypeError: If `eps` is not a float.
|
|
348
349
|
ValueError: If `num_features` is less than 1.
|
|
349
350
|
|
|
@@ -402,7 +403,7 @@ class GroupNorm(Cell):
|
|
|
402
403
|
additional dimensions.
|
|
403
404
|
|
|
404
405
|
Outputs:
|
|
405
|
-
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `
|
|
406
|
+
Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `input`.
|
|
406
407
|
|
|
407
408
|
Raises:
|
|
408
409
|
TypeError: If `num_groups` or `num_channels` is not an int.
|
|
@@ -435,8 +436,8 @@ class GroupNorm(Cell):
|
|
|
435
436
|
"""Initialize GroupNorm."""
|
|
436
437
|
super(GroupNorm, self).__init__()
|
|
437
438
|
ms_dtype = mstype.float32 if dtype is None else dtype
|
|
438
|
-
|
|
439
|
-
|
|
439
|
+
weight_init = 'ones'
|
|
440
|
+
bias_init = 'zeros'
|
|
440
441
|
|
|
441
442
|
self.num_groups = validator.check_positive_int(
|
|
442
443
|
num_groups, "num_groups", self.cls_name)
|
|
@@ -450,14 +451,14 @@ class GroupNorm(Cell):
|
|
|
450
451
|
self.affine = validator.check_bool(
|
|
451
452
|
affine, arg_name="affine", prim_name=self.cls_name)
|
|
452
453
|
|
|
453
|
-
self.
|
|
454
|
-
|
|
455
|
-
self.
|
|
456
|
-
|
|
454
|
+
self.weight = Parameter(initializer(
|
|
455
|
+
weight_init, self.num_channels, dtype=ms_dtype), name="weight", requires_grad=affine)
|
|
456
|
+
self.bias = Parameter(initializer(
|
|
457
|
+
bias_init, self.num_channels, dtype=ms_dtype), name="bias", requires_grad=affine)
|
|
457
458
|
|
|
458
459
|
def _cal_output(self, x):
|
|
459
460
|
"""calculate groupnorm output"""
|
|
460
|
-
return group_norm(x, self.num_groups, self.
|
|
461
|
+
return ops.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
|
461
462
|
|
|
462
463
|
def extend_repr(self):
|
|
463
464
|
return 'num_groups={}, num_channels={}, eps={}, affine={}'.format(
|
|
@@ -468,10 +469,205 @@ class GroupNorm(Cell):
|
|
|
468
469
|
return output
|
|
469
470
|
|
|
470
471
|
|
|
472
|
+
class SyncBatchNorm(_BatchNorm):
|
|
473
|
+
r"""
|
|
474
|
+
Sync Batch Normalization layer over a N-dimension input.
|
|
475
|
+
|
|
476
|
+
Sync Batch Normalization is cross device synchronized Batch Normalization. The implementation of Batch
|
|
477
|
+
Normalization only normalizes the data within each device. Sync Batch Normalization will normalize the input
|
|
478
|
+
within the group. It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by
|
|
479
|
+
Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
|
|
480
|
+
feature using a mini-batch of data and the learned parameters which can be described in the following formula.
|
|
481
|
+
|
|
482
|
+
.. math::
|
|
483
|
+
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
|
484
|
+
|
|
485
|
+
.. warning::
|
|
486
|
+
This is an experimental API that is subject to change or deletion.
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
num_features (int): `C` from an expected input of size :math:`(N, C, +)`.
|
|
490
|
+
eps (float): :math:`\epsilon`, a value added to the denominator for numerical stability. Default: ``1e-5`` .
|
|
491
|
+
momentum (float): A floating hyperparameter of the momentum for the
|
|
492
|
+
running_mean and running_var computation. Default: ``0.1`` .
|
|
493
|
+
affine (bool): A bool value. When set to ``True`` , :math:`\gamma` and :math:`\beta` can be learned.
|
|
494
|
+
Default: ``True`` .
|
|
495
|
+
track_running_stats (bool, optional): a boolean value that when set to ``True``, this
|
|
496
|
+
cell tracks the running mean and variance, and when set to ``False``,
|
|
497
|
+
this cell does not track such statistics. And this cell always uses batch statistics
|
|
498
|
+
in both training and eval modes. Default: ``True`` .
|
|
499
|
+
process_group (str, optional): synchronization of stats happen within each process group individually.
|
|
500
|
+
Default behavior is synchronization across the whole world. Default: ``None`` .
|
|
501
|
+
dtype (:class:`mindspore.dtype`, optional): Dtype of Parameters. Default: ``None`` .
|
|
502
|
+
|
|
503
|
+
Inputs:
|
|
504
|
+
- **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, +)`.
|
|
505
|
+
|
|
506
|
+
Outputs:
|
|
507
|
+
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, +)`.
|
|
508
|
+
|
|
509
|
+
Raises:
|
|
510
|
+
TypeError: If `num_features` is not an int.
|
|
511
|
+
TypeError: If `eps` is not a float.
|
|
512
|
+
ValueError: If `num_features` is less than 1.
|
|
513
|
+
ValueError: If `momentum` is not in range [0, 1].
|
|
514
|
+
ValueError: If rank_id in `process_group` is not in range [0, rank_size).
|
|
515
|
+
|
|
516
|
+
Supported Platforms:
|
|
517
|
+
``Ascend``
|
|
518
|
+
|
|
519
|
+
Examples:
|
|
520
|
+
.. note::
|
|
521
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
522
|
+
|
|
523
|
+
For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
|
|
524
|
+
Here, examples use msrun to pull multi-process distributed tasks across nodes with a single command
|
|
525
|
+
line instruction.
|
|
526
|
+
Please see the `Ascend tutorial
|
|
527
|
+
<https://www.mindspore.cn/docs/en/master/model_train/parallel/msrun_launcher.html>`_
|
|
528
|
+
for more details.
|
|
529
|
+
|
|
530
|
+
This example should be run with multiple devices.
|
|
531
|
+
|
|
532
|
+
>>> # Firstly, preparing test_syncbn.py:
|
|
533
|
+
>>> import numpy as np
|
|
534
|
+
>>> import mindspore
|
|
535
|
+
>>> import mindspore.context as context
|
|
536
|
+
>>> from mindspore.mint.nn.layer import SyncBatchNorm
|
|
537
|
+
>>> from mindspore import Tensor
|
|
538
|
+
>>> from mindspore.communication import init, create_group, get_local_rank
|
|
539
|
+
>>> init()
|
|
540
|
+
>>> context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
|
541
|
+
>>> group = "0-1"
|
|
542
|
+
>>> rank_ids = [0, 1]
|
|
543
|
+
>>> create_group(group, rank_ids)
|
|
544
|
+
>>> sync_batch_norm = SyncBatchNorm(num_features=2, process_group=group, dtype=mindspore.float32)
|
|
545
|
+
>>> sync_batch_norm.set_train(False)
|
|
546
|
+
>>> input_x = Tensor(np.linspace(0, 5, 2*2*2*2), mindspore.float32).reshape(2, 2, 2, 2)
|
|
547
|
+
>>> output_data = sync_batch_norm(input_x)
|
|
548
|
+
>>> # Then, executing the command such as the following:
|
|
549
|
+
>>> # msrun --worker_num=2 --local_worker_num=2 --master_port=8975 --log_dir=msrun_log --join=True
|
|
550
|
+
>>> # --cluster_time_out=100 pytest -s -v test_syncbn.py
|
|
551
|
+
|
|
552
|
+
"""
|
|
553
|
+
def __init__(self,
|
|
554
|
+
num_features: int,
|
|
555
|
+
eps: float = 1e-5,
|
|
556
|
+
momentum: float = 0.1,
|
|
557
|
+
affine: bool = True,
|
|
558
|
+
track_running_stats: bool = True,
|
|
559
|
+
process_group: Optional[str] = None,
|
|
560
|
+
dtype=None):
|
|
561
|
+
super(SyncBatchNorm, self).__init__(
|
|
562
|
+
num_features, eps, momentum, affine, track_running_stats, dtype
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
self.process_group = process_group if process_group else GlobalComm.WORLD_COMM_GROUP
|
|
566
|
+
self.world_size = get_group_size(self.process_group)
|
|
567
|
+
self.sync_batch_norm = _SyncBatchNorm(
|
|
568
|
+
self.num_features, self.world_size, self.dtype)
|
|
569
|
+
|
|
570
|
+
def _check_input_dim(self, input):
|
|
571
|
+
if input.ndim < 2:
|
|
572
|
+
raise ValueError(
|
|
573
|
+
"expected at least 2D input (got {}D input)".format(input.ndim)
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
def _check_non_zero_input_channels(self, input):
|
|
577
|
+
if input.shape[1] == 0:
|
|
578
|
+
raise ValueError(
|
|
579
|
+
"SyncBatchNorm number of input channels should be non-zero"
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
def construct(self, input: Tensor) -> Tensor:
|
|
583
|
+
# currently only GPU input is supported
|
|
584
|
+
|
|
585
|
+
self._check_input_dim(input)
|
|
586
|
+
self._check_non_zero_input_channels(input)
|
|
587
|
+
|
|
588
|
+
# exponential_average_factor is set to self.momentum
|
|
589
|
+
# (when it is available) only so that it gets updated
|
|
590
|
+
# in ONNX graph when this node is exported to ONNX.
|
|
591
|
+
if self.momentum is None:
|
|
592
|
+
exponential_average_factor = 0.0
|
|
593
|
+
else:
|
|
594
|
+
exponential_average_factor = self.momentum
|
|
595
|
+
|
|
596
|
+
if self.training and self.track_running_stats:
|
|
597
|
+
one_tensor = Tensor(1, dtype=ms.float32)
|
|
598
|
+
ops.assign_add(self.num_batches_tracked, one_tensor)
|
|
599
|
+
if self.momentum is None: # use cumulative moving average
|
|
600
|
+
exponential_average_factor = 1.0 / self.num_batches_tracked.value()
|
|
601
|
+
else: # use exponential moving average
|
|
602
|
+
exponential_average_factor = self.momentum
|
|
603
|
+
|
|
604
|
+
r"""
|
|
605
|
+
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
|
|
606
|
+
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
|
|
607
|
+
"""
|
|
608
|
+
if self.training:
|
|
609
|
+
bn_training = True
|
|
610
|
+
else:
|
|
611
|
+
bn_training = (self.running_mean is None) and (
|
|
612
|
+
self.running_var is None)
|
|
613
|
+
|
|
614
|
+
r"""
|
|
615
|
+
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
|
|
616
|
+
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
|
|
617
|
+
used for normalization (i.e. in eval mode when buffers are not None).
|
|
618
|
+
"""
|
|
619
|
+
# If buffers are not to be tracked, ensure that they won't be updated
|
|
620
|
+
running_mean = (
|
|
621
|
+
self.running_mean if not self.training or self.track_running_stats else None
|
|
622
|
+
)
|
|
623
|
+
running_var = (
|
|
624
|
+
self.running_var if not self.training or self.track_running_stats else None
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
# Don't sync batchnorm stats in inference mode (model.eval()).
|
|
628
|
+
need_sync = (bn_training and self.training)
|
|
629
|
+
if need_sync:
|
|
630
|
+
need_sync = self.world_size > 1
|
|
631
|
+
|
|
632
|
+
# fallback to framework BN when synchronization is not necessary
|
|
633
|
+
if not need_sync:
|
|
634
|
+
if self.weight is None:
|
|
635
|
+
weight = Tensor(np.ones(self.num_features), dtype=self.dtype)
|
|
636
|
+
else:
|
|
637
|
+
weight = self.weight
|
|
638
|
+
if self.bias is None:
|
|
639
|
+
bias = Tensor(np.zeros(self.num_features), dtype=self.dtype)
|
|
640
|
+
else:
|
|
641
|
+
bias = self.bias
|
|
642
|
+
if running_mean is None or running_var is None:
|
|
643
|
+
raise ValueError(
|
|
644
|
+
"running mean or running var can\'t be none for batch_norm.")
|
|
645
|
+
return batch_norm(input,
|
|
646
|
+
running_mean,
|
|
647
|
+
running_var,
|
|
648
|
+
weight,
|
|
649
|
+
bias,
|
|
650
|
+
bn_training,
|
|
651
|
+
exponential_average_factor,
|
|
652
|
+
self.eps)
|
|
653
|
+
else:
|
|
654
|
+
output = self.sync_batch_norm(input,
|
|
655
|
+
self.weight,
|
|
656
|
+
self.bias,
|
|
657
|
+
running_mean,
|
|
658
|
+
running_var,
|
|
659
|
+
self.eps,
|
|
660
|
+
exponential_average_factor,
|
|
661
|
+
self.process_group,
|
|
662
|
+
self.world_size)
|
|
663
|
+
return output
|
|
664
|
+
|
|
665
|
+
|
|
471
666
|
__all__ = [
|
|
472
667
|
'GroupNorm',
|
|
473
668
|
'BatchNorm1d',
|
|
474
669
|
'BatchNorm2d',
|
|
475
670
|
'BatchNorm3d',
|
|
476
671
|
'LayerNorm',
|
|
672
|
+
'SyncBatchNorm',
|
|
477
673
|
]
|