mindspore 2.7.0__cp311-cp311-win_amd64.whl → 2.7.1__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp311-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp311-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp311-win_amd64.pyd +0 -0
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +8 -1
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +275 -64
- mindspore/common/_utils.py +0 -44
- mindspore/common/api.py +285 -35
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
- mindspore/common/hook_handle.py +60 -0
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/parameter.py +13 -107
- mindspore/common/recompute.py +4 -11
- mindspore/common/tensor.py +16 -169
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +5 -85
- mindspore/dataset/engine/datasets.py +8 -4
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dnnl.dll +0 -0
- mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +1 -1
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/distributed.py +182 -62
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +4 -4
- mindspore/mint/nn/layer/normalization.py +8 -13
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +16 -66
- mindspore/nn/layer/basic.py +49 -1
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +31 -124
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +0 -1
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +0 -1
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/wrap/grad_reducer.py +4 -74
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/fft.py +9 -9
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +0 -5
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
- mindspore/ops/auto_generate/gen_extend_func.py +2 -7
- mindspore/ops/auto_generate/gen_ops_def.py +98 -141
- mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +15 -1
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +1 -0
- mindspore/ops/function/array_func.py +14 -12
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +3 -4
- mindspore/ops/function/math_func.py +45 -54
- mindspore/ops/function/nn_func.py +75 -294
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +2 -0
- mindspore/ops/functional_overload.py +354 -18
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +1 -38
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +1 -0
- mindspore/ops/operations/comm_ops.py +94 -2
- mindspore/ops/operations/custom_ops.py +228 -19
- mindspore/ops/operations/debug_ops.py +27 -29
- mindspore/ops/operations/manually_defined/ops_def.py +27 -306
- mindspore/ops/operations/nn_ops.py +2 -2
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +1 -17
- mindspore/ops/tensor_method.py +72 -3
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +1 -4
- mindspore/parallel/_utils.py +29 -6
- mindspore/parallel/checkpoint_transform.py +18 -2
- mindspore/parallel/cluster/process_entity/_api.py +24 -32
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +117 -16
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +38 -2
- mindspore/profiler/common/path_manager.py +56 -24
- mindspore/profiler/common/profiler_context.py +2 -12
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/experimental_config.py +2 -1
- mindspore/profiler/platform/npu_profiler.py +33 -6
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +3 -2
- mindspore/runtime/executor.py +11 -3
- mindspore/runtime/memory.py +112 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +5 -18
- mindspore/train/amp.py +6 -4
- mindspore/train/callback/_checkpoint.py +0 -9
- mindspore/train/callback/_train_fault_tolerance.py +69 -18
- mindspore/train/data_sink.py +1 -5
- mindspore/train/model.py +38 -211
- mindspore/train/serialization.py +126 -387
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +144 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/train/memory_profiling_pb2.py +0 -298
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
mindspore/mint/__init__.py
CHANGED
|
@@ -486,6 +486,11 @@ from mindspore.ops.auto_generate.pyboost_inner_prim import squeeze_impl
|
|
|
486
486
|
from mindspore.ops.auto_generate.gen_ops_prim import equal_ext_op
|
|
487
487
|
|
|
488
488
|
|
|
489
|
+
# 1101
|
|
490
|
+
from mindspore.ops.functional_overload import real
|
|
491
|
+
# 1102
|
|
492
|
+
from mindspore.ops.functional_overload import imag
|
|
493
|
+
|
|
489
494
|
# 1023
|
|
490
495
|
from mindspore.ops.function.array_func import unbind_ext as unbind
|
|
491
496
|
|
|
@@ -609,6 +614,9 @@ def cat(tensors, dim=0):
|
|
|
609
614
|
|
|
610
615
|
(x_1, x_2, ..., \sum_{i=1}^Nx_{mi}, ..., x_R)
|
|
611
616
|
|
|
617
|
+
.. warning::
|
|
618
|
+
Input tensor of inconsistent types are not supported in Graph Mode under Dynamic Shape.
|
|
619
|
+
|
|
612
620
|
Args:
|
|
613
621
|
tensors (Union[tuple, list]): A tuple or a list of input tensors.
|
|
614
622
|
Suppose there are two tensors in this tuple or list, namely t1 and t2.
|
|
@@ -620,7 +628,6 @@ def cat(tensors, dim=0):
|
|
|
620
628
|
|
|
621
629
|
Returns:
|
|
622
630
|
Tensor, the shape is :math:`(x_1, x_2, ..., \sum_{i=1}^Nx_{mi}, ..., x_R)`.
|
|
623
|
-
The data type is the same with `tensors`.
|
|
624
631
|
|
|
625
632
|
Raises:
|
|
626
633
|
TypeError: If `dim` is not an int.
|
|
@@ -658,7 +665,7 @@ def concat(tensors, dim=0):
|
|
|
658
665
|
Alias for :func:`mindspore.mint.cat`.
|
|
659
666
|
|
|
660
667
|
.. warning::
|
|
661
|
-
|
|
668
|
+
Input tensor of inconsistent types are not supported in Graph Mode under Dynamic Shape.
|
|
662
669
|
|
|
663
670
|
Supported Platforms:
|
|
664
671
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -744,9 +751,6 @@ def equal(input, other):
|
|
|
744
751
|
Note:
|
|
745
752
|
`input` and `other` comply with the implicit type conversion rules to make the data types consistent.
|
|
746
753
|
|
|
747
|
-
.. warning::
|
|
748
|
-
This is an experimental API that is subject to change or deletion.
|
|
749
|
-
|
|
750
754
|
Args:
|
|
751
755
|
input (Tensor): The first input.
|
|
752
756
|
other (Tensor): The second input.
|
|
@@ -1323,8 +1327,6 @@ def swapaxes(input, axis0, axis1):
|
|
|
1323
1327
|
Alias for :func:`mindspore.mint.transpose` . The `input` corresponds to the `input` in the reference interface,
|
|
1324
1328
|
and the parameters `axis0` and `axis1` correspond to `dim0` and `dim1` in the reference interface respectively.
|
|
1325
1329
|
|
|
1326
|
-
For more details, see :func:`mindspore.mint.transpose` .
|
|
1327
|
-
|
|
1328
1330
|
.. warning::
|
|
1329
1331
|
This is an experimental API that is subject to change or deletion.
|
|
1330
1332
|
|
|
@@ -1438,8 +1440,6 @@ def fix(input):
|
|
|
1438
1440
|
"""
|
|
1439
1441
|
Alias for :func:`mindspore.mint.trunc` .
|
|
1440
1442
|
|
|
1441
|
-
For more details, see :func:`mindspore.mint.trunc` .
|
|
1442
|
-
|
|
1443
1443
|
Supported Platforms:
|
|
1444
1444
|
``Ascend``
|
|
1445
1445
|
"""
|
|
@@ -1535,7 +1535,7 @@ def cdist(x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary'):
|
|
|
1535
1535
|
x2 (Tensor): Input tensor of shape :math:`(B, R, M)`, has the same dtype as `x1`.
|
|
1536
1536
|
p (float, optional): P value for the p-norm distance to calculate between each
|
|
1537
1537
|
vector pair, P >= 0. Default: ``2.0`` .
|
|
1538
|
-
compute_mode (
|
|
1538
|
+
compute_mode (str, optional): Specify the cumpute mode. Setting this parameter currently has no effect.
|
|
1539
1539
|
Default: ``'use_mm_for_euclid_dist_if_necessary'`` .
|
|
1540
1540
|
|
|
1541
1541
|
Returns:
|
|
@@ -2061,6 +2061,11 @@ __all__ = [
|
|
|
2061
2061
|
|
|
2062
2062
|
# 1100
|
|
2063
2063
|
'diff',
|
|
2064
|
+
|
|
2065
|
+
# 1101
|
|
2066
|
+
'real',
|
|
2067
|
+
# 1102
|
|
2068
|
+
'imag',
|
|
2064
2069
|
]
|
|
2065
2070
|
|
|
2066
2071
|
__all__.extend(functional.__all__)
|
|
@@ -18,13 +18,16 @@ import hashlib
|
|
|
18
18
|
import builtins
|
|
19
19
|
import io
|
|
20
20
|
import pickle
|
|
21
|
+
from datetime import timedelta
|
|
21
22
|
import numpy as np
|
|
22
23
|
from mindspore import log as logger
|
|
23
24
|
from mindspore.common import dtype as mstype
|
|
25
|
+
from mindspore._checkparam import args_type_check
|
|
24
26
|
from mindspore.ops import ReduceOp, cat
|
|
25
27
|
from mindspore.common.tensor import Tensor
|
|
26
28
|
from mindspore._c_expression import TensorPy as Tensor_
|
|
27
29
|
from mindspore.ops.primitive import _primexpr
|
|
30
|
+
from mindspore.common.api import _pynative_executor
|
|
28
31
|
from mindspore.communication._comm_helper import (
|
|
29
32
|
_destroy_group_helper,
|
|
30
33
|
_get_rank_helper,
|
|
@@ -33,10 +36,11 @@ from mindspore.communication._comm_helper import (
|
|
|
33
36
|
_get_group_ranks,
|
|
34
37
|
_is_available,
|
|
35
38
|
_is_initialized,
|
|
39
|
+
_ExistingGroup,
|
|
36
40
|
)
|
|
41
|
+
from mindspore.communication.management import _init_without_sched
|
|
37
42
|
from mindspore.communication import (
|
|
38
43
|
init,
|
|
39
|
-
release,
|
|
40
44
|
get_group_size,
|
|
41
45
|
get_world_rank_from_group_rank,
|
|
42
46
|
create_group,
|
|
@@ -72,7 +76,7 @@ from mindspore.ops.auto_generate.gen_ops_prim import (
|
|
|
72
76
|
dist_comm_barrier_op,
|
|
73
77
|
dist_comm_batch_isend_irecv_op,
|
|
74
78
|
)
|
|
75
|
-
from mindspore._c_expression import TCPStoreClient, GroupOptions
|
|
79
|
+
from mindspore._c_expression import TCPStoreClient, GroupOptions, _finalize_collective
|
|
76
80
|
|
|
77
81
|
_pickler = pickle.Pickler
|
|
78
82
|
_unpickler = pickle.Unpickler
|
|
@@ -146,28 +150,26 @@ class TCPStore:
|
|
|
146
150
|
|
|
147
151
|
Note:
|
|
148
152
|
- The function is implemented by CPU and does not involve any hardware operations related to Ascend.
|
|
149
|
-
- Currently, all parameters provided by the TCPStore class constructor are not supported
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
- The current
|
|
153
|
+
- Currently, all parameters provided by the TCPStore class constructor are not supported
|
|
154
|
+
except for `host_name`, `port`, `world_size`, `is_master`, `timeout` and `wait_for_workers`,
|
|
155
|
+
which are reserved parameters and invalid settings.
|
|
156
|
+
- The current TCPStore function is limited and only supports scenarios where the key is
|
|
153
157
|
less than 4k and the value is less than 1G. Complex scenarios are to be supported.
|
|
154
|
-
- The timeout interval for message sending and receiving in the TcpStore function is controlled by
|
|
155
|
-
the `MS_RECEIVE_MSG_TIMEOUT` environment variable, in seconds, with a default value of ``15``.
|
|
156
|
-
If a timeout occurs, the user needs to increase the configuration value.
|
|
157
158
|
|
|
158
159
|
Args:
|
|
159
|
-
host_name (str
|
|
160
|
-
|
|
161
|
-
port (int
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
is_master (bool,
|
|
160
|
+
host_name (str): The hostname or IP Address the server store should run on.
|
|
161
|
+
Currently only supports user input IP addresses.
|
|
162
|
+
port (int): The port on which the server store should listen for incoming requests.
|
|
163
|
+
world_size (int, optional): The total number of store users (number of clients + 1 for the server).
|
|
164
|
+
Default is ``None``, indicates a non-fixed number of store users. This parameter is
|
|
165
|
+
only valid for the server.
|
|
166
|
+
is_master (bool, optional): True when initializing the server store and False for client stores.
|
|
166
167
|
Default is ``False``.
|
|
167
|
-
timeout (timedelta,
|
|
168
|
-
|
|
169
|
-
wait_for_workers (bool,
|
|
170
|
-
store. This is only applicable when `world_size` is a fixed value. Default is ``True``.
|
|
168
|
+
timeout (timedelta, optional): Timeout used by the store during initialization. Default is
|
|
169
|
+
``timedelta(seconds=300)``.
|
|
170
|
+
wait_for_workers (bool, optional): Whether to wait for all the workers to connect with the server
|
|
171
|
+
store. This is only applicable when `world_size` is a fixed value. Default is ``True``. This
|
|
172
|
+
parameter is only valid for the server.
|
|
171
173
|
multi_tenant (bool, invalid, optional): If ``True``, all ``TCPStore`` instances in the current process with
|
|
172
174
|
the same host/port will use the same underlying ``TCPServer``. Default is ``False``.
|
|
173
175
|
master_listen_fd (int, invalid, optional): If specified, the underlying ``TCPServer`` will listen on this file
|
|
@@ -193,12 +195,106 @@ class TCPStore:
|
|
|
193
195
|
for more details.
|
|
194
196
|
|
|
195
197
|
>>> from mindspore.mint.distributed import TCPStore
|
|
196
|
-
>>> store = TCPStore()
|
|
198
|
+
>>> store = TCPStore("127.0.0.1", 1234)
|
|
197
199
|
"""
|
|
198
200
|
|
|
199
|
-
def __init__(self, host_name
|
|
201
|
+
def __init__(self, host_name, port, world_size=None, is_master=False, timeout=timedelta(seconds=300),
|
|
200
202
|
wait_for_workers=True, multi_tenant=False, master_listen_fd=None, use_libuv=True):
|
|
201
|
-
|
|
203
|
+
if not isinstance(host_name, str):
|
|
204
|
+
raise TypeError(
|
|
205
|
+
"For 'TCPStore', the argument 'host_name' must be type of string, "
|
|
206
|
+
"but got 'host_name' type : {}.".format(type(host_name))
|
|
207
|
+
)
|
|
208
|
+
if not isinstance(port, int):
|
|
209
|
+
raise TypeError(
|
|
210
|
+
"For 'TCPStore', the argument 'port' must be type of int, "
|
|
211
|
+
"but got 'port' type : {}.".format(type(port))
|
|
212
|
+
)
|
|
213
|
+
if not isinstance(is_master, bool):
|
|
214
|
+
raise TypeError(
|
|
215
|
+
"For 'TCPStore', the argument 'is_master' must be type of bool, "
|
|
216
|
+
"but got 'is_master' type : {}.".format(type(is_master))
|
|
217
|
+
)
|
|
218
|
+
if not isinstance(timeout, timedelta):
|
|
219
|
+
raise TypeError(
|
|
220
|
+
"For 'TCPStore', the argument 'timeout' must be type of timedelta, "
|
|
221
|
+
"but got 'timeout' type : {}.".format(type(timeout))
|
|
222
|
+
)
|
|
223
|
+
if not isinstance(wait_for_workers, bool):
|
|
224
|
+
raise TypeError(
|
|
225
|
+
"For 'TCPStore', the argument 'wait_for_workers' must be type of bool, "
|
|
226
|
+
"but got 'wait_for_workers' type : {}.".format(type(wait_for_workers))
|
|
227
|
+
)
|
|
228
|
+
if world_size is None:
|
|
229
|
+
world_size = 1
|
|
230
|
+
if not isinstance(world_size, int):
|
|
231
|
+
raise TypeError(
|
|
232
|
+
"For 'TCPStore', the argument 'world_size' must be type of int, "
|
|
233
|
+
"but got 'world_size' type : {}.".format(type(world_size))
|
|
234
|
+
)
|
|
235
|
+
if port < 0 or port > 65535:
|
|
236
|
+
raise ValueError(
|
|
237
|
+
"For 'TCPStore', the argument 'port' must be legal, "
|
|
238
|
+
f"but got {port}."
|
|
239
|
+
)
|
|
240
|
+
if world_size <= 0:
|
|
241
|
+
raise ValueError(
|
|
242
|
+
"For 'TCPStore', the argument 'world_size' must be legal, "
|
|
243
|
+
f"but got {world_size}."
|
|
244
|
+
)
|
|
245
|
+
timeout_ms = int(timeout.total_seconds() * 1000)
|
|
246
|
+
self.instance = TCPStoreClient(host_name, port, is_master, timeout_ms, world_size, wait_for_workers)
|
|
247
|
+
self.host = host_name
|
|
248
|
+
self.port = port
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def add(self, key, amount):
|
|
252
|
+
"""
|
|
253
|
+
When the `add` function is called for the first time with a given key, it creates a counter in
|
|
254
|
+
the storage corresponding to that key, with the initial value set to `amount`. Subsequent calls
|
|
255
|
+
to `add` with the same key increment the counter by amount.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
key (str): The key whose counter value will be incremented.
|
|
259
|
+
amount (int): The amount by which the counter will be incremented.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
int, value of counter with `key`.
|
|
263
|
+
|
|
264
|
+
Raises:
|
|
265
|
+
TypeError: If `key` is not string.
|
|
266
|
+
TypeError: If `amount` is not int.
|
|
267
|
+
RuntimeError: If the `add` and `set` pass the same `key` and the `value` passed by `set` cannot
|
|
268
|
+
be correctly converted to a numerical value, calling `add` will result in an error.
|
|
269
|
+
|
|
270
|
+
Supported Platforms:
|
|
271
|
+
``Ascend``
|
|
272
|
+
|
|
273
|
+
Examples:
|
|
274
|
+
.. note::
|
|
275
|
+
Before running the following examples, you need to configure the communication environment variables.
|
|
276
|
+
|
|
277
|
+
For Ascend devices, it is recommended to use the msrun startup method
|
|
278
|
+
without any third-party or configuration file dependencies.
|
|
279
|
+
Please see the `msrun start up
|
|
280
|
+
<https://www.mindspore.cn/tutorials/en/master/parallel/msrun_launcher.html>`_
|
|
281
|
+
for more details.
|
|
282
|
+
|
|
283
|
+
>>> from mindspore.mint.distributed import TCPStore
|
|
284
|
+
>>> store = TCPStore("127.0.0.1", 1234)
|
|
285
|
+
>>> store.add("first_key", 1)
|
|
286
|
+
"""
|
|
287
|
+
if not isinstance(key, str):
|
|
288
|
+
raise TypeError(
|
|
289
|
+
"For 'TCPStore.add', the argument 'key' must be type of string, "
|
|
290
|
+
"but got 'key' type : {}.".format(type(key))
|
|
291
|
+
)
|
|
292
|
+
if not isinstance(amount, int):
|
|
293
|
+
raise TypeError(
|
|
294
|
+
"For 'TCPStore.add', the argument 'amount' must be type of string or int, "
|
|
295
|
+
"but got 'amount' type : {}.".format(type(amount))
|
|
296
|
+
)
|
|
297
|
+
return self.instance.add(key, amount)
|
|
202
298
|
|
|
203
299
|
|
|
204
300
|
def set(self, key, value):
|
|
@@ -229,7 +325,7 @@ class TCPStore:
|
|
|
229
325
|
for more details.
|
|
230
326
|
|
|
231
327
|
>>> from mindspore.mint.distributed import TCPStore
|
|
232
|
-
>>> store = TCPStore()
|
|
328
|
+
>>> store = TCPStore("127.0.0.1", 1234)
|
|
233
329
|
>>> store.set("first_key", "first_value")
|
|
234
330
|
"""
|
|
235
331
|
if not isinstance(key, str):
|
|
@@ -247,8 +343,9 @@ class TCPStore:
|
|
|
247
343
|
|
|
248
344
|
def get(self, key):
|
|
249
345
|
"""
|
|
250
|
-
Retrieves the value associated with the given `key` in the store. If `key`
|
|
251
|
-
|
|
346
|
+
Retrieves the value associated with the given `key` in the store. If the `key` does not exist
|
|
347
|
+
in the storage, this function will wait for the `timeout` set by the class initialization and then
|
|
348
|
+
throw an exception.
|
|
252
349
|
|
|
253
350
|
Args:
|
|
254
351
|
key (str): The function will return the value associated with this key.
|
|
@@ -258,6 +355,7 @@ class TCPStore:
|
|
|
258
355
|
|
|
259
356
|
Raises:
|
|
260
357
|
TypeError: If `key` is not string.
|
|
358
|
+
RuntimeError: If `get` runs out of time.
|
|
261
359
|
|
|
262
360
|
Supported Platforms:
|
|
263
361
|
``Ascend``
|
|
@@ -273,7 +371,7 @@ class TCPStore:
|
|
|
273
371
|
for more details.
|
|
274
372
|
|
|
275
373
|
>>> from mindspore.mint.distributed import TCPStore
|
|
276
|
-
>>> store = TCPStore()
|
|
374
|
+
>>> store = TCPStore("127.0.0.1", 1234)
|
|
277
375
|
>>> store.set("first_key", "first_value")
|
|
278
376
|
>>> data = store.get("first_key")
|
|
279
377
|
>>> print(data)
|
|
@@ -301,7 +399,7 @@ class TCPStore:
|
|
|
301
399
|
TypeError: If `key` is not string.
|
|
302
400
|
|
|
303
401
|
Supported Platforms:
|
|
304
|
-
``
|
|
402
|
+
``Ascend``
|
|
305
403
|
|
|
306
404
|
Examples:
|
|
307
405
|
.. note::
|
|
@@ -314,7 +412,7 @@ class TCPStore:
|
|
|
314
412
|
for more details.
|
|
315
413
|
|
|
316
414
|
>>> from mindspore.mint.distributed import TCPStore
|
|
317
|
-
>>> store = TCPStore()
|
|
415
|
+
>>> store = TCPStore("127.0.0.1", 1234)
|
|
318
416
|
>>> store.set("first_key", "first_value")
|
|
319
417
|
>>> # This should return true
|
|
320
418
|
>>> store.delete_key("first_key")
|
|
@@ -389,6 +487,7 @@ def is_initialized():
|
|
|
389
487
|
return _is_initialized()
|
|
390
488
|
|
|
391
489
|
|
|
490
|
+
@args_type_check(init_method=str, timeout=timedelta, world_size=int, rank=int, store=TCPStore)
|
|
392
491
|
def init_process_group(backend="hccl",
|
|
393
492
|
init_method=None,
|
|
394
493
|
timeout=None,
|
|
@@ -406,26 +505,29 @@ def init_process_group(backend="hccl",
|
|
|
406
505
|
and the instantiation and execution of any operation and net.
|
|
407
506
|
|
|
408
507
|
Args:
|
|
409
|
-
backend (str, optional): The backend to ues.
|
|
410
|
-
init_method (str,
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
information. Provides parameters consistent with pytorch, but is not currently support,
|
|
419
|
-
setting is invalid.
|
|
508
|
+
backend (str, optional): The backend to ues. Default is ``"hccl"`` and now only support hccl.
|
|
509
|
+
init_method (str, optional): URL specifying how to init collective communication group. Default is ``None``.
|
|
510
|
+
timeout (timedelta, optional): Timeout for API executed. Default is ``None``. Currently, this parameter is
|
|
511
|
+
only supported for host-side cluster network configuration using `init_method` or `store`.
|
|
512
|
+
world_size (int, optional): Number of the processes participating in the job. Default is ``-1``.
|
|
513
|
+
rank (int, optional): Rank of the current process. Default is ``-1``.
|
|
514
|
+
store (Store, optional): An object that stores key/value data, facilitating the exchange of inter-process
|
|
515
|
+
communication addresses and connection information. Default is ``None``. Currently, only the
|
|
516
|
+
``TCPStore`` type is supported.
|
|
420
517
|
pg_options (ProcessGroupOptions, invalid): process group options specifying what additional options need to be
|
|
421
|
-
passed in during the construction of specific process group.
|
|
422
|
-
|
|
423
|
-
device_id (int, invalid): the device id to exeute.
|
|
424
|
-
|
|
518
|
+
passed in during the construction of specific process group. The provided parameter is a reserved
|
|
519
|
+
parameter, and the current setting does not take effect.
|
|
520
|
+
device_id (int, invalid): the device id to exeute. The provided parameter is a reserved parameter,
|
|
521
|
+
and the current setting does not take effect.
|
|
425
522
|
|
|
426
523
|
Raises:
|
|
427
524
|
ValueError: If `backend` is not hccl.
|
|
428
525
|
ValueError: If `world_size` is not equal to -1 or process group number.
|
|
526
|
+
ValueError: If both `init_method` and `store` are set.
|
|
527
|
+
ValueError: `world_size` is not correctly set as a positive integer value, when using the initialization
|
|
528
|
+
method `init_method` or `store`.
|
|
529
|
+
ValueError: `rank` is not correctly set as a non-negative integer, when using the initialization method
|
|
530
|
+
`init_method` or `store`.
|
|
429
531
|
RuntimeError: If device target is invalid, or backend is invalid, or distributed initialization fails,
|
|
430
532
|
or the environment variables RANK_ID/MINDSPORE_HCCL_CONFIG_PATH
|
|
431
533
|
have not been exported when backend is HCCL.
|
|
@@ -449,25 +551,34 @@ def init_process_group(backend="hccl",
|
|
|
449
551
|
>>> init_process_group()
|
|
450
552
|
>>> destroy_process_group()
|
|
451
553
|
"""
|
|
452
|
-
if init_method is not None:
|
|
453
|
-
logger.warning("init_method is ignored, setting is invalid")
|
|
454
|
-
if timeout is not None:
|
|
455
|
-
logger.warning("timeout is ignored, setting is invalid")
|
|
456
|
-
if store is not None:
|
|
457
|
-
logger.warning("store is ignored, setting is invalid")
|
|
458
554
|
if pg_options is not None:
|
|
459
555
|
logger.warning("pg_options is ignored, setting is invalid")
|
|
460
556
|
if device_id is not None:
|
|
461
557
|
logger.warning("device_id is ignored, setting is invalid")
|
|
462
|
-
if rank != -1:
|
|
463
|
-
logger.warning("rank is ignored, setting is invalid")
|
|
464
558
|
if backend != "hccl":
|
|
465
559
|
raise ValueError(
|
|
466
560
|
"Only support hccl now, please setting backend to hccl or using default value"
|
|
467
561
|
)
|
|
468
562
|
|
|
469
|
-
|
|
470
|
-
|
|
563
|
+
if init_method is not None and store is not None:
|
|
564
|
+
raise ValueError(
|
|
565
|
+
"Only one of init_method and store is supported."
|
|
566
|
+
)
|
|
567
|
+
if init_method is not None or store is not None:
|
|
568
|
+
if world_size <= 0:
|
|
569
|
+
raise ValueError(
|
|
570
|
+
"Specified world_size must be a positive integer."
|
|
571
|
+
)
|
|
572
|
+
if rank < 0:
|
|
573
|
+
raise ValueError(
|
|
574
|
+
"Specified rank must be a non-negative integer."
|
|
575
|
+
)
|
|
576
|
+
if timeout is None:
|
|
577
|
+
timeout = timedelta(seconds=300)
|
|
578
|
+
timeout_ms = int(timeout.total_seconds() * 1000)
|
|
579
|
+
_init_without_sched(backend, init_method, timeout_ms, world_size, rank, store)
|
|
580
|
+
else:
|
|
581
|
+
init(backend)
|
|
471
582
|
|
|
472
583
|
if world_size != -1 and world_size != get_group_size():
|
|
473
584
|
raise ValueError(
|
|
@@ -515,7 +626,10 @@ def destroy_process_group(group=None):
|
|
|
515
626
|
"""
|
|
516
627
|
|
|
517
628
|
if group == GlobalComm.WORLD_COMM_GROUP or group is None:
|
|
518
|
-
|
|
629
|
+
_pynative_executor.sync()
|
|
630
|
+
_finalize_collective()
|
|
631
|
+
_ExistingGroup.ITEMS.clear()
|
|
632
|
+
_ExistingGroup.GROUP_RANKS.clear()
|
|
519
633
|
elif not isinstance(group, str):
|
|
520
634
|
raise TypeError(
|
|
521
635
|
"For 'destroy_group', the argument 'group' must be type of string or None, "
|
|
@@ -673,6 +787,12 @@ def new_group(ranks=None,
|
|
|
673
787
|
hccl_config(dict)
|
|
674
788
|
}
|
|
675
789
|
|
|
790
|
+
`hccl_config` currently only supports "hccl_buffer_size" or "hccl_comm".
|
|
791
|
+
|
|
792
|
+
- hccl_buffer_size (uint32): specifies the size of the HCCL communication buffer.
|
|
793
|
+
- hccl_comm (int64): specifies an existing HcclComm pointer. If "hccl_comm" is set,
|
|
794
|
+
"hccl_buffer_size" will be ignored.
|
|
795
|
+
|
|
676
796
|
use_local_synchronization (bool, invalid): Currently it is a reserved parameter.
|
|
677
797
|
group_desc (str, invalid): Currently it is a reserved parameter.
|
|
678
798
|
|
|
@@ -1223,9 +1343,9 @@ def all_gather_into_tensor_uneven(output, input, output_split_sizes=None, group=
|
|
|
1223
1343
|
>>> ms.set_device(device_target="Ascend")
|
|
1224
1344
|
>>> init_process_group()
|
|
1225
1345
|
>>> if get_rank() == 0:
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1346
|
+
... input_tensor = Tensor(np.ones([3, 4]).astype(np.float32))
|
|
1347
|
+
... else:
|
|
1348
|
+
... input_tensor = Tensor(np.ones([2, 4]).astype(np.float32))
|
|
1229
1349
|
>>> out_tensor = Tensor(np.zeros([5, 4]).astype(np.float32))
|
|
1230
1350
|
>>> output_split_sizes = [3, 2]
|
|
1231
1351
|
>>> output = all_gather_into_tensor_uneven(out_tensor, input_tensor, output_split_sizes)
|
|
@@ -1357,7 +1477,7 @@ def reduce_scatter_tensor_uneven(output, input, input_split_sizes=None, op=Reduc
|
|
|
1357
1477
|
|
|
1358
1478
|
Args:
|
|
1359
1479
|
output(Tensor): the output tensor has the same dtype as `input` with a shape of
|
|
1360
|
-
:math:`(
|
|
1480
|
+
:math:`(input\_split\_sizes[rank], *)`, where rank is the local rank id of the device.
|
|
1361
1481
|
input(Tensor): The input tensor to be reduced and scattered, Expected shape :math:`(N, *)`, where `*`
|
|
1362
1482
|
means any number of additional dimensions. N must equal the sum of `input_split_sizes` across ranks.
|
|
1363
1483
|
input_split_sizes (list[int], optional): List specifying how to split the first dimension of input tensor.
|
|
@@ -1401,9 +1521,9 @@ def reduce_scatter_tensor_uneven(output, input, input_split_sizes=None, op=Reduc
|
|
|
1401
1521
|
>>> init_process_group()
|
|
1402
1522
|
>>> input_tensor = Tensor(np.ones([5, 8]).astype(np.float32))
|
|
1403
1523
|
>>> if get_rank() == 0:
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1524
|
+
... output_tensor = Tensor(np.ones([2, 8]).astype(np.float32))
|
|
1525
|
+
... else:
|
|
1526
|
+
... output_tensor = Tensor(np.ones([3, 8]).astype(np.float32))
|
|
1407
1527
|
>>> input_split_sizes = [2, 3]
|
|
1408
1528
|
>>> output = reduce_scatter_tensor_uneven(output_tensor, input_tensor, input_split_sizes)
|
|
1409
1529
|
>>> print(output_tensor)
|
mindspore/mint/nn/__init__.py
CHANGED
|
@@ -61,6 +61,7 @@ from mindspore.nn.layer import ReLU
|
|
|
61
61
|
|
|
62
62
|
# 14
|
|
63
63
|
from mindspore.nn.layer.basic import DropoutExt as Dropout
|
|
64
|
+
from mindspore.nn.layer.basic import Dropout2dExt as Dropout2d
|
|
64
65
|
# 15
|
|
65
66
|
from mindspore.mint.nn.layer.conv import Conv1d, Conv2d, Conv3d, ConvTranspose2d
|
|
66
67
|
# 16
|
|
@@ -260,9 +261,6 @@ from mindspore.mint.nn.layer.activation import Threshold
|
|
|
260
261
|
# 258
|
|
261
262
|
from mindspore.ops.function.nn_func import mse_loss_ext
|
|
262
263
|
|
|
263
|
-
# 393
|
|
264
|
-
from mindspore.mint.nn.layer.basic import Dropout2d
|
|
265
|
-
|
|
266
264
|
# 406
|
|
267
265
|
from mindspore.mint.nn.layer.activation import ELU
|
|
268
266
|
|
|
@@ -325,9 +323,6 @@ class NLLLoss(Cell):
|
|
|
325
323
|
\sum_{n=1}^{N} l_{n}, & \text { if reduction }=\text { 'sum' }
|
|
326
324
|
\end{array}\right.
|
|
327
325
|
|
|
328
|
-
.. warning::
|
|
329
|
-
This is an experimental API that is subject to change or deletion.
|
|
330
|
-
|
|
331
326
|
Args:
|
|
332
327
|
weight (Tensor, optional): A rescaling weight applied to the loss of each batch element.
|
|
333
328
|
If not None, the shape is :math:`(C,)`, data type must be float16 or float32 or bfloat16(only supported by
|
|
@@ -696,9 +691,6 @@ class ReLU6(Cell):
|
|
|
696
691
|
r"""
|
|
697
692
|
Activation function ReLU6.
|
|
698
693
|
|
|
699
|
-
.. warning::
|
|
700
|
-
This is an experimental API that is subject to change or deletion.
|
|
701
|
-
|
|
702
694
|
Refer to :func:`mindspore.mint.nn.functional.relu6` for more details.
|
|
703
695
|
|
|
704
696
|
ReLU6 Activation Function Graph:
|
|
@@ -847,9 +839,6 @@ class SmoothL1Loss(Cell):
|
|
|
847
839
|
|
|
848
840
|
Refer to :func:`mindspore.mint.nn.functional.smooth_l1_loss` for more details.
|
|
849
841
|
|
|
850
|
-
.. warning::
|
|
851
|
-
This is an experimental API that is subject to change or deletion.
|
|
852
|
-
|
|
853
842
|
Supported Platforms:
|
|
854
843
|
``Ascend``
|
|
855
844
|
|
|
@@ -1190,7 +1179,7 @@ class PixelShuffle(Cell):
|
|
|
1190
1179
|
>>> input = mint.randn(1, 9, 4, 4)
|
|
1191
1180
|
>>> output = pixel_shuffle(input)
|
|
1192
1181
|
>>> print(output.shape)
|
|
1193
|
-
|
|
1182
|
+
(1, 1, 12, 12)
|
|
1194
1183
|
"""
|
|
1195
1184
|
|
|
1196
1185
|
def __init__(self, upscale_factor):
|
|
@@ -1448,9 +1437,6 @@ __all__ = [
|
|
|
1448
1437
|
|
|
1449
1438
|
# 388
|
|
1450
1439
|
'AdaptiveMaxPool2d',
|
|
1451
|
-
|
|
1452
|
-
# 393
|
|
1453
|
-
'Dropout2d',
|
|
1454
1440
|
# 406
|
|
1455
1441
|
'ELU',
|
|
1456
1442
|
# 407
|