mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.0rc1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +1 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +2 -2
- mindspore/_extends/builtin_operations.py +3 -3
- mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +1 -1
- mindspore/_extends/parse/__init__.py +3 -3
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +1 -0
- mindspore/_extends/parse/parser.py +22 -28
- mindspore/_extends/parse/standard_method.py +1 -15
- mindspore/_extends/pijit/pijit_func_white_list.py +5 -2
- mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
- mindspore/amp.py +18 -0
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/common/__init__.py +12 -18
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +38 -102
- mindspore/common/_utils.py +1 -9
- mindspore/common/api.py +106 -155
- mindspore/common/{dynamic_shape/auto_dynamic_shape.py → auto_dynamic_shape.py} +23 -17
- mindspore/common/dtype.py +57 -98
- mindspore/common/dump.py +1 -1
- mindspore/common/file_system.py +9 -59
- mindspore/common/hook_handle.py +3 -22
- mindspore/common/np_dtype.py +3 -3
- mindspore/common/parameter.py +20 -4
- mindspore/common/recompute.py +4 -2
- mindspore/common/tensor.py +52 -38
- mindspore/communication/_hccl_management.py +297 -0
- mindspore/context.py +21 -15
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/transforms.py +1 -1
- mindspore/dataset/core/config.py +1 -35
- mindspore/dataset/engine/datasets.py +315 -330
- mindspore/dataset/engine/datasets_user_defined.py +22 -38
- mindspore/dataset/transforms/c_transforms.py +2 -2
- mindspore/dataset/transforms/transforms.py +3 -3
- mindspore/dataset/vision/__init__.py +1 -1
- mindspore/dataset/vision/py_transforms.py +8 -8
- mindspore/dataset/vision/transforms.py +5 -17
- mindspore/dataset/vision/utils.py +21 -632
- mindspore/device_context/ascend/op_tuning.py +1 -35
- mindspore/dnnl.dll +0 -0
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -3
- mindspore/include/api/cell.h +4 -28
- mindspore/include/api/cfg.h +7 -24
- mindspore/include/api/context.h +0 -1
- mindspore/include/api/delegate.h +2 -0
- mindspore/include/api/dual_abi_helper.h +19 -100
- mindspore/include/api/graph.h +1 -14
- mindspore/include/api/kernel.h +3 -16
- mindspore/include/api/kernel_api.h +1 -9
- mindspore/include/api/metrics/accuracy.h +0 -9
- mindspore/include/api/model.h +1 -5
- mindspore/include/api/model_group.h +0 -4
- mindspore/include/api/model_parallel_runner.h +0 -2
- mindspore/include/api/status.h +10 -48
- mindspore/include/api/types.h +1 -6
- mindspore/include/dataset/constants.h +0 -9
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/tools/cifar10.py +2 -3
- mindspore/mindrecord/tools/cifar10_to_mr.py +5 -5
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mindspore_ops_host.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mint/distributed/__init__.py +0 -4
- mindspore/mint/distributed/distributed.py +14 -217
- mindspore/mint/nn/layer/_functions.py +2 -1
- mindspore/mint/nn/layer/conv.py +6 -6
- mindspore/mint/nn/layer/normalization.py +3 -3
- mindspore/nn/cell.py +174 -216
- mindspore/nn/layer/activation.py +2 -4
- mindspore/nn/layer/basic.py +13 -7
- mindspore/nn/layer/image.py +1 -1
- mindspore/nn/optim/adam.py +3 -1
- mindspore/nn/optim/lamb.py +3 -1
- mindspore/nn/optim/tft_wrapper.py +3 -2
- mindspore/nn/probability/distribution/_utils/utils.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +5 -39
- mindspore/nn/wrap/grad_reducer.py +15 -0
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/utils_const.py +1 -1
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/_grad_experimental/grad_inner_ops.py +9 -0
- mindspore/ops/_op_impl/cpu/__init__.py +0 -1
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +2 -12
- mindspore/ops/auto_generate/gen_extend_func.py +4 -4
- mindspore/ops/auto_generate/gen_ops_def.py +16 -290
- mindspore/ops/auto_generate/gen_ops_prim.py +76 -563
- mindspore/ops/composite/base.py +1 -1
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +1 -1
- mindspore/ops/function/__init__.py +0 -1
- mindspore/ops/function/array_func.py +6 -10
- mindspore/ops/function/debug_func.py +2 -4
- mindspore/ops/function/grad/grad_func.py +12 -4
- mindspore/ops/function/math_func.py +32 -44
- mindspore/ops/function/nn_func.py +20 -18
- mindspore/ops/functional.py +1 -2
- mindspore/ops/functional_overload.py +12 -23
- mindspore/ops/operations/_inner_ops.py +12 -11
- mindspore/ops/operations/array_ops.py +50 -4
- mindspore/ops/operations/comm_ops.py +15 -1
- mindspore/ops/operations/custom_ops.py +4 -10
- mindspore/ops/operations/debug_ops.py +6 -6
- mindspore/ops/operations/manually_defined/ops_def.py +12 -12
- mindspore/ops/operations/math_ops.py +5 -5
- mindspore/ops/operations/nn_ops.py +1 -1
- mindspore/ops/primitive.py +10 -3
- mindspore/ops/tensor_method.py +7 -16
- mindspore/ops_generate/pyboost/gen_pyboost_func.py +16 -0
- mindspore/parallel/_auto_parallel_context.py +15 -5
- mindspore/parallel/_parallel_serialization.py +2 -3
- mindspore/parallel/_ps_context.py +2 -2
- mindspore/parallel/_transformer/transformer.py +4 -4
- mindspore/parallel/_utils.py +11 -5
- mindspore/parallel/auto_parallel.py +9 -23
- mindspore/parallel/checkpoint_transform.py +0 -2
- mindspore/parallel/cluster/process_entity/_api.py +1 -4
- mindspore/parallel/cluster/run.py +3 -5
- mindspore/parallel/function/reshard_func.py +5 -6
- mindspore/parallel/nn/parallel_cell_wrapper.py +3 -40
- mindspore/parallel/nn/parallel_grad_reducer.py +8 -0
- mindspore/parallel/shard.py +21 -7
- mindspore/parallel/transform_safetensors.py +4 -10
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +9 -10
- mindspore/profiler/analysis/viewer/ascend_op_memory_viewer.py +1 -1
- mindspore/profiler/common/msprof_cmd_tool.py +2 -2
- mindspore/profiler/common/path_manager.py +0 -9
- mindspore/profiler/common/profiler_context.py +2 -25
- mindspore/profiler/common/profiler_meta_data.py +0 -1
- mindspore/profiler/common/profiler_op_analyse.py +6 -10
- mindspore/{ops/_op_impl/cpu/joinedstr_op.py → profiler/common/validator/__init__.py} +1 -15
- mindspore/profiler/common/validator/validate_path.py +84 -0
- mindspore/profiler/dynamic_profiler.py +46 -91
- mindspore/profiler/envprofiler.py +5 -30
- mindspore/profiler/experimental_config.py +1 -16
- mindspore/profiler/platform/cpu_profiler.py +4 -10
- mindspore/profiler/platform/npu_profiler.py +1 -1
- mindspore/profiler/profiler.py +145 -193
- mindspore/profiler/profiler_action_controller.py +1 -1
- mindspore/profiler/profiler_interface.py +2 -2
- mindspore/rewrite/symbol_tree/symbol_tree.py +1 -1
- mindspore/runtime/__init__.py +4 -6
- mindspore/runtime/executor.py +0 -27
- mindspore/runtime/memory.py +0 -1
- mindspore/runtime/thread_bind_core.py +1 -1
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/train/_utils.py +3 -3
- mindspore/train/amp.py +3 -0
- mindspore/train/callback/_callback.py +1 -2
- mindspore/train/callback/_checkpoint.py +8 -1
- mindspore/train/callback/_flops_collector.py +6 -10
- mindspore/train/callback/_train_fault_tolerance.py +7 -3
- mindspore/train/data_sink.py +4 -4
- mindspore/train/dataset_helper.py +5 -5
- mindspore/train/model.py +20 -4
- mindspore/train/serialization.py +15 -35
- mindspore/train/train_thor/model_thor.py +2 -2
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/hooks.py +81 -0
- mindspore/utils/utils.py +8 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/METADATA +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/RECORD +193 -192
- mindspore/_extends/parallel_compile/akg_compiler/custom.py +0 -1109
- mindspore/common/dynamic_shape/__init__.py +0 -0
- mindspore/common/dynamic_shape/enable_dynamic.py +0 -197
- /mindspore/common/{dynamic_shape/_auto_dynamic.py → _auto_dynamic.py} +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.0rc1.dist-info}/top_level.txt +0 -0
mindspore/common/tensor.py
CHANGED
|
@@ -419,6 +419,9 @@ class Tensor(TensorPy_, metaclass=_TensorMeta):
|
|
|
419
419
|
def __rmod__(self, other):
|
|
420
420
|
return _rmod_instance(other, self)
|
|
421
421
|
|
|
422
|
+
def __imod__(self, other):
|
|
423
|
+
return self.__mod__(other)
|
|
424
|
+
|
|
422
425
|
def __rpow__(self, other):
|
|
423
426
|
return tensor_operator_registry.get('__rpow__')(self, other)
|
|
424
427
|
|
|
@@ -1147,6 +1150,7 @@ class Tensor(TensorPy_, metaclass=_TensorMeta):
|
|
|
1147
1150
|
which may be modified by returning a new output gradient.
|
|
1148
1151
|
- The `hook` should have the following signature:
|
|
1149
1152
|
hook(grad) -> New output gradient, but can not return None or not set return value.
|
|
1153
|
+
- Higher-order differentiation does not support tensor `register_hook`.
|
|
1150
1154
|
- The following constraints must be met under graph mode:
|
|
1151
1155
|
|
|
1152
1156
|
- The `hook` must satisfy the syntax constraints of the graph mode.
|
|
@@ -1851,10 +1855,6 @@ class Tensor(TensorPy_, metaclass=_TensorMeta):
|
|
|
1851
1855
|
|
|
1852
1856
|
self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
|
|
1853
1857
|
|
|
1854
|
-
.. warning::
|
|
1855
|
-
When deterministic computation is enabled, `index` can not be a non-contiguous Tensor; otherwise,
|
|
1856
|
-
deterministic results can not be guaranteed.
|
|
1857
|
-
|
|
1858
1858
|
Args:
|
|
1859
1859
|
dim (int): Which dim to scatter. Accepted range is [-r, r) where r = rank(`self`).
|
|
1860
1860
|
index (Tensor): The index of `self` to do scatter operation whose data type must
|
|
@@ -2101,9 +2101,9 @@ class Tensor(TensorPy_, metaclass=_TensorMeta):
|
|
|
2101
2101
|
try:
|
|
2102
2102
|
dtype_ = mstype.int8 if is_qint4x2 else self.dtype
|
|
2103
2103
|
if isinstance(self.init, ZeroInitializer):
|
|
2104
|
-
data = np.zeros(data_shape, dtype=mstype.
|
|
2104
|
+
data = np.zeros(data_shape, dtype=mstype.dtype_to_nptype(dtype_))
|
|
2105
2105
|
else:
|
|
2106
|
-
data = np.ndarray(data_shape, dtype=mstype.
|
|
2106
|
+
data = np.ndarray(data_shape, dtype=mstype.dtype_to_nptype(dtype_))
|
|
2107
2107
|
except ValueError as e:
|
|
2108
2108
|
msg = "Error shape={}".format(shape)
|
|
2109
2109
|
logger.critical(msg)
|
|
@@ -2586,34 +2586,6 @@ class Tensor(TensorPy_, metaclass=_TensorMeta):
|
|
|
2586
2586
|
"""
|
|
2587
2587
|
return tensor_operator_registry.get('bernoulli')(self, generator=generator)
|
|
2588
2588
|
|
|
2589
|
-
def bernoulli_(self, p=0.5, *, generator=None):
|
|
2590
|
-
r"""
|
|
2591
|
-
Fills each location of self with an independent sample from Bernoulli(p).
|
|
2592
|
-
|
|
2593
|
-
Args:
|
|
2594
|
-
p (Union[number.Number, Tensor], optional): `p` should either be a scalar or tensor containing
|
|
2595
|
-
probabilities to be used for drawing the binary random number, between ``0`` and ``1`` .
|
|
2596
|
-
If it is a tensor, `p` must be floating point. Default: ``0.5`` .
|
|
2597
|
-
|
|
2598
|
-
Keyword Args:
|
|
2599
|
-
generator (:class:`mindspore.Generator`, optional): a pseudorandom number generator.
|
|
2600
|
-
Default: ``None`` , uses the default pseudorandom number generator.
|
|
2601
|
-
|
|
2602
|
-
Returns:
|
|
2603
|
-
The input tensor.
|
|
2604
|
-
|
|
2605
|
-
Supported Platforms:
|
|
2606
|
-
``Ascend``
|
|
2607
|
-
|
|
2608
|
-
Examples:
|
|
2609
|
-
>>> from mindspore import Tensor
|
|
2610
|
-
>>> x = Tensor([[2, 3, 4], [1, 2, 3]])
|
|
2611
|
-
>>> p = 0.1
|
|
2612
|
-
>>> print(x.bernoulli_(p).shape)
|
|
2613
|
-
(2, 3)
|
|
2614
|
-
"""
|
|
2615
|
-
return tensor_operator_registry.get('bernoulli_')(self, p, generator=generator)
|
|
2616
|
-
|
|
2617
2589
|
def random_(self, from_=0, to=None, *, generator=None):
|
|
2618
2590
|
r"""
|
|
2619
2591
|
Fill the tensor with numbers sampled from a discrete uniform distribution over an
|
|
@@ -2967,8 +2939,7 @@ class Tensor(TensorPy_, metaclass=_TensorMeta):
|
|
|
2967
2939
|
taken into account. As long as the real part is non-zero, it returns True; otherwise, it returns False.
|
|
2968
2940
|
|
|
2969
2941
|
Args:
|
|
2970
|
-
dtype (dtype.Number
|
|
2971
|
-
Only Support type bool in PyNative mode.
|
|
2942
|
+
dtype (dtype.Number): The valid data type of the output tensor. Only constant value is allowed.
|
|
2972
2943
|
|
|
2973
2944
|
Returns:
|
|
2974
2945
|
Tensor, converted to the specified `dtype`.
|
|
@@ -3628,6 +3599,46 @@ class Tensor(TensorPy_, metaclass=_TensorMeta):
|
|
|
3628
3599
|
"""
|
|
3629
3600
|
return tensor_operator_registry.get('ormqr')(self, input2, input3, left, transpose)
|
|
3630
3601
|
|
|
3602
|
+
def masked_scatter(self, mask, x):
|
|
3603
|
+
r"""
|
|
3604
|
+
Updates the value in the "self Tensor" with the `tensor` value according to the mask, and returns a Tensor.
|
|
3605
|
+
The shape of `mask` and the "self Tensor" must be the same or `mask` is broadcastable.
|
|
3606
|
+
|
|
3607
|
+
.. warning::
|
|
3608
|
+
This is an experimental API that is subject to change or deletion.
|
|
3609
|
+
|
|
3610
|
+
Args:
|
|
3611
|
+
mask (Tensor[bool]): A bool tensor with a shape broadcastable to the "self Tensor".
|
|
3612
|
+
x (Tensor): A tensor with the same data type as the "self Tensor". The number
|
|
3613
|
+
of elements must be greater than or equal to the number of True's in `mask`.
|
|
3614
|
+
|
|
3615
|
+
Returns:
|
|
3616
|
+
Tensor, with the same type and shape as the "self Tensor".
|
|
3617
|
+
|
|
3618
|
+
Raises:
|
|
3619
|
+
TypeError: If `mask` or `x` is not a Tensor.
|
|
3620
|
+
TypeError: If data type of the "self Tensor" is not be supported.
|
|
3621
|
+
TypeError: If dtype of `mask` is not bool.
|
|
3622
|
+
TypeError: If the dim of the "self Tensor" less than the dim of `mask`.
|
|
3623
|
+
ValueError: If `mask` can not be broadcastable to the "self Tensor".
|
|
3624
|
+
ValueError: If the number of elements in `x` is less than the number required for the updates.
|
|
3625
|
+
|
|
3626
|
+
Supported Platforms:
|
|
3627
|
+
``Ascend`` ``CPU``
|
|
3628
|
+
|
|
3629
|
+
Examples:
|
|
3630
|
+
>>> import numpy as np
|
|
3631
|
+
>>> import mindspore
|
|
3632
|
+
>>> from mindspore import Tensor
|
|
3633
|
+
>>> x = Tensor(np.array([1., 2., 3., 4.]), mindspore.float32)
|
|
3634
|
+
>>> mask = Tensor(np.array([True, True, False, True]), mindspore.bool_)
|
|
3635
|
+
>>> tensor = Tensor(np.array([5., 6., 7.]), mindspore.float32)
|
|
3636
|
+
>>> output = x.masked_scatter(mask, tensor)
|
|
3637
|
+
>>> print(output)
|
|
3638
|
+
[5. 6. 3. 7.]
|
|
3639
|
+
"""
|
|
3640
|
+
return tensor_operator_registry.get('masked_scatter')()(self, mask, x)
|
|
3641
|
+
|
|
3631
3642
|
def index_put(self, indices, values, accumulate=False):
|
|
3632
3643
|
r"""
|
|
3633
3644
|
Based on the indices in `indices`, replace the corresponding elements in Tensor `self`
|
|
@@ -3768,6 +3779,9 @@ class Tensor(TensorPy_, metaclass=_TensorMeta):
|
|
|
3768
3779
|
raise ValueError(f"The type of 'blocking' must be bool, but got {blocking}")
|
|
3769
3780
|
if to not in ("Ascend", "GPU", "CPU"):
|
|
3770
3781
|
raise ValueError(f"The value of 'to' must be one of ['Ascend', 'GPU', 'CPU'], but got {to}")
|
|
3782
|
+
mode = context.get_context("mode")
|
|
3783
|
+
if mode != context.PYNATIVE_MODE:
|
|
3784
|
+
raise ValueError(f"The method of 'move_to' only supported in pynative mode, but got: {mode}.")
|
|
3771
3785
|
return TensorPy_.move_to(self, to, blocking)
|
|
3772
3786
|
|
|
3773
3787
|
def _offload(self):
|
|
@@ -3937,9 +3951,9 @@ def _check_astype_and_convert(dtype):
|
|
|
3937
3951
|
if dtype.lower() not in all_types:
|
|
3938
3952
|
raise TypeError(f"For Tensor.astype, the string input type must be one of {all_types}, "
|
|
3939
3953
|
f"but got '{dtype}'.")
|
|
3940
|
-
dtype = mstype.
|
|
3954
|
+
dtype = mstype.pytype_to_dtype(np.dtype(dtype.lower()))
|
|
3941
3955
|
elif isinstance(dtype, type):
|
|
3942
|
-
dtype = mstype.
|
|
3956
|
+
dtype = mstype.pytype_to_dtype(dtype)
|
|
3943
3957
|
elif dtype not in mstype.number_type + (mstype.bool_,):
|
|
3944
3958
|
raise TypeError(
|
|
3945
3959
|
f"For Tensor.astype, the input type must be one of {list(mstype.number_type + (mstype.bool_,) + np_types)},"
|
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
2
|
+
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
# ============================================================================
|
|
16
|
+
"""HCCL management API"""
|
|
17
|
+
from __future__ import absolute_import
|
|
18
|
+
from __future__ import division
|
|
19
|
+
|
|
20
|
+
import ctypes
|
|
21
|
+
import os
|
|
22
|
+
|
|
23
|
+
from mindspore import context
|
|
24
|
+
from mindspore._c_expression import get_hccl_rank_id, get_hccl_rank_size
|
|
25
|
+
|
|
26
|
+
MAX_GROUP_NAME_LEN = 127
|
|
27
|
+
MAX_RANK_NUM = 4096
|
|
28
|
+
HCCL_LIB = 'libhccl_plugin.so'
|
|
29
|
+
HCCL_LIB_CTYPES = ""
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def check_group(group):
|
|
33
|
+
"""
|
|
34
|
+
A function that check if a collection communication group is legal.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
None
|
|
38
|
+
"""
|
|
39
|
+
if isinstance(group, (str)):
|
|
40
|
+
group_len = len(group)
|
|
41
|
+
if group_len > MAX_GROUP_NAME_LEN or group_len == 0:
|
|
42
|
+
raise ValueError("The length of communication group name must be in range [1, 127), "
|
|
43
|
+
"but got the value : {} ".format(group_len))
|
|
44
|
+
else:
|
|
45
|
+
raise TypeError("The type of communication group name must be type of string, "
|
|
46
|
+
"but got 'group' type : {}.".format(type(group)))
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def check_rank_num(rank_num):
|
|
50
|
+
"""
|
|
51
|
+
A function that check if a collection communication rank number is legal.If not raise error.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
None
|
|
55
|
+
"""
|
|
56
|
+
if isinstance(rank_num, (int)):
|
|
57
|
+
if rank_num > MAX_RANK_NUM or rank_num <= 0:
|
|
58
|
+
raise ValueError("For 'create_group', the size of argument 'rand_ids' should be greater than 0 and"
|
|
59
|
+
"less than {}, but got the size of 'rank_ids' : {}.".format(MAX_RANK_NUM, rank_num))
|
|
60
|
+
else:
|
|
61
|
+
raise TypeError("The argument 'rank_num' must be type of int, "
|
|
62
|
+
"but got 'rank_num' type : {}.".format(type(rank_num)))
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def check_rank_id(rank_id):
|
|
66
|
+
"""
|
|
67
|
+
A function that check if a collection communication rank id is legal.If not raise error.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
None
|
|
71
|
+
"""
|
|
72
|
+
if isinstance(rank_id, (int)):
|
|
73
|
+
if rank_id >= MAX_RANK_NUM or rank_id < 0:
|
|
74
|
+
raise ValueError("The rand id in the communication group must be greater or equal 0 and "
|
|
75
|
+
"less than {}, but got type value : {}.".format(MAX_RANK_NUM, rank_id))
|
|
76
|
+
else:
|
|
77
|
+
raise TypeError("The rand id in the communication group must be must be type of int, "
|
|
78
|
+
"but got type value : {}.".format(type(rank_id)))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def load_lib():
|
|
82
|
+
"""load hccl lib"""
|
|
83
|
+
try:
|
|
84
|
+
base_dir = os.path.dirname(os.path.realpath(__file__))
|
|
85
|
+
lib_path = os.path.join(base_dir, "../lib/plugin/ascend", HCCL_LIB)
|
|
86
|
+
hccl_lib = ctypes.CDLL(lib_path)
|
|
87
|
+
except Exception:
|
|
88
|
+
raise RuntimeError('Get hccl lib error.')
|
|
89
|
+
|
|
90
|
+
global HCCL_LIB_CTYPES
|
|
91
|
+
HCCL_LIB_CTYPES = hccl_lib
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def c_str(string):
|
|
95
|
+
"""Convert a python string to C string."""
|
|
96
|
+
if not isinstance(string, str):
|
|
97
|
+
string = string.decode('ascii')
|
|
98
|
+
return ctypes.c_char_p(string.encode('utf-8'))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def c_array(ctype, values):
|
|
102
|
+
"""Create ctypes array from a python array."""
|
|
103
|
+
return (ctype * len(values))(*values)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def create_group(group, rank_num, rank_ids):
|
|
107
|
+
"""
|
|
108
|
+
Create group.
|
|
109
|
+
|
|
110
|
+
A function that creates a collection communication group which includes 'rank_num'
|
|
111
|
+
device and 'rank_ids' is the list of these ranks of devices.
|
|
112
|
+
|
|
113
|
+
Note:
|
|
114
|
+
The world group can not be created.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
None
|
|
118
|
+
"""
|
|
119
|
+
check_group(group)
|
|
120
|
+
check_rank_num(rank_num)
|
|
121
|
+
if isinstance(rank_ids, (list)):
|
|
122
|
+
if rank_num != len(rank_ids):
|
|
123
|
+
raise ValueError("The argument 'rank_num' number should be equal to the length "
|
|
124
|
+
"of rank_ids, but got 'rank_num' value : {} and 'rank_ids' value : {}."
|
|
125
|
+
.format(rank_num, rank_ids))
|
|
126
|
+
for rank_id in rank_ids:
|
|
127
|
+
if not isinstance(rank_id, (int)) or rank_id < 0:
|
|
128
|
+
raise ValueError("The elements of argument 'rank_ids' must be "
|
|
129
|
+
"unsigned integer, but got the type : {}".format(type(rank_id)))
|
|
130
|
+
c_array_rank_ids = c_array(ctypes.c_uint, rank_ids)
|
|
131
|
+
c_rank_num = ctypes.c_uint(rank_num)
|
|
132
|
+
c_group = c_str(group)
|
|
133
|
+
ret = HCCL_LIB_CTYPES.HcomCreateGroup(c_group, c_rank_num, c_array_rank_ids)
|
|
134
|
+
if ret != 0:
|
|
135
|
+
raise RuntimeError('Create group error, the error code is {}.'.format(ret))
|
|
136
|
+
else:
|
|
137
|
+
raise TypeError("For 'create_group', the argument 'rank_ids' must be type of list, "
|
|
138
|
+
"but got 'rank_ids' type : {}.".format(type(rank_ids)))
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def destroy_group(group):
|
|
142
|
+
"""
|
|
143
|
+
A function that destroy the group which created by user.
|
|
144
|
+
|
|
145
|
+
Note:
|
|
146
|
+
The world group can not be destroy.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
None
|
|
150
|
+
"""
|
|
151
|
+
check_group(group)
|
|
152
|
+
c_group = c_str(group)
|
|
153
|
+
ret = HCCL_LIB_CTYPES.HcomDestroyGroup(c_group)
|
|
154
|
+
if ret != 0:
|
|
155
|
+
raise RuntimeError('Destroy group error.')
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def get_rank_size(group="hccl_world_group"):
|
|
159
|
+
"""
|
|
160
|
+
A function that returns the number of ranks within the given collection communication group.
|
|
161
|
+
|
|
162
|
+
Note:
|
|
163
|
+
The default group is hccl_world_group.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
An integer scalar with the num of ranks.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
170
|
+
return get_hccl_rank_size()
|
|
171
|
+
|
|
172
|
+
check_group(group)
|
|
173
|
+
c_group = c_str(group)
|
|
174
|
+
c_rank_size = ctypes.c_uint()
|
|
175
|
+
ret = HCCL_LIB_CTYPES.HcomGetRankSize(c_group, ctypes.byref(c_rank_size))
|
|
176
|
+
if ret != 0:
|
|
177
|
+
raise RuntimeError('Get rank size error.')
|
|
178
|
+
|
|
179
|
+
return c_rank_size.value
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def get_rank_id(group="hccl_world_group"):
|
|
183
|
+
"""
|
|
184
|
+
A function that returns the rank id of the calling process, within the given collection communication group.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
An integer scalar with the rank id of the calling process.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
191
|
+
return get_hccl_rank_id()
|
|
192
|
+
|
|
193
|
+
check_group(group)
|
|
194
|
+
c_group = c_str(group)
|
|
195
|
+
c_rank_id = ctypes.c_uint()
|
|
196
|
+
ret = HCCL_LIB_CTYPES.HcomGetRankId(c_group, ctypes.byref(c_rank_id))
|
|
197
|
+
if ret != 0:
|
|
198
|
+
raise RuntimeError('Get rank id error.')
|
|
199
|
+
|
|
200
|
+
return c_rank_id.value
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def get_local_rank_size(group="hccl_world_group"):
|
|
205
|
+
"""
|
|
206
|
+
A function that returns the number of local ranks within the given collection communication group.
|
|
207
|
+
|
|
208
|
+
Note:
|
|
209
|
+
The default group is hccl_world_group.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
An integer scalar with the num of local ranks.
|
|
213
|
+
"""
|
|
214
|
+
if context.get_context("mode") is context.PYNATIVE_MODE:
|
|
215
|
+
raise RuntimeError("The function 'get_local_rank_size' is not supported in PYNATIVE_MODE, "
|
|
216
|
+
"'get_local_rank_size' only support GRAPH_MODE")
|
|
217
|
+
check_group(group)
|
|
218
|
+
c_group = c_str(group)
|
|
219
|
+
c_local_rank_size = ctypes.c_uint()
|
|
220
|
+
ret = HCCL_LIB_CTYPES.HcomGetLocalRankSize(c_group, ctypes.byref(c_local_rank_size))
|
|
221
|
+
if ret != 0:
|
|
222
|
+
raise RuntimeError('Get local rank size error.')
|
|
223
|
+
|
|
224
|
+
return c_local_rank_size.value
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def get_local_rank_id(group="hccl_world_group"):
|
|
228
|
+
"""
|
|
229
|
+
Get local rank id.
|
|
230
|
+
|
|
231
|
+
A function that returns the local rank id of the calling process, within the given collection communication group.
|
|
232
|
+
|
|
233
|
+
Returns:
|
|
234
|
+
An integer scalar with the local rank id of the calling process.
|
|
235
|
+
"""
|
|
236
|
+
|
|
237
|
+
if context.get_context("mode") is context.PYNATIVE_MODE:
|
|
238
|
+
raise RuntimeError("The function 'get_local_rank_id' is not supported in PYNATIVE_MODE, "
|
|
239
|
+
"'get_local_rank_id' only support GRAPH_MODE")
|
|
240
|
+
check_group(group)
|
|
241
|
+
c_group = c_str(group)
|
|
242
|
+
c_local_rank_id = ctypes.c_uint()
|
|
243
|
+
ret = HCCL_LIB_CTYPES.HcomGetLocalRankId(c_group, ctypes.byref(c_local_rank_id))
|
|
244
|
+
if ret != 0:
|
|
245
|
+
raise RuntimeError('Get local rank id error.')
|
|
246
|
+
|
|
247
|
+
return c_local_rank_id.value
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def get_world_rank_from_group_rank(group, group_rank_id):
|
|
251
|
+
"""
|
|
252
|
+
Get world rank from group rank.
|
|
253
|
+
|
|
254
|
+
A function that returns the rank id in the world group corresponding to the
|
|
255
|
+
rank which id is 'group_rank_id' in the user group.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
An integer scalar with the rank id in the world group.
|
|
259
|
+
"""
|
|
260
|
+
if context.get_context("mode") is context.PYNATIVE_MODE:
|
|
261
|
+
raise RuntimeError("The function 'get_world_rank_from_group_rank' is not supported in PYNATIVE_MODE, "
|
|
262
|
+
"'get_world_rank_from_group_rank' only support GRAPH_MODE")
|
|
263
|
+
check_group(group)
|
|
264
|
+
check_rank_id(group_rank_id)
|
|
265
|
+
c_group = c_str(group)
|
|
266
|
+
c_group_rank_id = ctypes.c_uint(group_rank_id)
|
|
267
|
+
c_world_rank_id = ctypes.c_uint()
|
|
268
|
+
ret = HCCL_LIB_CTYPES.HcomGetWorldRankFromGroupRank(c_group, c_group_rank_id, ctypes.byref(c_world_rank_id))
|
|
269
|
+
if ret != 0:
|
|
270
|
+
raise RuntimeError('Get world rank from group rank error.')
|
|
271
|
+
|
|
272
|
+
return c_world_rank_id.value
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def get_group_rank_from_world_rank(world_rank_id, group):
|
|
276
|
+
"""
|
|
277
|
+
Get group rank from world rank.
|
|
278
|
+
|
|
279
|
+
A function that returns the rank id in the user group corresponding to the
|
|
280
|
+
rank which id is 'world_rank_id' in the world group.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
An integer scalar with the rank id in the user group.
|
|
284
|
+
"""
|
|
285
|
+
if context.get_context("mode") is context.PYNATIVE_MODE:
|
|
286
|
+
raise RuntimeError("The function 'get_group_rank_from_world_rank' is not supported in PYNATIVE_MODE, "
|
|
287
|
+
"'get_group_rank_from_world_rank' only support GRAPH_MODE")
|
|
288
|
+
check_group(group)
|
|
289
|
+
check_rank_id(world_rank_id)
|
|
290
|
+
c_group = c_str(group)
|
|
291
|
+
c_world_rank_id = ctypes.c_uint(world_rank_id)
|
|
292
|
+
c_group_rank_id = ctypes.c_uint()
|
|
293
|
+
ret = HCCL_LIB_CTYPES.HcomGetGroupRankFromWorldRank(c_world_rank_id, c_group, ctypes.byref(c_group_rank_id))
|
|
294
|
+
if ret != 0:
|
|
295
|
+
raise RuntimeError('Get group rank from world rank error.')
|
|
296
|
+
|
|
297
|
+
return c_group_rank_id.value
|
mindspore/context.py
CHANGED
|
@@ -204,6 +204,13 @@ class _Context:
|
|
|
204
204
|
if mode == PYNATIVE_MODE:
|
|
205
205
|
if self.enable_debug_runtime:
|
|
206
206
|
self.set_backend_policy("vm")
|
|
207
|
+
parallel_mode = _get_auto_parallel_context("parallel_mode")
|
|
208
|
+
if parallel_mode not in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE, ParallelMode.AUTO_PARALLEL):
|
|
209
|
+
raise ValueError(f"Got {parallel_mode}, when the user enabled SEMI_AUTO_PARALELL, "
|
|
210
|
+
f"pynative mode does not support, you should set either "
|
|
211
|
+
f"context.set_auto_parallel_context(parallel_mode='data_parallel'), "
|
|
212
|
+
f"context.set_auto_parallel_context(parallel_mode='stand_alone') "
|
|
213
|
+
f"or context.set_auto_parallel_context(parallel_mode='auto_parallel').")
|
|
207
214
|
self._context_switches.push(True, None)
|
|
208
215
|
elif mode == GRAPH_MODE:
|
|
209
216
|
if self.enable_debug_runtime:
|
|
@@ -598,12 +605,12 @@ class _Context:
|
|
|
598
605
|
def set_mempool_block_size(self, mempool_block_size):
|
|
599
606
|
"""Set the block size of memory pool."""
|
|
600
607
|
global_jit_config = get_jit_config()
|
|
601
|
-
|
|
608
|
+
is_force_kbk = False
|
|
602
609
|
if global_jit_config:
|
|
603
|
-
|
|
604
|
-
if
|
|
605
|
-
logger.warning("
|
|
606
|
-
"you can use pynative mode or set jit_level=O0/O1.")
|
|
610
|
+
is_force_kbk = global_jit_config.get('jit_level') == "O0" or global_jit_config.get('jit_level') == "O1"
|
|
611
|
+
if _get_mode() == GRAPH_MODE and not is_force_kbk:
|
|
612
|
+
logger.warning("Graph mode doesn't support to set parameter 'mempool_block_size' of context currently, "
|
|
613
|
+
"you can use context.set_context to set pynative mode or set jit_level=O0/O1.")
|
|
607
614
|
return
|
|
608
615
|
if not Validator.check_str_by_regular(mempool_block_size, _RE_PATTERN):
|
|
609
616
|
raise ValueError("For 'context.set_context', the argument 'mempool_block_size' should be in "
|
|
@@ -836,8 +843,7 @@ class _Context:
|
|
|
836
843
|
@staticmethod
|
|
837
844
|
def _check_speedup_config_str_value(key, value):
|
|
838
845
|
"""check speedup config str value"""
|
|
839
|
-
if key in ["pp_1f1b_overlap", "recompute_comm_overlap", "recomputation_communication_overlap"
|
|
840
|
-
"matmul_grad_comm_overlap", "grad_matmul_communication_overlap"]:
|
|
846
|
+
if key in ["pp_1f1b_overlap", "recompute_comm_overlap", "recomputation_communication_overlap"]:
|
|
841
847
|
if isinstance(value, str):
|
|
842
848
|
values = value.split(",")
|
|
843
849
|
for v in values:
|
|
@@ -865,8 +871,8 @@ class _Context:
|
|
|
865
871
|
try:
|
|
866
872
|
valid_option = {"recompute_comm_overlap": (ms_ctx_param.recompute_comm_overlap, str),
|
|
867
873
|
"recomputation_communication_overlap": (ms_ctx_param.recompute_comm_overlap, str),
|
|
868
|
-
"matmul_grad_comm_overlap": (ms_ctx_param.matmul_grad_comm_overlap,
|
|
869
|
-
"grad_matmul_communication_overlap": (ms_ctx_param.matmul_grad_comm_overlap,
|
|
874
|
+
"matmul_grad_comm_overlap": (ms_ctx_param.matmul_grad_comm_overlap, bool),
|
|
875
|
+
"grad_matmul_communication_overlap": (ms_ctx_param.matmul_grad_comm_overlap, bool),
|
|
870
876
|
"enable_task_opt": (ms_ctx_param.enable_task_opt, bool),
|
|
871
877
|
"enable_communication_fusion": (ms_ctx_param.enable_task_opt, bool),
|
|
872
878
|
"enable_grad_comm_opt": (ms_ctx_param.enable_grad_comm_opt, bool),
|
|
@@ -1081,8 +1087,8 @@ def set_auto_parallel_context(**kwargs):
|
|
|
1081
1087
|
|
|
1082
1088
|
- pipeline_interleave(bool): Indicates whether to enable the interleaved execution mode.
|
|
1083
1089
|
- pipeline_scheduler(str): Indicates the scheduling mode for pipeline parallelism. Only support
|
|
1084
|
-
``gpipe/1f1b/seqpipe/seqvpp/seqsmartvpp
|
|
1085
|
-
|
|
1090
|
+
``gpipe/1f1b/seqpipe/seqvpp/seqsmartvpp``. When applying seqsmartvpp, the pipeline parallel
|
|
1091
|
+
must be an even number.
|
|
1086
1092
|
parallel_optimizer_config (dict): A dict contains the keys and values for setting the parallel optimizer
|
|
1087
1093
|
configure. The configure provides more detailed behavior control about parallel training
|
|
1088
1094
|
when parallel optimizer is enabled. The configure will be effective when we use
|
|
@@ -1781,8 +1787,8 @@ def set_ps_context(**kwargs):
|
|
|
1781
1787
|
config_file_path (str): Configuration file path used by recovery, parameter server training mode only
|
|
1782
1788
|
supports Server disaster recovery currently. Default: ``''`` .
|
|
1783
1789
|
enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: ``False``.
|
|
1784
|
-
|
|
1785
|
-
|
|
1790
|
+
Turning it off by default may be a security risk,
|
|
1791
|
+
and users need to ensure the security of the network environment.
|
|
1786
1792
|
client_password (str): Password to decrypt the secret key stored in the client certificate. Default: ``''`` .
|
|
1787
1793
|
server_password (str): Password to decrypt the secret key stored in the server certificate. Default: ``''`` .
|
|
1788
1794
|
|
|
@@ -1809,8 +1815,8 @@ def get_ps_context(attr_key):
|
|
|
1809
1815
|
parameter server training mode only
|
|
1810
1816
|
supports Server disaster recovery currently. Default: ``''`` .
|
|
1811
1817
|
- enable_ssl (bool, optional): Set PS SSL mode enabled or disabled. Default: ``False`` .
|
|
1812
|
-
|
|
1813
|
-
|
|
1818
|
+
Turning it off by default may be a security risk,
|
|
1819
|
+
and users need to ensure the security of the network environment.
|
|
1814
1820
|
|
|
1815
1821
|
Returns:
|
|
1816
1822
|
Returns attribute value according to the key.
|
mindspore/dataset/__init__.py
CHANGED
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
"""
|
|
15
15
|
At the heart of MindSpore data loading utility is the `mindspore.dataset` module.
|
|
16
|
-
It is a `dataset engine <https://www.mindspore.cn/docs/en/master/
|
|
16
|
+
It is a `dataset engine <https://www.mindspore.cn/docs/en/master/design/data_engine.html>`_ based on pipline design.
|
|
17
17
|
|
|
18
18
|
This module provides the following data loading methods to help users load datasets into MindSpore.
|
|
19
19
|
|
|
@@ -2793,7 +2793,7 @@ class PhaseVocoder(AudioTensorOperation):
|
|
|
2793
2793
|
Raises:
|
|
2794
2794
|
TypeError: If `rate` is not of type float.
|
|
2795
2795
|
ValueError: If `rate` is not a positive number.
|
|
2796
|
-
TypeError: If `phase_advance` is not of type
|
|
2796
|
+
TypeError: If `phase_advance` is not of type :class:`numpy.ndarray` .
|
|
2797
2797
|
RuntimeError: If input tensor is not in shape of <..., freq, num_frame, complex=2>.
|
|
2798
2798
|
|
|
2799
2799
|
Supported Platforms:
|
mindspore/dataset/core/config.py
CHANGED
|
@@ -34,6 +34,7 @@ from mindspore.dataset.core.validator_helpers import replace_none, type_check, c
|
|
|
34
34
|
from mindspore.dataset.debug import DebugHook, PrintMetaDataHook
|
|
35
35
|
from mindspore.dataset.core.validator_helpers import check_independent_mode
|
|
36
36
|
|
|
37
|
+
|
|
37
38
|
__all__ = ['set_sending_batches', 'load', '_init_device_info',
|
|
38
39
|
'set_seed', 'get_seed',
|
|
39
40
|
'set_prefetch_size', 'get_prefetch_size',
|
|
@@ -1173,38 +1174,3 @@ def get_multiprocessing_start_method():
|
|
|
1173
1174
|
>>> multiprocessing_start_method = ds.config.get_multiprocessing_start_method()
|
|
1174
1175
|
"""
|
|
1175
1176
|
return _config.get_multiprocessing_start_method()
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
def set_video_backend(backend):
|
|
1179
|
-
"""
|
|
1180
|
-
Set the backend used to decode videos.
|
|
1181
|
-
|
|
1182
|
-
Args:
|
|
1183
|
-
backend (str): Type of the video backend. It can be "CPU" or "Ascend".
|
|
1184
|
-
|
|
1185
|
-
Raises:
|
|
1186
|
-
TypeError: If `backend` is not of type str.
|
|
1187
|
-
ValueError: If `backend` is not "CPU" or "Ascend".
|
|
1188
|
-
|
|
1189
|
-
Examples:
|
|
1190
|
-
>>> import mindspore.dataset as ds
|
|
1191
|
-
>>> ds.config.set_video_backend("CPU")
|
|
1192
|
-
"""
|
|
1193
|
-
|
|
1194
|
-
type_check(backend, (str,), "backend")
|
|
1195
|
-
check_valid_str(backend, ["CPU", "Ascend"], "backend")
|
|
1196
|
-
_config.set_video_backend(backend)
|
|
1197
|
-
|
|
1198
|
-
|
|
1199
|
-
def get_video_backend():
|
|
1200
|
-
"""
|
|
1201
|
-
Returns the currently active backend used to decode videos.
|
|
1202
|
-
|
|
1203
|
-
Returns:
|
|
1204
|
-
str, backend used to decode videos.
|
|
1205
|
-
|
|
1206
|
-
Examples:
|
|
1207
|
-
>>> import mindspore.dataset as ds
|
|
1208
|
-
>>> backend = ds.config.get_video_backend()
|
|
1209
|
-
"""
|
|
1210
|
-
return _config.get_video_backend()
|