mindspore 2.4.0__cp311-cp311-manylinux1_x86_64.whl → 2.4.1__cp311-cp311-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/_c_dataengine.cpython-311-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-311-x86_64-linux-gnu.so +0 -0
- mindspore/common/initializer.py +51 -15
- mindspore/common/parameter.py +18 -4
- mindspore/common/tensor.py +15 -49
- mindspore/communication/comm_func.py +7 -7
- mindspore/context.py +9 -0
- mindspore/include/mindapi/base/format.h +13 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_ops.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/all_finite.json +10 -10
- mindspore/lib/plugin/ascend/custom_ascendc_ops/op_impl/ai_core/tbe/kernel/config/ascend910b/binary_info_config.json +8 -8
- mindspore/lib/plugin/ascend/custom_compiler/setup.py +1 -1
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_internal_kernels.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/include/asdops/utils/rt/base/types.h +5 -5
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/libasdops_static.a +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/liblcal.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/asdops/lib/liblcal_static.a +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/include/acme_op.h +1 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/paged_attention_op.h +6 -1
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/include/acme/src/ops/host_src/rms_norm_op.h +4 -3
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libapply_rotary_pos_emb_310p_impl.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/lib/libms_kernels_internal.so +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_bf16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_bf16_bnsd_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_bf16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_bf16_bsh_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_fp16_bnsd_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_fp16_bnsd_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_fp16_bsh_full_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/flash_attention_score/flash_attention_score_fp16_bsh_tri_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/paged_attention/paged_attention_bf16_bnsd_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/paged_attention/paged_attention_bf16_bsh_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/paged_attention/paged_attention_fp16_bnsd_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/internal_kernel/op_kernels/ascend910b/paged_attention/paged_attention_fp16_bsh_mix.o +0 -0
- mindspore/lib/plugin/ascend/ms_kernels_internal/lccl/lib/liblcal.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/mint/__init__.py +490 -2
- mindspore/mint/nn/__init__.py +2 -2
- mindspore/mint/optim/adamw.py +6 -14
- mindspore/nn/cell.py +1 -3
- mindspore/nn/layer/basic.py +24 -7
- mindspore/nn/layer/embedding.py +31 -14
- mindspore/nn/optim/tft_wrapper.py +12 -15
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +20 -1
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +6 -0
- mindspore/ops/auto_generate/gen_extend_func.py +33 -0
- mindspore/ops/auto_generate/gen_ops_def.py +52 -3
- mindspore/ops/auto_generate/gen_ops_prim.py +155 -6
- mindspore/ops/function/array_func.py +2 -0
- mindspore/ops/function/math_func.py +7 -1
- mindspore/ops/function/random_func.py +221 -7
- mindspore/ops/operations/__init__.py +1 -1
- mindspore/ops/operations/array_ops.py +3 -1
- mindspore/ops/operations/comm_ops.py +21 -0
- mindspore/ops/operations/manually_defined/ops_def.py +8 -10
- mindspore/parallel/_auto_parallel_context.py +3 -1
- mindspore/parallel/_cell_wrapper.py +2 -0
- mindspore/parallel/_tensor.py +46 -2
- mindspore/parallel/_utils.py +40 -21
- mindspore/parallel/transform_safetensors.py +196 -43
- mindspore/profiler/profiling.py +5 -1
- mindspore/run_check/_check_version.py +4 -2
- mindspore/train/_utils.py +92 -32
- mindspore/train/callback/_checkpoint.py +12 -9
- mindspore/train/callback/_on_request_exit.py +12 -1
- mindspore/train/callback/_tft_register.py +27 -4
- mindspore/train/dataset_helper.py +10 -2
- mindspore/train/model.py +20 -0
- mindspore/train/serialization.py +8 -18
- mindspore/version.py +1 -1
- {mindspore-2.4.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +8 -6
- {mindspore-2.4.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +97 -97
- {mindspore-2.4.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
- {mindspore-2.4.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.4.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
|
@@ -81,6 +81,7 @@ from mindspore._c_expression import pyboost_copy_ext
|
|
|
81
81
|
from mindspore._c_expression import pyboost_copy
|
|
82
82
|
from mindspore._c_expression import pyboost_cos
|
|
83
83
|
from mindspore._c_expression import pyboost_cosh
|
|
84
|
+
from mindspore._c_expression import pyboost_count_nonzero
|
|
84
85
|
from mindspore._c_expression import pyboost_cross
|
|
85
86
|
from mindspore._c_expression import pyboost_cummax
|
|
86
87
|
from mindspore._c_expression import pyboost_cummin_ext
|
|
@@ -116,7 +117,6 @@ from mindspore._c_expression import pyboost_gather_d_grad_v2
|
|
|
116
117
|
from mindspore._c_expression import pyboost_gather_d
|
|
117
118
|
from mindspore._c_expression import pyboost_gelu_grad
|
|
118
119
|
from mindspore._c_expression import pyboost_gelu
|
|
119
|
-
from mindspore._c_expression import pyboost_generator
|
|
120
120
|
from mindspore._c_expression import pyboost_greater_equal
|
|
121
121
|
from mindspore._c_expression import pyboost_greater
|
|
122
122
|
from mindspore._c_expression import pyboost_grid_sampler_2d_grad
|
|
@@ -195,6 +195,7 @@ from mindspore._c_expression import pyboost_muls
|
|
|
195
195
|
from mindspore._c_expression import pyboost_multinomial_ext
|
|
196
196
|
from mindspore._c_expression import pyboost_mv
|
|
197
197
|
from mindspore._c_expression import pyboost_nan_to_num
|
|
198
|
+
from mindspore._c_expression import pyboost_ne_scalar
|
|
198
199
|
from mindspore._c_expression import pyboost_neg
|
|
199
200
|
from mindspore._c_expression import pyboost_non_zero_ext
|
|
200
201
|
from mindspore._c_expression import pyboost_non_zero
|
|
@@ -214,6 +215,11 @@ from mindspore._c_expression import pyboost_prelu
|
|
|
214
215
|
from mindspore._c_expression import pyboost_prod_ext
|
|
215
216
|
from mindspore._c_expression import pyboost_rand_ext
|
|
216
217
|
from mindspore._c_expression import pyboost_rand_like_ext
|
|
218
|
+
from mindspore._c_expression import pyboost_randint_like
|
|
219
|
+
from mindspore._c_expression import pyboost_randint
|
|
220
|
+
from mindspore._c_expression import pyboost_randn_like
|
|
221
|
+
from mindspore._c_expression import pyboost_randn
|
|
222
|
+
from mindspore._c_expression import pyboost_randperm_ext
|
|
217
223
|
from mindspore._c_expression import pyboost_reciprocal
|
|
218
224
|
from mindspore._c_expression import pyboost_reduce_all
|
|
219
225
|
from mindspore._c_expression import pyboost_reduce_any
|
|
@@ -250,6 +256,7 @@ from mindspore._c_expression import pyboost_scatter_add_ext
|
|
|
250
256
|
from mindspore._c_expression import pyboost_scatter
|
|
251
257
|
from mindspore._c_expression import pyboost_scatter_value
|
|
252
258
|
from mindspore._c_expression import pyboost_searchsorted
|
|
259
|
+
from mindspore._c_expression import pyboost_select_ext
|
|
253
260
|
from mindspore._c_expression import pyboost_select
|
|
254
261
|
from mindspore._c_expression import pyboost_select_v2
|
|
255
262
|
from mindspore._c_expression import pyboost_selu_ext
|
|
@@ -317,6 +324,7 @@ from mindspore._c_expression import pyboost_zeros
|
|
|
317
324
|
from mindspore._c_expression import pyboost_add_rmsnorm_quant_v2
|
|
318
325
|
from mindspore._c_expression import pyboost_dynamic_quant_ext
|
|
319
326
|
from mindspore._c_expression import pyboost_grouped_matmul
|
|
327
|
+
from mindspore._c_expression import pyboost_kv_cache_scatter_update
|
|
320
328
|
from mindspore._c_expression import pyboost_moe_finalize_routing
|
|
321
329
|
from mindspore._c_expression import pyboost_quant_batch_matmul
|
|
322
330
|
from mindspore._c_expression import pyboost_quant_v2
|
|
@@ -3770,6 +3778,36 @@ class Cosh(Primitive):
|
|
|
3770
3778
|
cosh_op=Cosh()
|
|
3771
3779
|
|
|
3772
3780
|
|
|
3781
|
+
class CountNonZero(Primitive):
|
|
3782
|
+
r"""
|
|
3783
|
+
.. code-block::
|
|
3784
|
+
|
|
3785
|
+
prim = ops.CountNonZero()
|
|
3786
|
+
out = prim(input, dim)
|
|
3787
|
+
|
|
3788
|
+
is equivalent to
|
|
3789
|
+
|
|
3790
|
+
.. code-block::
|
|
3791
|
+
|
|
3792
|
+
ops.count_nonzero(input, dim)
|
|
3793
|
+
|
|
3794
|
+
Refer to :func:`mindspore.ops.count_nonzero` for more details.
|
|
3795
|
+
"""
|
|
3796
|
+
__mindspore_signature__ = (
|
|
3797
|
+
sig.make_sig('input'),
|
|
3798
|
+
sig.make_sig('dim', default=None),
|
|
3799
|
+
)
|
|
3800
|
+
|
|
3801
|
+
@prim_arg_register
|
|
3802
|
+
def __init__(self):
|
|
3803
|
+
pass
|
|
3804
|
+
|
|
3805
|
+
def __call__(self, input, dim=None):
|
|
3806
|
+
return _convert_stub(pyboost_count_nonzero(self, [input, dim]))
|
|
3807
|
+
|
|
3808
|
+
count_nonzero_op=CountNonZero()
|
|
3809
|
+
|
|
3810
|
+
|
|
3773
3811
|
class Cross(Primitive):
|
|
3774
3812
|
r"""
|
|
3775
3813
|
Returns the cross product of vectors in dimension `dim` of input and other.
|
|
@@ -6490,7 +6528,8 @@ class Generator(Primitive):
|
|
|
6490
6528
|
self.add_prim_attr("side_effect_mem", True)
|
|
6491
6529
|
|
|
6492
6530
|
def __call__(self, cmd, inputs):
|
|
6493
|
-
return
|
|
6531
|
+
return super().__call__(cmd, inputs)
|
|
6532
|
+
|
|
6494
6533
|
|
|
6495
6534
|
generator_op=Generator()
|
|
6496
6535
|
|
|
@@ -8678,6 +8717,9 @@ class LinSpaceExt(Primitive):
|
|
|
8678
8717
|
&output = [start, start+step, start+2*step, ... , end]
|
|
8679
8718
|
\end{aligned}
|
|
8680
8719
|
|
|
8720
|
+
.. warning::
|
|
8721
|
+
Atlas training series does not support int16 dtype currently.
|
|
8722
|
+
|
|
8681
8723
|
Inputs:
|
|
8682
8724
|
- **start** (Union[float, int]) - Start value of interval.
|
|
8683
8725
|
It can be a float or integer.
|
|
@@ -11245,6 +11287,115 @@ class RandLikeExt(Primitive):
|
|
|
11245
11287
|
rand_like_ext_op=RandLikeExt()
|
|
11246
11288
|
|
|
11247
11289
|
|
|
11290
|
+
class RandIntLike(Primitive):
|
|
11291
|
+
r"""
|
|
11292
|
+
|
|
11293
|
+
"""
|
|
11294
|
+
__mindspore_signature__ = (
|
|
11295
|
+
sig.make_sig('input'),
|
|
11296
|
+
sig.make_sig('low'),
|
|
11297
|
+
sig.make_sig('high'),
|
|
11298
|
+
sig.make_sig('seed'),
|
|
11299
|
+
sig.make_sig('offset'),
|
|
11300
|
+
sig.make_sig('dtype', default=None),
|
|
11301
|
+
)
|
|
11302
|
+
|
|
11303
|
+
@prim_arg_register
|
|
11304
|
+
def __init__(self):
|
|
11305
|
+
pass
|
|
11306
|
+
|
|
11307
|
+
def __call__(self, input, low, high, seed, offset, dtype=None):
|
|
11308
|
+
return _convert_stub(pyboost_randint_like(self, [input, low, high, seed, offset, dtype if dtype is None else dtype_to_type_id('RandIntLike', 'dtype', dtype)]))
|
|
11309
|
+
|
|
11310
|
+
randint_like_op=RandIntLike()
|
|
11311
|
+
|
|
11312
|
+
|
|
11313
|
+
class RandInt(Primitive):
|
|
11314
|
+
r"""
|
|
11315
|
+
|
|
11316
|
+
"""
|
|
11317
|
+
__mindspore_signature__ = (
|
|
11318
|
+
sig.make_sig('low'),
|
|
11319
|
+
sig.make_sig('high'),
|
|
11320
|
+
sig.make_sig('shape'),
|
|
11321
|
+
sig.make_sig('seed'),
|
|
11322
|
+
sig.make_sig('offset'),
|
|
11323
|
+
sig.make_sig('dtype', default=None),
|
|
11324
|
+
)
|
|
11325
|
+
|
|
11326
|
+
@prim_arg_register
|
|
11327
|
+
def __init__(self):
|
|
11328
|
+
pass
|
|
11329
|
+
|
|
11330
|
+
def __call__(self, low, high, shape, seed, offset, dtype=None):
|
|
11331
|
+
return _convert_stub(pyboost_randint(self, [low, high, shape, seed, offset, dtype if dtype is None else dtype_to_type_id('RandInt', 'dtype', dtype)]))
|
|
11332
|
+
|
|
11333
|
+
randint_op=RandInt()
|
|
11334
|
+
|
|
11335
|
+
|
|
11336
|
+
class RandnLike(Primitive):
|
|
11337
|
+
r"""
|
|
11338
|
+
|
|
11339
|
+
"""
|
|
11340
|
+
__mindspore_signature__ = (
|
|
11341
|
+
sig.make_sig('input'),
|
|
11342
|
+
sig.make_sig('seed'),
|
|
11343
|
+
sig.make_sig('offset'),
|
|
11344
|
+
sig.make_sig('dtype', default=None),
|
|
11345
|
+
)
|
|
11346
|
+
|
|
11347
|
+
@prim_arg_register
|
|
11348
|
+
def __init__(self):
|
|
11349
|
+
pass
|
|
11350
|
+
|
|
11351
|
+
def __call__(self, input, seed, offset, dtype=None):
|
|
11352
|
+
return _convert_stub(pyboost_randn_like(self, [input, seed, offset, dtype if dtype is None else dtype_to_type_id('RandnLike', 'dtype', dtype)]))
|
|
11353
|
+
|
|
11354
|
+
randn_like_op=RandnLike()
|
|
11355
|
+
|
|
11356
|
+
|
|
11357
|
+
class Randn(Primitive):
|
|
11358
|
+
r"""
|
|
11359
|
+
|
|
11360
|
+
"""
|
|
11361
|
+
__mindspore_signature__ = (
|
|
11362
|
+
sig.make_sig('shape'),
|
|
11363
|
+
sig.make_sig('seed'),
|
|
11364
|
+
sig.make_sig('offset'),
|
|
11365
|
+
sig.make_sig('dtype', default=None),
|
|
11366
|
+
)
|
|
11367
|
+
|
|
11368
|
+
@prim_arg_register
|
|
11369
|
+
def __init__(self):
|
|
11370
|
+
pass
|
|
11371
|
+
|
|
11372
|
+
def __call__(self, shape, seed, offset, dtype=None):
|
|
11373
|
+
return _convert_stub(pyboost_randn(self, [shape, seed, offset, dtype if dtype is None else dtype_to_type_id('Randn', 'dtype', dtype)]))
|
|
11374
|
+
|
|
11375
|
+
randn_op=Randn()
|
|
11376
|
+
|
|
11377
|
+
|
|
11378
|
+
class RandpermExt(Primitive):
|
|
11379
|
+
r"""
|
|
11380
|
+
|
|
11381
|
+
"""
|
|
11382
|
+
__mindspore_signature__ = (
|
|
11383
|
+
sig.make_sig('n'),
|
|
11384
|
+
sig.make_sig('seed'),
|
|
11385
|
+
sig.make_sig('offset'),
|
|
11386
|
+
sig.make_sig('dtype', default=mstype.int64),
|
|
11387
|
+
)
|
|
11388
|
+
|
|
11389
|
+
@prim_arg_register
|
|
11390
|
+
def __init__(self):
|
|
11391
|
+
pass
|
|
11392
|
+
|
|
11393
|
+
def __call__(self, n, seed, offset, dtype=mstype.int64):
|
|
11394
|
+
return _convert_stub(pyboost_randperm_ext(self, [n, seed, offset, dtype_to_type_id('RandpermExt', 'dtype', dtype)]))
|
|
11395
|
+
|
|
11396
|
+
randperm_ext_op=RandpermExt()
|
|
11397
|
+
|
|
11398
|
+
|
|
11248
11399
|
class RandpermV2(Primitive):
|
|
11249
11400
|
r"""
|
|
11250
11401
|
.. code-block::
|
|
@@ -13631,8 +13782,7 @@ class SelectExt(Primitive):
|
|
|
13631
13782
|
pass
|
|
13632
13783
|
|
|
13633
13784
|
def __call__(self, input, dim, index):
|
|
13634
|
-
return
|
|
13635
|
-
|
|
13785
|
+
return _convert_stub(pyboost_select_ext(self, [input, dim, index]))
|
|
13636
13786
|
|
|
13637
13787
|
select_ext_op=SelectExt()
|
|
13638
13788
|
|
|
@@ -16494,8 +16644,7 @@ class KVCacheScatterUpdate(Primitive):
|
|
|
16494
16644
|
self.add_prim_attr("side_effect_mem", True)
|
|
16495
16645
|
|
|
16496
16646
|
def __call__(self, var, indices, updates, axis, reduce='none'):
|
|
16497
|
-
return
|
|
16498
|
-
|
|
16647
|
+
return _convert_stub(pyboost_kv_cache_scatter_update(self, [var, indices, updates, axis, str_to_enum('KVCacheScatterUpdate', 'reduce', reduce)]))
|
|
16499
16648
|
|
|
16500
16649
|
kv_cache_scatter_update_op=KVCacheScatterUpdate()
|
|
16501
16650
|
|
|
@@ -1312,6 +1312,8 @@ def unique_with_pad(x, pad_num):
|
|
|
1312
1312
|
|
|
1313
1313
|
.. warning::
|
|
1314
1314
|
:func:`mindspore.ops.unique_with_pad` is deprecated from version 2.4 and will be removed in a future version.
|
|
1315
|
+
Please use the :func:`mindspore.ops.unique` combined with :func:`mindspore.ops.pad` to realize
|
|
1316
|
+
the same function.
|
|
1315
1317
|
|
|
1316
1318
|
Args:
|
|
1317
1319
|
x (Tensor): The tensor need to be unique. Must be 1-D vector with types: int32, int64.
|
|
@@ -2761,6 +2761,9 @@ def linspace_ext(start, end, steps, *, dtype=None):
|
|
|
2761
2761
|
&output = [start, start+step, start+2*step, ... , end]
|
|
2762
2762
|
\end{aligned}
|
|
2763
2763
|
|
|
2764
|
+
.. warning::
|
|
2765
|
+
Atlas training series does not support int16 dtype currently.
|
|
2766
|
+
|
|
2764
2767
|
Args:
|
|
2765
2768
|
start (Union[float, int]): Start value of interval.
|
|
2766
2769
|
It can be a float or integer.
|
|
@@ -7518,7 +7521,7 @@ def norm_ext(input, p='fro', dim=None, keepdim=False, *, dtype=None):
|
|
|
7518
7521
|
This is an experimental API that is subject to change or deletion.
|
|
7519
7522
|
|
|
7520
7523
|
Args:
|
|
7521
|
-
input (Tensor): The input of
|
|
7524
|
+
input (Tensor): The input of norm with data type of bfloat16, float16 or float32.
|
|
7522
7525
|
The shape is :math:`(*)` where :math:`*` means, any number of additional dimensions.
|
|
7523
7526
|
p (Union[int, float, inf, -inf, 'fro', 'nuc'], optional): norm's mode. refer to the table above for
|
|
7524
7527
|
behavior. Default: ``fro`` .
|
|
@@ -7554,6 +7557,9 @@ def norm_ext(input, p='fro', dim=None, keepdim=False, *, dtype=None):
|
|
|
7554
7557
|
>>> print(ops.function.math_func.norm_ext(x, 2.0))
|
|
7555
7558
|
38.327538
|
|
7556
7559
|
"""
|
|
7560
|
+
if not isinstance(input, (Tensor, Tensor_)):
|
|
7561
|
+
raise TypeError(f"For `norm_ext`, the `input` must be Tensor!, but get {type(input)}.")
|
|
7562
|
+
|
|
7557
7563
|
if (dim is not None) or keepdim or (dtype is not None):
|
|
7558
7564
|
raise ValueError(f"For `norm_ext`, the value of `dim`, `keepdim` and `dtype` must be default value currently.")
|
|
7559
7565
|
|
|
@@ -30,7 +30,8 @@ from mindspore.common.api import _function_forbid_reuse
|
|
|
30
30
|
from mindspore.ops.auto_generate import randperm
|
|
31
31
|
from mindspore.common.generator import default_generator
|
|
32
32
|
from mindspore.ops.auto_generate import UniformExt, NormalTensorTensor, \
|
|
33
|
-
NormalTensorFloat, NormalFloatTensor, NormalFloatFloat, RandExt, RandLikeExt, MultinomialExt
|
|
33
|
+
NormalTensorFloat, NormalFloatTensor, NormalFloatFloat, RandExt, RandLikeExt, MultinomialExt, \
|
|
34
|
+
Randn, RandnLike, RandInt, RandIntLike, RandpermExt
|
|
34
35
|
|
|
35
36
|
normal_tensor_tensor_op = NormalTensorTensor()
|
|
36
37
|
normal_tensor_float_op = NormalTensorFloat()
|
|
@@ -42,10 +43,15 @@ real_div_ = P.RealDiv()
|
|
|
42
43
|
reshape_ = P.Reshape()
|
|
43
44
|
shape_ = P.Shape()
|
|
44
45
|
top_k_ = P.TopK()
|
|
46
|
+
randperm_ext_ = RandpermExt()
|
|
45
47
|
uniform_ = UniformExt()
|
|
46
48
|
rand_ext_ = RandExt()
|
|
47
49
|
rand_like_ext_ = RandLikeExt()
|
|
48
50
|
multinomial_ext_ = MultinomialExt()
|
|
51
|
+
randn_ = Randn()
|
|
52
|
+
randn_like_ = RandnLike()
|
|
53
|
+
randint_ = RandInt()
|
|
54
|
+
randint_like_ = RandIntLike()
|
|
49
55
|
generator_step_ = Tensor(10, mstype.int64)
|
|
50
56
|
|
|
51
57
|
|
|
@@ -287,7 +293,8 @@ def uniform_ext(tensor, a, b, generator=None):
|
|
|
287
293
|
"""
|
|
288
294
|
if generator is None:
|
|
289
295
|
generator = default_generator
|
|
290
|
-
seed, offset = generator._step(
|
|
296
|
+
seed, offset = generator._step( # pylint: disable=protected-access
|
|
297
|
+
generator_step_)
|
|
291
298
|
return uniform_(tensor, a, b, seed, offset)
|
|
292
299
|
|
|
293
300
|
|
|
@@ -755,7 +762,8 @@ def normal_ext(mean=0.0, std=1.0, size=None, generator=None):
|
|
|
755
762
|
"""
|
|
756
763
|
if generator is None:
|
|
757
764
|
generator = default_generator
|
|
758
|
-
seed, offset = generator._step(
|
|
765
|
+
seed, offset = generator._step( # pylint: disable=protected-access
|
|
766
|
+
generator_step_)
|
|
759
767
|
|
|
760
768
|
is_mean_tensor = isinstance(mean, Tensor)
|
|
761
769
|
is_std_tensor = isinstance(std, Tensor)
|
|
@@ -1129,7 +1137,8 @@ def rand_ext(*size, generator=None, dtype=None):
|
|
|
1129
1137
|
"""
|
|
1130
1138
|
if not generator:
|
|
1131
1139
|
generator = default_generator
|
|
1132
|
-
seed, offset = generator._step(
|
|
1140
|
+
seed, offset = generator._step( # pylint: disable=protected-access
|
|
1141
|
+
generator_step_)
|
|
1133
1142
|
return rand_ext_(size, seed, offset, dtype)
|
|
1134
1143
|
|
|
1135
1144
|
|
|
@@ -1163,10 +1172,174 @@ def rand_like_ext(input, *, dtype=None):
|
|
|
1163
1172
|
>>> print(ops.function.random_func.rand_like_ext(a, dtype=ms.float32).shape)
|
|
1164
1173
|
(2, 3)
|
|
1165
1174
|
"""
|
|
1166
|
-
seed, offset = default_generator._step(
|
|
1175
|
+
seed, offset = default_generator._step( # pylint: disable=protected-access
|
|
1176
|
+
generator_step_)
|
|
1167
1177
|
return rand_like_ext_(input, seed, offset, dtype)
|
|
1168
1178
|
|
|
1169
1179
|
|
|
1180
|
+
@_function_forbid_reuse
|
|
1181
|
+
def randn_ext(*size, generator=None, dtype=None):
|
|
1182
|
+
r"""
|
|
1183
|
+
Returns a new tensor filled with numbers from the normal distribution over an interval :math:`[0, 1)`
|
|
1184
|
+
based on the given shape and dtype.
|
|
1185
|
+
|
|
1186
|
+
.. warning::
|
|
1187
|
+
This is an experimental API that is subject to change or deletion.
|
|
1188
|
+
|
|
1189
|
+
Args:
|
|
1190
|
+
size (Union[int, tuple(int), list(int)]): Shape of the new tensor, e.g. :math:`(2, 3)` or :math:`2`.
|
|
1191
|
+
|
|
1192
|
+
Keyword Args:
|
|
1193
|
+
generator (:class:`mindspore.Generator`, optional): a pseudorandom number generator.
|
|
1194
|
+
Default: ``None``, uses the default pseudorandom number generator.
|
|
1195
|
+
dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
|
|
1196
|
+
`mindspore.float32` will be applied. Default: ``None`` .
|
|
1197
|
+
|
|
1198
|
+
Returns:
|
|
1199
|
+
Tensor, with the designated shape and dtype, filled with random numbers from the normal distribution on
|
|
1200
|
+
the interval :math:`[0, 1)`.
|
|
1201
|
+
|
|
1202
|
+
Raises:
|
|
1203
|
+
ValueError: If `dtype` is not a `mstype.float_type` type.
|
|
1204
|
+
|
|
1205
|
+
Supported Platforms:
|
|
1206
|
+
``Ascend``
|
|
1207
|
+
|
|
1208
|
+
Examples:
|
|
1209
|
+
>>> from mindspore import ops
|
|
1210
|
+
>>> print(ops.function.random_func.randn_ext(2, 3).shape)
|
|
1211
|
+
(2, 3)
|
|
1212
|
+
"""
|
|
1213
|
+
if not generator:
|
|
1214
|
+
generator = default_generator
|
|
1215
|
+
seed, offset = generator._step( # pylint: disable=protected-access
|
|
1216
|
+
generator_step_)
|
|
1217
|
+
return randn_(size, seed, offset, dtype)
|
|
1218
|
+
|
|
1219
|
+
|
|
1220
|
+
@_function_forbid_reuse
|
|
1221
|
+
def randn_like_ext(input, *, dtype=None):
|
|
1222
|
+
r"""
|
|
1223
|
+
Returns a new tensor filled with numbers from the normal distribution over an interval :math:`[0, 1)`
|
|
1224
|
+
based on the given dtype and shape of the input tensor.
|
|
1225
|
+
|
|
1226
|
+
.. warning::
|
|
1227
|
+
This is an experimental API that is subject to change or deletion.
|
|
1228
|
+
|
|
1229
|
+
Args:
|
|
1230
|
+
input (Tensor): Input Tensor to specify the output shape and its default dtype.
|
|
1231
|
+
|
|
1232
|
+
Keyword Args:
|
|
1233
|
+
dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
|
|
1234
|
+
the same dtype of `input` will be applied. Default: ``None`` .
|
|
1235
|
+
|
|
1236
|
+
Returns:
|
|
1237
|
+
Tensor, with the designated shape and dtype, filled with random numbers from the normal distribution on
|
|
1238
|
+
the interval :math:`[0, 1)`.
|
|
1239
|
+
|
|
1240
|
+
Raises:
|
|
1241
|
+
ValueError: If `dtype` is not a `mstype.float_type` type.
|
|
1242
|
+
|
|
1243
|
+
Supported Platforms:
|
|
1244
|
+
``Ascend``
|
|
1245
|
+
|
|
1246
|
+
Examples:
|
|
1247
|
+
>>> import mindspore as ms
|
|
1248
|
+
>>> from mindspore import Tensor, ops
|
|
1249
|
+
>>> a = Tensor([[2, 3, 4], [1, 2, 3]])
|
|
1250
|
+
>>> print(ops.function.random_func.randn_like_ext(a, dtype=ms.float32).shape)
|
|
1251
|
+
(2, 3)
|
|
1252
|
+
"""
|
|
1253
|
+
seed, offset = default_generator._step( # pylint: disable=protected-access
|
|
1254
|
+
generator_step_)
|
|
1255
|
+
return randn_like_(input, seed, offset, dtype)
|
|
1256
|
+
|
|
1257
|
+
|
|
1258
|
+
@_function_forbid_reuse
|
|
1259
|
+
def randint_ext(low, high, size, *, generator=None, dtype=None):
|
|
1260
|
+
r"""
|
|
1261
|
+
Returns a new tensor filled with integer numbers from the uniform distribution over an interval :math:`[low, high)`
|
|
1262
|
+
based on the given shape and dtype.
|
|
1263
|
+
|
|
1264
|
+
.. warning::
|
|
1265
|
+
This is an experimental API that is subject to change or deletion.
|
|
1266
|
+
|
|
1267
|
+
Args:
|
|
1268
|
+
low (int): the lower bound of the generated random number
|
|
1269
|
+
high (int): the upper bound of the generated random number
|
|
1270
|
+
size (Union[tuple(int), list(int)]): Shape of the new tensor, e.g. :math:`(2, 3)`.
|
|
1271
|
+
|
|
1272
|
+
Keyword Args:
|
|
1273
|
+
generator (:class:`mindspore.Generator`, optional): a pseudorandom number generator.
|
|
1274
|
+
Default: ``None``, uses the default pseudorandom number generator.
|
|
1275
|
+
dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype. If None,
|
|
1276
|
+
`mindspore.int64` will be applied. Default: ``None`` .
|
|
1277
|
+
|
|
1278
|
+
Returns:
|
|
1279
|
+
Tensor, with the designated shape and dtype, filled with random numbers from the uniform distribution on
|
|
1280
|
+
the interval :math:`[low, high)`.
|
|
1281
|
+
|
|
1282
|
+
Raises:
|
|
1283
|
+
TypeError: If `size` is not a tuple.
|
|
1284
|
+
TypeError: If `low` or `high` is not integer.
|
|
1285
|
+
|
|
1286
|
+
Supported Platforms:
|
|
1287
|
+
``Ascend``
|
|
1288
|
+
|
|
1289
|
+
Examples:
|
|
1290
|
+
>>> from mindspore import ops
|
|
1291
|
+
>>> print(ops.function.random_func.randint_ext(0, 5, (2, 3)).shape)
|
|
1292
|
+
(2, 3)
|
|
1293
|
+
"""
|
|
1294
|
+
if not generator:
|
|
1295
|
+
generator = default_generator
|
|
1296
|
+
seed, offset = generator._step( # pylint: disable=protected-access
|
|
1297
|
+
generator_step_)
|
|
1298
|
+
return randint_(low, high, size, seed, offset, dtype)
|
|
1299
|
+
|
|
1300
|
+
|
|
1301
|
+
@_function_forbid_reuse
|
|
1302
|
+
def randint_like_ext(input, low, high, *, dtype=None):
|
|
1303
|
+
r"""
|
|
1304
|
+
Returns a new tensor filled with integer numbers from the uniform distribution over an interval :math:`[low, high)`
|
|
1305
|
+
based on the given dtype and shape of the input tensor.
|
|
1306
|
+
|
|
1307
|
+
.. warning::
|
|
1308
|
+
This is an experimental API that is subject to change or deletion.
|
|
1309
|
+
|
|
1310
|
+
Args:
|
|
1311
|
+
input (Tensor): Input Tensor to specify the output shape and its default dtype.
|
|
1312
|
+
low (int): the lower bound of the generated random number
|
|
1313
|
+
high (int): the upper bound of the generated random number
|
|
1314
|
+
|
|
1315
|
+
Keyword Args:
|
|
1316
|
+
dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype. If None,
|
|
1317
|
+
the same dtype of `input` will be applied. Default: ``None`` .
|
|
1318
|
+
|
|
1319
|
+
Returns:
|
|
1320
|
+
Tensor, with the designated shape and dtype, filled with random numbers from the uniform distribution on
|
|
1321
|
+
the interval :math:`[low, high)`.
|
|
1322
|
+
|
|
1323
|
+
Raises:
|
|
1324
|
+
TypeError: If `low` or `high` is not integer.
|
|
1325
|
+
|
|
1326
|
+
Supported Platforms:
|
|
1327
|
+
``Ascend``
|
|
1328
|
+
|
|
1329
|
+
Examples:
|
|
1330
|
+
>>> import mindspore as ms
|
|
1331
|
+
>>> from mindspore import Tensor, ops
|
|
1332
|
+
>>> a = Tensor([[2, 3, 4], [1, 2, 3]])
|
|
1333
|
+
>>> low = 0
|
|
1334
|
+
>>> high = 5
|
|
1335
|
+
>>> print(ops.function.random_func.randint_like_ext(a, low, high, dtype=ms.int32).shape)
|
|
1336
|
+
(2, 3)
|
|
1337
|
+
"""
|
|
1338
|
+
seed, offset = default_generator._step( # pylint: disable=protected-access
|
|
1339
|
+
generator_step_)
|
|
1340
|
+
return randint_like_(input, low, high, seed, offset, dtype)
|
|
1341
|
+
|
|
1342
|
+
|
|
1170
1343
|
@_function_forbid_reuse
|
|
1171
1344
|
def randn(*size, dtype=None, seed=None):
|
|
1172
1345
|
r"""
|
|
@@ -1395,6 +1568,47 @@ def randint_like(input, low, high, seed=None, *, dtype=None):
|
|
|
1395
1568
|
return cast_(output, dtype)
|
|
1396
1569
|
|
|
1397
1570
|
|
|
1571
|
+
def randperm_ext(n, *, generator=None, dtype=mstype.int64):
|
|
1572
|
+
r"""
|
|
1573
|
+
Generates random permutation of integers from 0 to n-1.
|
|
1574
|
+
|
|
1575
|
+
.. warning::
|
|
1576
|
+
- This is an experimental API that is subject to change or deletion.
|
|
1577
|
+
|
|
1578
|
+
|
|
1579
|
+
Args:
|
|
1580
|
+
n (Union[Tensor, int]): size of the permutation. int or Tensor with shape: () or (1,) and
|
|
1581
|
+
data type int64. The value of `n` must be greater than zero.
|
|
1582
|
+
generator (:class:`mindspore.Generator`, optional): a pseudorandom number generator.
|
|
1583
|
+
Default: ``None``, uses the default pseudorandom number generator.
|
|
1584
|
+
dtype (mindspore.dtype, optional): The type of output. Default: mstype.int64.
|
|
1585
|
+
|
|
1586
|
+
Returns:
|
|
1587
|
+
Tensor with shape (n,) and type `dtype`.
|
|
1588
|
+
|
|
1589
|
+
Raises:
|
|
1590
|
+
TypeError: If `dtype` is not supported.
|
|
1591
|
+
ValueError: If `n` is a negative or 0 element.
|
|
1592
|
+
ValueError: If `n` is larger than the maximal data of the set dtype.
|
|
1593
|
+
|
|
1594
|
+
Supported Platforms:
|
|
1595
|
+
``Ascend``
|
|
1596
|
+
|
|
1597
|
+
Examples:
|
|
1598
|
+
>>> from mindspore import ops
|
|
1599
|
+
>>> from mindspore import dtype as mstype
|
|
1600
|
+
>>> n = 4
|
|
1601
|
+
>>> output = ops.randperm_ext(n, dtype=mstype.int64)
|
|
1602
|
+
>>> print(output.shape)
|
|
1603
|
+
(4,)
|
|
1604
|
+
"""
|
|
1605
|
+
if not generator:
|
|
1606
|
+
generator = default_generator
|
|
1607
|
+
seed, offset = generator._step( # pylint: disable=protected-access
|
|
1608
|
+
generator_step_)
|
|
1609
|
+
return randperm_ext_(n, seed, offset, dtype)
|
|
1610
|
+
|
|
1611
|
+
|
|
1398
1612
|
@_function_forbid_reuse
|
|
1399
1613
|
def poisson(shape, mean, seed=None):
|
|
1400
1614
|
r"""
|
|
@@ -1675,10 +1889,10 @@ def multinomial_ext(input, num_samples, replacement=False, *, generator=None):
|
|
|
1675
1889
|
>>> # [[0 0 0 0 0 0 0 0 1 0]
|
|
1676
1890
|
>>> # [1 1 1 1 1 0 1 1 1 1]]
|
|
1677
1891
|
"""
|
|
1678
|
-
|
|
1679
1892
|
if generator is None:
|
|
1680
1893
|
generator = default_generator
|
|
1681
|
-
seed, offset = generator._step(
|
|
1894
|
+
seed, offset = generator._step( # pylint: disable=protected-access
|
|
1895
|
+
generator_step_)
|
|
1682
1896
|
return multinomial_ext_(input, num_samples, replacement, seed, offset)
|
|
1683
1897
|
|
|
1684
1898
|
|
|
@@ -55,7 +55,7 @@ from .comm_ops import (AllGather, AllReduce, Reduce, NeighborExchange, NeighborE
|
|
|
55
55
|
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
|
56
56
|
_VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad,
|
|
57
57
|
_HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather,
|
|
58
|
-
_VirtualPipelineEnd, AlltoAllV, ReduceScatter)
|
|
58
|
+
_VirtualPipelineEnd, AlltoAllV, ReduceScatter, _VirtualAssignKvCache)
|
|
59
59
|
from .control_ops import GeSwitch, Merge
|
|
60
60
|
from .custom_ops import (Custom)
|
|
61
61
|
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
|
@@ -771,12 +771,14 @@ class Padding(Primitive):
|
|
|
771
771
|
class UniqueWithPad(Primitive):
|
|
772
772
|
"""
|
|
773
773
|
'ops.UniqueWithPad' is deprecated from version 2.4 and will be removed in a future version.
|
|
774
|
+
Please use the :func:`mindspore.ops.unique` combined with :func:`mindspore.ops.pad` to realize
|
|
775
|
+
the same function.
|
|
774
776
|
|
|
775
777
|
Supported Platforms:
|
|
776
778
|
Deprecated
|
|
777
779
|
"""
|
|
778
780
|
|
|
779
|
-
@deprecated("2.4", "ops.
|
|
781
|
+
@deprecated("2.4", "ops.unique and ops.pad", False)
|
|
780
782
|
@prim_attr_register
|
|
781
783
|
def __init__(self):
|
|
782
784
|
"""init UniqueWithPad"""
|
|
@@ -1682,6 +1682,27 @@ class _VirtualAssignAdd(PrimitiveWithInfer):
|
|
|
1682
1682
|
virtual_assign_add = _VirtualAssignAdd()
|
|
1683
1683
|
|
|
1684
1684
|
|
|
1685
|
+
class _VirtualAssignKvCache(PrimitiveWithInfer):
|
|
1686
|
+
"""
|
|
1687
|
+
Auto parallel virtual operator. Do nothing in forward, do Assign kv cache in backward. It is only for
|
|
1688
|
+
internal use of parallel modules and cannot be called by users.
|
|
1689
|
+
|
|
1690
|
+
"""
|
|
1691
|
+
|
|
1692
|
+
@prim_attr_register
|
|
1693
|
+
def __init__(self):
|
|
1694
|
+
"""Initialize _VirtualAssignAdd."""
|
|
1695
|
+
self.add_prim_attr('order_enforce_skip', True)
|
|
1696
|
+
self.add_prim_attr('side_effect_backprop_mem', True)
|
|
1697
|
+
|
|
1698
|
+
def infer_shape(self, x_shape, y_shape, kv_equal_shape):
|
|
1699
|
+
return x_shape
|
|
1700
|
+
|
|
1701
|
+
def infer_dtype(self, x_dtype, y_dtype, kv_equal_dtype):
|
|
1702
|
+
return x_dtype
|
|
1703
|
+
virtual_assign_kv_cache = _VirtualAssignKvCache()
|
|
1704
|
+
|
|
1705
|
+
|
|
1685
1706
|
class _VirtualAccuGrad(PrimitiveWithInfer):
|
|
1686
1707
|
"""
|
|
1687
1708
|
Auto parallel virtual operator. Do nothing in forward, return y in backward. It is only for
|
|
@@ -1171,17 +1171,15 @@ class Cast(Primitive):
|
|
|
1171
1171
|
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
|
|
1172
1172
|
|
|
1173
1173
|
def check_elim(self, x, dtype):
|
|
1174
|
-
if isinstance(x,
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
if data.dtype == dtype:
|
|
1178
|
-
return (True, x)
|
|
1179
|
-
if isinstance(x, Tensor) and x.dtype == dtype:
|
|
1180
|
-
x = Tensor(x)
|
|
1181
|
-
x.set_cast_dtype()
|
|
1174
|
+
if isinstance(x, Parameter):
|
|
1175
|
+
data = x.data
|
|
1176
|
+
if data.dtype == dtype:
|
|
1182
1177
|
return (True, x)
|
|
1183
|
-
|
|
1184
|
-
|
|
1178
|
+
if isinstance(x, Tensor) and x.dtype == dtype:
|
|
1179
|
+
x.set_cast_dtype()
|
|
1180
|
+
return (True, x)
|
|
1181
|
+
if isinstance(x, numbers.Number):
|
|
1182
|
+
return (True, Tensor(x, dtype=dtype))
|
|
1185
1183
|
return (False, None)
|
|
1186
1184
|
|
|
1187
1185
|
def __call__(self, input_x, dtype):
|
|
@@ -76,6 +76,7 @@ class _PipelineConfig:
|
|
|
76
76
|
class _PipelineScheduler:
|
|
77
77
|
PIPELINE_1F1B = "1f1b"
|
|
78
78
|
PIPELINE_GPIPE = "gpipe"
|
|
79
|
+
PIPELINE_SEQPIPE = "seqpipe"
|
|
79
80
|
|
|
80
81
|
|
|
81
82
|
class _AutoParallelContext:
|
|
@@ -914,7 +915,8 @@ class _AutoParallelContext:
|
|
|
914
915
|
pipeline_config[pp_interleave])
|
|
915
916
|
|
|
916
917
|
Validator.check_string(pipeline_config[pp_scheduler], [_PipelineScheduler.PIPELINE_1F1B,
|
|
917
|
-
_PipelineScheduler.PIPELINE_GPIPE
|
|
918
|
+
_PipelineScheduler.PIPELINE_GPIPE,
|
|
919
|
+
_PipelineScheduler.PIPELINE_SEQPIPE])
|
|
918
920
|
if not pipeline_config[pp_interleave] and pipeline_config[pp_scheduler] != _PipelineScheduler.PIPELINE_1F1B:
|
|
919
921
|
raise ValueError(f"When pipeline_interleave is False, {pp_scheduler} is not supported")
|
|
920
922
|
|
|
@@ -126,6 +126,8 @@ def _restore_parallel_context(origin_parallel_mode, origin_dataset_strategy):
|
|
|
126
126
|
if context.get_context("mode") == context.GRAPH_MODE:
|
|
127
127
|
context.set_auto_parallel_context(parallel_mode=origin_parallel_mode)
|
|
128
128
|
if origin_dataset_strategy != "data_parallel":
|
|
129
|
+
if origin_dataset_strategy is not None and isinstance(origin_dataset_strategy, list):
|
|
130
|
+
origin_dataset_strategy = tuple(tuple(ds_item) for ds_item in origin_dataset_strategy)
|
|
129
131
|
context.set_auto_parallel_context(dataset_strategy=origin_dataset_strategy)
|
|
130
132
|
|
|
131
133
|
|