mindspore 2.7.0__cp310-cp310-win_amd64.whl → 2.7.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +4 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_extends/parse/compile_config.py +24 -1
- mindspore/_extends/parse/deprecated/deprecated_tensor_method.py +6 -2
- mindspore/_extends/parse/resources.py +1 -1
- mindspore/_extends/parse/standard_method.py +8 -1
- mindspore/_extends/parse/trope.py +2 -1
- mindspore/_extends/pijit/pijit_func_white_list.py +7 -22
- mindspore/avcodec-59.dll +0 -0
- mindspore/avdevice-59.dll +0 -0
- mindspore/avfilter-8.dll +0 -0
- mindspore/avformat-59.dll +0 -0
- mindspore/avutil-57.dll +0 -0
- mindspore/boost/base.py +29 -2
- mindspore/common/_decorator.py +3 -2
- mindspore/common/_grad_function.py +3 -1
- mindspore/common/_tensor_cpp_method.py +1 -1
- mindspore/common/_tensor_docs.py +275 -64
- mindspore/common/_utils.py +0 -44
- mindspore/common/api.py +285 -35
- mindspore/common/dump.py +7 -108
- mindspore/common/dynamic_shape/auto_dynamic_shape.py +1 -3
- mindspore/common/hook_handle.py +60 -0
- mindspore/common/jit_config.py +5 -1
- mindspore/common/jit_trace.py +27 -12
- mindspore/common/lazy_inline.py +5 -3
- mindspore/common/parameter.py +13 -107
- mindspore/common/recompute.py +4 -11
- mindspore/common/tensor.py +16 -169
- mindspore/communication/_comm_helper.py +11 -1
- mindspore/communication/comm_func.py +138 -4
- mindspore/communication/management.py +85 -1
- mindspore/config/op_info.config +0 -15
- mindspore/context.py +5 -85
- mindspore/dataset/engine/datasets.py +8 -4
- mindspore/dataset/engine/datasets_vision.py +1 -1
- mindspore/dataset/engine/validators.py +1 -15
- mindspore/dnnl.dll +0 -0
- mindspore/{experimental/llm_boost/ascend_native → graph}/__init__.py +7 -7
- mindspore/graph/custom_pass.py +55 -0
- mindspore/include/dataset/execute.h +2 -2
- mindspore/jpeg62.dll +0 -0
- mindspore/mindrecord/__init__.py +3 -3
- mindspore/mindrecord/common/exceptions.py +1 -0
- mindspore/mindrecord/config.py +1 -1
- mindspore/{parallel/mpi → mindrecord/core}/__init__.py +4 -1
- mindspore/mindrecord/{shardheader.py → core/shardheader.py} +2 -1
- mindspore/mindrecord/{shardindexgenerator.py → core/shardindexgenerator.py} +1 -1
- mindspore/mindrecord/{shardreader.py → core/shardreader.py} +2 -1
- mindspore/mindrecord/{shardsegment.py → core/shardsegment.py} +2 -2
- mindspore/mindrecord/{shardutils.py → core/shardutils.py} +1 -1
- mindspore/mindrecord/{shardwriter.py → core/shardwriter.py} +1 -1
- mindspore/mindrecord/filereader.py +4 -4
- mindspore/mindrecord/filewriter.py +5 -5
- mindspore/mindrecord/mindpage.py +2 -2
- mindspore/mindrecord/tools/cifar10.py +1 -1
- mindspore/mindrecord/tools/cifar100.py +1 -1
- mindspore/mindrecord/tools/cifar100_to_mr.py +1 -1
- mindspore/mindrecord/tools/cifar10_to_mr.py +1 -1
- mindspore/mindrecord/tools/csv_to_mr.py +1 -1
- mindspore/mindrecord/tools/imagenet_to_mr.py +1 -1
- mindspore/mindrecord/tools/mnist_to_mr.py +1 -1
- mindspore/mindrecord/tools/tfrecord_to_mr.py +1 -1
- mindspore/mindspore_backend_common.dll +0 -0
- mindspore/mindspore_backend_manager.dll +0 -0
- mindspore/mindspore_cluster.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_cpu.dll +0 -0
- mindspore/mindspore_dump.dll +0 -0
- mindspore/mindspore_frontend.dll +0 -0
- mindspore/mindspore_glog.dll +0 -0
- mindspore/mindspore_hardware_abstract.dll +0 -0
- mindspore/mindspore_memory_pool.dll +0 -0
- mindspore/mindspore_ms_backend.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/{mindspore_ops_host.dll → mindspore_ops_cpu.dll} +0 -0
- mindspore/mindspore_profiler.dll +0 -0
- mindspore/mindspore_pyboost.dll +0 -0
- mindspore/mindspore_pynative.dll +0 -0
- mindspore/mindspore_runtime_pipeline.dll +0 -0
- mindspore/mindspore_runtime_utils.dll +0 -0
- mindspore/mindspore_tools.dll +0 -0
- mindspore/mint/__init__.py +15 -10
- mindspore/mint/distributed/distributed.py +182 -62
- mindspore/mint/nn/__init__.py +2 -16
- mindspore/mint/nn/functional.py +4 -110
- mindspore/mint/nn/layer/__init__.py +0 -2
- mindspore/mint/nn/layer/activation.py +0 -6
- mindspore/mint/nn/layer/basic.py +0 -47
- mindspore/mint/nn/layer/conv.py +4 -4
- mindspore/mint/nn/layer/normalization.py +8 -13
- mindspore/mint/nn/layer/pooling.py +0 -4
- mindspore/nn/__init__.py +1 -3
- mindspore/nn/cell.py +16 -66
- mindspore/nn/layer/basic.py +49 -1
- mindspore/nn/layer/container.py +16 -0
- mindspore/nn/layer/embedding.py +4 -169
- mindspore/nn/layer/normalization.py +2 -1
- mindspore/nn/layer/thor_layer.py +4 -85
- mindspore/nn/optim/ada_grad.py +0 -1
- mindspore/nn/optim/adafactor.py +0 -1
- mindspore/nn/optim/adam.py +31 -124
- mindspore/nn/optim/adamax.py +0 -1
- mindspore/nn/optim/asgd.py +0 -1
- mindspore/nn/optim/ftrl.py +8 -102
- mindspore/nn/optim/lamb.py +0 -1
- mindspore/nn/optim/lars.py +0 -3
- mindspore/nn/optim/lazyadam.py +25 -218
- mindspore/nn/optim/momentum.py +5 -43
- mindspore/nn/optim/optimizer.py +6 -55
- mindspore/nn/optim/proximal_ada_grad.py +0 -1
- mindspore/nn/optim/rmsprop.py +0 -1
- mindspore/nn/optim/rprop.py +0 -1
- mindspore/nn/optim/sgd.py +0 -1
- mindspore/nn/optim/tft_wrapper.py +0 -1
- mindspore/nn/optim/thor.py +0 -2
- mindspore/nn/probability/bijector/bijector.py +7 -8
- mindspore/nn/probability/bijector/gumbel_cdf.py +2 -2
- mindspore/nn/probability/bijector/power_transform.py +20 -21
- mindspore/nn/probability/bijector/scalar_affine.py +5 -5
- mindspore/nn/probability/bijector/softplus.py +13 -14
- mindspore/nn/wrap/grad_reducer.py +4 -74
- mindspore/numpy/array_creations.py +2 -2
- mindspore/numpy/fft.py +9 -9
- mindspore/{nn/reinforcement → onnx}/__init__.py +5 -8
- mindspore/onnx/onnx_export.py +137 -0
- mindspore/opencv_core4110.dll +0 -0
- mindspore/opencv_imgcodecs4110.dll +0 -0
- mindspore/{opencv_imgproc452.dll → opencv_imgproc4110.dll} +0 -0
- mindspore/ops/__init__.py +2 -0
- mindspore/ops/_grad_experimental/grad_comm_ops.py +38 -2
- mindspore/ops/_op_impl/aicpu/__init__.py +0 -10
- mindspore/ops/_op_impl/cpu/__init__.py +0 -5
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +16 -22
- mindspore/ops/auto_generate/gen_extend_func.py +2 -7
- mindspore/ops/auto_generate/gen_ops_def.py +98 -141
- mindspore/ops/auto_generate/gen_ops_prim.py +12708 -12686
- mindspore/ops/communication.py +97 -0
- mindspore/ops/composite/__init__.py +5 -2
- mindspore/ops/composite/base.py +15 -1
- mindspore/ops/composite/multitype_ops/__init__.py +3 -1
- mindspore/ops/composite/multitype_ops/_compile_utils.py +150 -8
- mindspore/ops/composite/multitype_ops/add_impl.py +7 -0
- mindspore/ops/composite/multitype_ops/mod_impl.py +27 -0
- mindspore/ops/function/__init__.py +1 -0
- mindspore/ops/function/array_func.py +14 -12
- mindspore/ops/function/comm_func.py +3883 -0
- mindspore/ops/function/debug_func.py +3 -4
- mindspore/ops/function/math_func.py +45 -54
- mindspore/ops/function/nn_func.py +75 -294
- mindspore/ops/function/random_func.py +9 -18
- mindspore/ops/functional.py +2 -0
- mindspore/ops/functional_overload.py +354 -18
- mindspore/ops/operations/__init__.py +2 -5
- mindspore/ops/operations/_custom_ops_utils.py +7 -9
- mindspore/ops/operations/_inner_ops.py +1 -38
- mindspore/ops/operations/_rl_inner_ops.py +0 -933
- mindspore/ops/operations/array_ops.py +1 -0
- mindspore/ops/operations/comm_ops.py +94 -2
- mindspore/ops/operations/custom_ops.py +228 -19
- mindspore/ops/operations/debug_ops.py +27 -29
- mindspore/ops/operations/manually_defined/ops_def.py +27 -306
- mindspore/ops/operations/nn_ops.py +2 -2
- mindspore/ops/operations/sparse_ops.py +0 -83
- mindspore/ops/primitive.py +1 -17
- mindspore/ops/tensor_method.py +72 -3
- mindspore/ops_generate/aclnn/aclnn_kernel_register_auto_cc_generator.py +5 -5
- mindspore/ops_generate/aclnn/gen_aclnn_implement.py +8 -8
- mindspore/ops_generate/api/functions_cc_generator.py +53 -4
- mindspore/ops_generate/api/tensor_func_reg_cpp_generator.py +25 -11
- mindspore/ops_generate/common/gen_constants.py +11 -10
- mindspore/ops_generate/common/op_proto.py +18 -1
- mindspore/ops_generate/common/template.py +102 -245
- mindspore/ops_generate/common/template_utils.py +212 -0
- mindspore/ops_generate/gen_custom_ops.py +69 -0
- mindspore/ops_generate/op_def/ops_def_cc_generator.py +78 -7
- mindspore/ops_generate/op_def_py/base_op_prim_py_generator.py +360 -0
- mindspore/ops_generate/op_def_py/custom_op_prim_py_generator.py +140 -0
- mindspore/ops_generate/op_def_py/op_def_py_generator.py +54 -7
- mindspore/ops_generate/op_def_py/op_prim_py_generator.py +5 -312
- mindspore/ops_generate/pyboost/auto_grad_impl_cc_generator.py +74 -17
- mindspore/ops_generate/pyboost/auto_grad_reg_cc_generator.py +22 -5
- mindspore/ops_generate/pyboost/op_template_parser.py +3 -2
- mindspore/ops_generate/pyboost/pyboost_functions_cpp_generator.py +21 -5
- mindspore/ops_generate/pyboost/pyboost_functions_h_generator.py +2 -2
- mindspore/ops_generate/pyboost/pyboost_functions_impl_cpp_generator.py +30 -10
- mindspore/ops_generate/pyboost/pyboost_grad_function_cpp_generator.py +10 -3
- mindspore/ops_generate/pyboost/pyboost_internal_kernel_info_adapter_generator.py +1 -1
- mindspore/ops_generate/pyboost/pyboost_native_grad_functions_generator.py +19 -9
- mindspore/ops_generate/pyboost/pyboost_op_cpp_code_generator.py +71 -28
- mindspore/ops_generate/pyboost/pyboost_overload_functions_cpp_generator.py +10 -9
- mindspore/ops_generate/pyboost/pyboost_utils.py +27 -16
- mindspore/ops_generate/resources/yaml_loader.py +13 -0
- mindspore/ops_generate/tensor_py_cc_generator.py +2 -2
- mindspore/parallel/_cell_wrapper.py +1 -1
- mindspore/parallel/_parallel_serialization.py +1 -4
- mindspore/parallel/_utils.py +29 -6
- mindspore/parallel/checkpoint_transform.py +18 -2
- mindspore/parallel/cluster/process_entity/_api.py +24 -32
- mindspore/parallel/cluster/process_entity/_utils.py +9 -5
- mindspore/{experimental/llm_boost/atb → parallel/distributed}/__init__.py +21 -23
- mindspore/parallel/distributed/distributed_data_parallel.py +393 -0
- mindspore/parallel/distributed/flatten_grad_buffer.py +295 -0
- mindspore/parallel/strategy.py +336 -0
- mindspore/parallel/transform_safetensors.py +117 -16
- mindspore/profiler/analysis/viewer/ascend_kernel_details_viewer.py +3 -0
- mindspore/profiler/analysis/viewer/ms_minddata_viewer.py +1 -1
- mindspore/profiler/common/constant.py +5 -0
- mindspore/profiler/common/file_manager.py +9 -0
- mindspore/profiler/common/msprof_cmd_tool.py +38 -2
- mindspore/profiler/common/path_manager.py +56 -24
- mindspore/profiler/common/profiler_context.py +2 -12
- mindspore/profiler/common/profiler_info.py +3 -3
- mindspore/profiler/common/profiler_path_manager.py +13 -0
- mindspore/profiler/common/util.py +30 -3
- mindspore/profiler/experimental_config.py +2 -1
- mindspore/profiler/platform/npu_profiler.py +33 -6
- mindspore/run_check/_check_version.py +108 -24
- mindspore/runtime/__init__.py +3 -2
- mindspore/runtime/executor.py +11 -3
- mindspore/runtime/memory.py +112 -0
- mindspore/swresample-4.dll +0 -0
- mindspore/swscale-6.dll +0 -0
- mindspore/tinyxml2.dll +0 -0
- mindspore/{experimental/llm_boost → tools}/__init__.py +5 -5
- mindspore/tools/data_dump.py +130 -0
- mindspore/tools/sdc_detect.py +91 -0
- mindspore/tools/stress_detect.py +63 -0
- mindspore/train/__init__.py +6 -6
- mindspore/train/_utils.py +5 -18
- mindspore/train/amp.py +6 -4
- mindspore/train/callback/_checkpoint.py +0 -9
- mindspore/train/callback/_train_fault_tolerance.py +69 -18
- mindspore/train/data_sink.py +1 -5
- mindspore/train/model.py +38 -211
- mindspore/train/serialization.py +126 -387
- mindspore/turbojpeg.dll +0 -0
- mindspore/utils/__init__.py +6 -3
- mindspore/utils/dlpack.py +92 -0
- mindspore/utils/dryrun.py +1 -1
- mindspore/utils/runtime_execution_order_check.py +10 -0
- mindspore/utils/sdc_detect.py +14 -12
- mindspore/utils/stress_detect.py +43 -0
- mindspore/utils/utils.py +144 -8
- mindspore/version.py +1 -1
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/METADATA +3 -2
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/RECORD +254 -267
- mindspore/experimental/llm_boost/ascend_native/llama_boost_ascend_native.py +0 -210
- mindspore/experimental/llm_boost/ascend_native/llm_boost.py +0 -52
- mindspore/experimental/llm_boost/atb/boost_base.py +0 -385
- mindspore/experimental/llm_boost/atb/llama_boost.py +0 -137
- mindspore/experimental/llm_boost/atb/qwen_boost.py +0 -124
- mindspore/experimental/llm_boost/register.py +0 -130
- mindspore/experimental/llm_boost/utils.py +0 -31
- mindspore/include/OWNERS +0 -7
- mindspore/mindspore_cpu_res_manager.dll +0 -0
- mindspore/mindspore_ops_kernel_common.dll +0 -0
- mindspore/mindspore_res_manager.dll +0 -0
- mindspore/nn/optim/_dist_optimizer_registry.py +0 -111
- mindspore/nn/reinforcement/_batch_read_write.py +0 -142
- mindspore/nn/reinforcement/_tensors_queue.py +0 -152
- mindspore/nn/reinforcement/tensor_array.py +0 -145
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +0 -113
- mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +0 -96
- mindspore/ops/_op_impl/aicpu/sparse_cross.py +0 -42
- mindspore/ops/_op_impl/cpu/buffer_append.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_get.py +0 -28
- mindspore/ops/_op_impl/cpu/buffer_sample.py +0 -28
- mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +0 -42
- mindspore/ops/operations/_tensor_array.py +0 -359
- mindspore/ops/operations/rl_ops.py +0 -288
- mindspore/parallel/_offload_context.py +0 -275
- mindspore/parallel/_recovery_context.py +0 -115
- mindspore/parallel/_transformer/__init__.py +0 -35
- mindspore/parallel/_transformer/layers.py +0 -765
- mindspore/parallel/_transformer/loss.py +0 -251
- mindspore/parallel/_transformer/moe.py +0 -693
- mindspore/parallel/_transformer/op_parallel_config.py +0 -222
- mindspore/parallel/_transformer/transformer.py +0 -3124
- mindspore/parallel/mpi/_mpi_config.py +0 -116
- mindspore/train/memory_profiling_pb2.py +0 -298
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/WHEEL +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.7.0.dist-info → mindspore-2.7.1.dist-info}/top_level.txt +0 -0
mindspore/nn/layer/embedding.py
CHANGED
|
@@ -17,17 +17,12 @@ from __future__ import absolute_import
|
|
|
17
17
|
|
|
18
18
|
import mindspore.common.dtype as mstype
|
|
19
19
|
import mindspore.ops as ops
|
|
20
|
-
from mindspore import log as logger
|
|
21
20
|
from mindspore.common.tensor import Tensor
|
|
22
21
|
from mindspore.common.parameter import Parameter
|
|
23
|
-
from mindspore.common.parameter import _get_unique_parameter_key
|
|
24
22
|
from mindspore.common.initializer import initializer, Normal
|
|
25
|
-
from mindspore.communication.management import get_group_size
|
|
23
|
+
from mindspore.communication.management import get_group_size
|
|
26
24
|
from mindspore.context import ParallelMode
|
|
27
25
|
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
|
|
28
|
-
from mindspore.parallel._ps_context import _get_ps_context, _enable_distributed_mindrt
|
|
29
|
-
from mindspore.parallel._ps_context import _is_role_worker, _is_role_pserver
|
|
30
|
-
from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _set_rank_id
|
|
31
26
|
from mindspore import _checkparam as Validator
|
|
32
27
|
from mindspore.ops.primitive import constexpr, _primexpr
|
|
33
28
|
from mindspore.nn.layer.basic import ClipByNorm
|
|
@@ -341,10 +336,6 @@ class EmbeddingLookup(Cell):
|
|
|
341
336
|
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
|
|
342
337
|
or None. Default: ``None`` .
|
|
343
338
|
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: ``True`` .
|
|
344
|
-
vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: ``0`` . It is valid only in
|
|
345
|
-
parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
|
|
346
|
-
optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
|
|
347
|
-
memory, so suggests setting a reasonable value to avoid insufficient memory.
|
|
348
339
|
dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
|
|
349
340
|
|
|
350
341
|
Inputs:
|
|
@@ -358,10 +349,9 @@ class EmbeddingLookup(Cell):
|
|
|
358
349
|
Tensor, the shape of tensor is :math:`(z_1, z_2, ..., z_N)`.
|
|
359
350
|
|
|
360
351
|
Raises:
|
|
361
|
-
TypeError: If `vocab_size` or `embedding_size`
|
|
352
|
+
TypeError: If `vocab_size` or `embedding_size` is not an int.
|
|
362
353
|
TypeError: If `sparse` is not a bool or `manual_shapes` is not a tuple.
|
|
363
354
|
ValueError: If `vocab_size` or `embedding_size` is less than 1.
|
|
364
|
-
ValueError: If `vocab_cache_size` is less than 0.
|
|
365
355
|
ValueError: If `target` is neither 'CPU' nor 'DEVICE'.
|
|
366
356
|
ValueError: If `slice_mode` is not one of 'batch_slice' or 'field_slice' or
|
|
367
357
|
'table_row_slice' or 'table_column_slice'.
|
|
@@ -387,17 +377,14 @@ class EmbeddingLookup(Cell):
|
|
|
387
377
|
|
|
388
378
|
def __init__(self, vocab_size, embedding_size, param_init='normal',
|
|
389
379
|
target='CPU', slice_mode='batch_slice', manual_shapes=None,
|
|
390
|
-
max_norm=None, sparse=True,
|
|
380
|
+
max_norm=None, sparse=True, dtype=mstype.float32):
|
|
391
381
|
"""Initialize EmbeddingLookup."""
|
|
392
382
|
super(EmbeddingLookup, self).__init__()
|
|
393
383
|
Validator.check_value_type('sparse', sparse, [bool], self.cls_name)
|
|
394
384
|
self.vocab_size = Validator.check_positive_int(
|
|
395
385
|
vocab_size, 'vocab_size')
|
|
396
|
-
self.vocab_cache_size = Validator.check_non_negative_int(
|
|
397
|
-
vocab_cache_size, 'vocab_cache_size')
|
|
398
386
|
self.target = target
|
|
399
387
|
self.sparse = sparse
|
|
400
|
-
self.cache_enable = self.vocab_cache_size > 0
|
|
401
388
|
self.forward_unique = False
|
|
402
389
|
Validator.check_string(
|
|
403
390
|
target, ['CPU', 'DEVICE'], 'target', self.cls_name)
|
|
@@ -409,10 +396,6 @@ class EmbeddingLookup(Cell):
|
|
|
409
396
|
else:
|
|
410
397
|
self.gatherv2 = ops.Gather()
|
|
411
398
|
self.embeddinglookup = ops.EmbeddingLookup().set_device('CPU')
|
|
412
|
-
self.is_ps_server = False
|
|
413
|
-
enable_ps = _get_ps_context("enable_ps")
|
|
414
|
-
if enable_ps:
|
|
415
|
-
self._process_vocab_cache(slice_mode)
|
|
416
399
|
self.embedding_size = Validator.check_positive_int(
|
|
417
400
|
embedding_size, 'embedding_size', self.cls_name)
|
|
418
401
|
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size],
|
|
@@ -427,11 +410,6 @@ class EmbeddingLookup(Cell):
|
|
|
427
410
|
self.shape = ops.Shape()
|
|
428
411
|
if is_auto_parallel:
|
|
429
412
|
self.unique = ops.Unique().shard(((1,),))
|
|
430
|
-
if self.cache_enable and enable_ps:
|
|
431
|
-
self._set_voacb_cache_enable_for_ps(
|
|
432
|
-
vocab_cache_size, embedding_size, vocab_size, param_init, dtype=dtype)
|
|
433
|
-
if is_auto_parallel:
|
|
434
|
-
self.unique.add_prim_attr('cache_enable', True)
|
|
435
413
|
indices_shape_size = 2
|
|
436
414
|
if slice_mode == "field_slice" and is_auto_parallel:
|
|
437
415
|
if not manual_shapes:
|
|
@@ -450,7 +428,7 @@ class EmbeddingLookup(Cell):
|
|
|
450
428
|
((get_group_size(), 1), (1, get_group_size())))
|
|
451
429
|
elif slice_mode == "table_row_slice" and is_auto_parallel:
|
|
452
430
|
full_batch = _get_full_batch()
|
|
453
|
-
if (target == 'DEVICE' and not full_batch)
|
|
431
|
+
if (target == 'DEVICE' and not full_batch):
|
|
454
432
|
indices_shape_size = 1
|
|
455
433
|
self.gather_revert.shard(((1, 1), (get_group_size(),)))
|
|
456
434
|
self.forward_unique = True
|
|
@@ -479,9 +457,6 @@ class EmbeddingLookup(Cell):
|
|
|
479
457
|
"table_column_slice", "batch_slice"]
|
|
480
458
|
raise ValueError(f"For '{self.cls_name}', the 'slice_mode' must be in {support_mode}, "
|
|
481
459
|
f"but got \"{slice_mode}\".")
|
|
482
|
-
if self.cache_enable and not enable_ps:
|
|
483
|
-
raise ValueError(
|
|
484
|
-
f"For '{self.cls_name}', haven't supported cache enable for not ps mode.")
|
|
485
460
|
self.embedding_table.unique = self.forward_unique
|
|
486
461
|
self.max_norm = max_norm
|
|
487
462
|
if self.max_norm is not None:
|
|
@@ -489,149 +464,9 @@ class EmbeddingLookup(Cell):
|
|
|
489
464
|
self.max_norm, 'max_norm', self.cls_name)
|
|
490
465
|
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
|
|
491
466
|
|
|
492
|
-
def _process_vocab_cache(self, slice_mode):
|
|
493
|
-
"""PS embeddingLookup cache check and process."""
|
|
494
|
-
self.cache_enable = False
|
|
495
|
-
if self.vocab_cache_size > 0:
|
|
496
|
-
if self.target == 'CPU':
|
|
497
|
-
logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
|
|
498
|
-
"current target is CPU, so it will be ignored.")
|
|
499
|
-
return
|
|
500
|
-
enable_ps = _get_ps_context("enable_ps")
|
|
501
|
-
if not enable_ps:
|
|
502
|
-
logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server training "
|
|
503
|
-
"mode, current mode is not parameter server trainning mode, so it will be ignored.")
|
|
504
|
-
return
|
|
505
|
-
self.is_ps_server = _is_role_pserver() and _enable_distributed_mindrt()
|
|
506
|
-
parallel_mode = _get_parallel_mode()
|
|
507
|
-
is_auto_parallel = parallel_mode in (
|
|
508
|
-
ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
509
|
-
if is_auto_parallel:
|
|
510
|
-
rank_size = get_group_size()
|
|
511
|
-
rank_id = get_rank()
|
|
512
|
-
full_batch = _get_full_batch()
|
|
513
|
-
if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"):
|
|
514
|
-
raise ValueError(f"For '{self.cls_name}', the cache of parameter server parallel should only be "
|
|
515
|
-
f"used in \"full_batch\" and the value of \"full_batch\" must be True. "
|
|
516
|
-
f"Meanwhile, the value of 'slice_mode' must be \"table_row_slice\"."
|
|
517
|
-
f"But got full_batch: {full_batch} and 'slice_mode': \"{slice_mode}\".")
|
|
518
|
-
self.vocab_cache_size = self.vocab_cache_size * rank_size
|
|
519
|
-
_set_rank_id(rank_id)
|
|
520
|
-
|
|
521
|
-
self.cache_enable = True
|
|
522
|
-
_set_cache_enable(True)
|
|
523
|
-
|
|
524
|
-
if _is_role_worker():
|
|
525
|
-
self.vocab_size = self.vocab_cache_size
|
|
526
|
-
|
|
527
|
-
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size, param_init,
|
|
528
|
-
dtype=mstype.float32):
|
|
529
|
-
"""PS embeddingLookup cache enable set."""
|
|
530
|
-
if self.sparse:
|
|
531
|
-
self.forward_unique = True
|
|
532
|
-
param_key = _get_unique_parameter_key()
|
|
533
|
-
if _is_role_worker():
|
|
534
|
-
self.embedding_table.is_param_ps = True
|
|
535
|
-
self.embedding_table.cache_enable = True
|
|
536
|
-
self.embedding_table.key = param_key
|
|
537
|
-
_insert_hash_table_size(
|
|
538
|
-
self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size, param_key)
|
|
539
|
-
|
|
540
|
-
if _enable_distributed_mindrt():
|
|
541
|
-
self.rank_id = get_rank()
|
|
542
|
-
if self.is_ps_server:
|
|
543
|
-
self._slice_pserver_embeddings("zeros", dtype=dtype)
|
|
544
|
-
self._set_cache_enable_and_key_for_pserver(param_key)
|
|
545
|
-
|
|
546
|
-
def _slice_pserver_embeddings(self, param_init, dtype=mstype.float32):
|
|
547
|
-
'''
|
|
548
|
-
Method to slice embedding tables on Parameter Servers.
|
|
549
|
-
It helps to train with a large scale embedding table and is used only in Parameter Server training mode.
|
|
550
|
-
So EmbeddingLookup op is on CPU device.
|
|
551
|
-
'''
|
|
552
|
-
self.embedding_lookup_list = []
|
|
553
|
-
# The dimension of each embedding table on servers could be different according to the slicing algorithm.
|
|
554
|
-
self.embedding_table_vocab_dim_list = []
|
|
555
|
-
self.embedding_table_list = []
|
|
556
|
-
# For different servers, the offset of their embedding table should be different.
|
|
557
|
-
self.embedding_offset = []
|
|
558
|
-
|
|
559
|
-
server_num = _get_ps_context("server_num")
|
|
560
|
-
if server_num == 0:
|
|
561
|
-
raise ValueError("The Parameter Server number is zero.")
|
|
562
|
-
# Assign the embedding table dimensions.
|
|
563
|
-
for _ in range(server_num):
|
|
564
|
-
self.embedding_table_vocab_dim_list.append(
|
|
565
|
-
self.vocab_size // server_num)
|
|
566
|
-
rest_vocab_size = self.vocab_size % server_num
|
|
567
|
-
if rest_vocab_size != 0:
|
|
568
|
-
for i in range(rest_vocab_size):
|
|
569
|
-
self.embedding_table_vocab_dim_list[i] += 1
|
|
570
|
-
|
|
571
|
-
offset = 0
|
|
572
|
-
for i in range(server_num):
|
|
573
|
-
self.embedding_table_list.append(Parameter(initializer(param_init,
|
|
574
|
-
[self.embedding_table_vocab_dim_list[i],
|
|
575
|
-
self.embedding_size], dtype=dtype),
|
|
576
|
-
name="embedding_table_server_" + str(i)))
|
|
577
|
-
|
|
578
|
-
self.embedding_offset.append(offset)
|
|
579
|
-
offset += self.embedding_table_vocab_dim_list[i]
|
|
580
|
-
|
|
581
|
-
# Add EmbeddingLookup ops on different servers.
|
|
582
|
-
if self.target == 'CPU':
|
|
583
|
-
embedding_lookup = ops.EmbeddingLookup().set_device('CPU')
|
|
584
|
-
else:
|
|
585
|
-
if self.sparse:
|
|
586
|
-
embedding_lookup = ops.SparseGatherV2()
|
|
587
|
-
else:
|
|
588
|
-
embedding_lookup = ops.Gather()
|
|
589
|
-
embedding_lookup.add_prim_attr(
|
|
590
|
-
'offset', self.embedding_offset[i])
|
|
591
|
-
embedding_lookup.add_prim_attr('rank_id', i)
|
|
592
|
-
embedding_lookup.add_prim_attr('ms_role', 'MS_PSERVER')
|
|
593
|
-
self.embedding_lookup_list.append(embedding_lookup)
|
|
594
|
-
|
|
595
|
-
# For now unique operation is not applied,
|
|
596
|
-
# so we need to reduce the lookup results from different servers with AddN.
|
|
597
|
-
self.reduce_lookup_result = ops.AddN()
|
|
598
|
-
|
|
599
|
-
def _do_server_embedding_lookup(self, indices):
|
|
600
|
-
'''
|
|
601
|
-
Construct backbone for EmbeddingLookup operators on servers.
|
|
602
|
-
'''
|
|
603
|
-
result_from_servers = []
|
|
604
|
-
for i in range(_get_ps_context("server_num")):
|
|
605
|
-
result = self.embedding_lookup_list[i](self.embedding_table_list[i],
|
|
606
|
-
indices, self.embedding_offset[i])
|
|
607
|
-
result_from_servers.append(result)
|
|
608
|
-
final_result = self.reduce_lookup_result(result_from_servers)
|
|
609
|
-
return final_result
|
|
610
|
-
|
|
611
|
-
def _set_cache_enable_and_key_for_pserver(self, param_key):
|
|
612
|
-
'''
|
|
613
|
-
Set cache enable and parameter key for embedding table on parameter servers.
|
|
614
|
-
'''
|
|
615
|
-
# Parameter The Embedding Table on the Server side will be divided according to the number of servers.
|
|
616
|
-
# The divided Embedding Table will be used instead of the complete Embedding Table.
|
|
617
|
-
self.embedding_table = self.embedding_table_list[self.rank_id]
|
|
618
|
-
self.embedding_table.cache_enable = True
|
|
619
|
-
self.embedding_table.key = param_key
|
|
620
|
-
|
|
621
|
-
def _pserver_embedding_lookup(self, indices):
|
|
622
|
-
'''
|
|
623
|
-
Construct backbone for EmbeddingLookup operators on servers for embedding cache lookup.
|
|
624
|
-
'''
|
|
625
|
-
if self.target == 'CPU':
|
|
626
|
-
return self.embedding_lookup_list[self.rank_id](self.embedding_table, indices,
|
|
627
|
-
self.embedding_offset[self.rank_id])
|
|
628
|
-
return self.embedding_lookup_list[self.rank_id](self.embedding_table, indices, 0)
|
|
629
|
-
|
|
630
467
|
def construct(self, indices):
|
|
631
468
|
if self.target == "CPU":
|
|
632
469
|
out = self.embeddinglookup(self.embedding_table, indices, 0)
|
|
633
|
-
elif self.is_ps_server:
|
|
634
|
-
out = self._pserver_embedding_lookup(indices)
|
|
635
470
|
else:
|
|
636
471
|
if self.forward_unique:
|
|
637
472
|
shp = self.shape(indices) + (self.embedding_size,)
|
|
@@ -21,6 +21,7 @@ import numbers
|
|
|
21
21
|
import hashlib
|
|
22
22
|
import numpy as np
|
|
23
23
|
import mindspore.ops as ops
|
|
24
|
+
from mindspore.ops import operations as P
|
|
24
25
|
from mindspore.ops.operations import _inner_ops as inner
|
|
25
26
|
from mindspore.common.parameter import Parameter
|
|
26
27
|
from mindspore.common.initializer import initializer, Initializer
|
|
@@ -917,7 +918,7 @@ class _InstanceNorm(Cell):
|
|
|
917
918
|
|
|
918
919
|
self.shape = ops.Shape()
|
|
919
920
|
self.momentum = momentum
|
|
920
|
-
self.instance_bn =
|
|
921
|
+
self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum)
|
|
921
922
|
|
|
922
923
|
def construct(self, x):
|
|
923
924
|
self._check_input_dim(self.shape(x), self.cls_name)
|
mindspore/nn/layer/thor_layer.py
CHANGED
|
@@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
|
|
|
22
22
|
import mindspore.log as logger
|
|
23
23
|
from mindspore.common.tensor import Tensor
|
|
24
24
|
from mindspore.common.initializer import initializer, Initializer
|
|
25
|
-
from mindspore.communication.management import get_group_size
|
|
25
|
+
from mindspore.communication.management import get_group_size
|
|
26
26
|
from mindspore.ops.operations._thor_ops import ThorIm2Col
|
|
27
27
|
from mindspore.common.parameter import Parameter
|
|
28
28
|
from mindspore import _checkparam as Validator
|
|
@@ -30,8 +30,6 @@ from mindspore._checkparam import twice
|
|
|
30
30
|
from mindspore import context
|
|
31
31
|
from mindspore.nn.cell import Cell
|
|
32
32
|
from mindspore.nn.layer.activation import get_activation
|
|
33
|
-
from mindspore.parallel._ps_context import _is_role_worker, _get_ps_context, \
|
|
34
|
-
_set_rank_id, _insert_hash_table_size, _set_cache_enable
|
|
35
33
|
from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch
|
|
36
34
|
from mindspore.context import ParallelMode
|
|
37
35
|
from mindspore.nn.layer.basic import ClipByNorm
|
|
@@ -695,10 +693,6 @@ class EmbeddingLookupThor(Cell):
|
|
|
695
693
|
Default: ``None`` .
|
|
696
694
|
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be ``true`` .
|
|
697
695
|
Default: ``True`` .
|
|
698
|
-
vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: ``0`` . It is valid only in
|
|
699
|
-
'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size.
|
|
700
|
-
In addition, it should be noted that it will cost the 'DEVICE' memory, so suggests setting a reasonable
|
|
701
|
-
value to avoid insufficient memory.
|
|
702
696
|
|
|
703
697
|
Inputs:
|
|
704
698
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
@@ -712,10 +706,9 @@ class EmbeddingLookupThor(Cell):
|
|
|
712
706
|
'table_row_slice' or 'table_column_slice'.
|
|
713
707
|
ValueError: If `sparse` is False and `target` is 'CPU'.
|
|
714
708
|
ValueError: If `slice_mode` is 'field_slice' and `manual_shapes` is None.
|
|
715
|
-
TypeError: If `vocab_size` or `embedding_size`
|
|
709
|
+
TypeError: If `vocab_size` or `embedding_size` is not an int.
|
|
716
710
|
TypeError: If `sparse` is not a bool or `manual_shapes` is not a tuple.
|
|
717
711
|
ValueError: If `vocab_size` or `embedding_size` is less than 1.
|
|
718
|
-
ValueError: If `vocab_cache_size` is less than 0.
|
|
719
712
|
|
|
720
713
|
|
|
721
714
|
Supported Platforms:
|
|
@@ -736,14 +729,12 @@ class EmbeddingLookupThor(Cell):
|
|
|
736
729
|
|
|
737
730
|
def __init__(self, vocab_size, embedding_size, param_init='normal',
|
|
738
731
|
target='CPU', slice_mode='batch_slice', manual_shapes=None,
|
|
739
|
-
max_norm=None, sparse=True
|
|
732
|
+
max_norm=None, sparse=True):
|
|
740
733
|
super(EmbeddingLookupThor, self).__init__()
|
|
741
734
|
Validator.check_value_type('sparse', sparse, [bool], self.cls_name)
|
|
742
735
|
self.vocab_size = Validator.check_positive_int(vocab_size, 'vocab_size', self.cls_name)
|
|
743
|
-
self.vocab_cache_size = Validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size', self.cls_name)
|
|
744
736
|
self.target = target
|
|
745
737
|
self.sparse = sparse
|
|
746
|
-
self.cache_enable = self.vocab_cache_size > 0
|
|
747
738
|
self.forward_unique = False
|
|
748
739
|
self.dtype = mstype.float16
|
|
749
740
|
if target not in ('CPU', 'DEVICE'):
|
|
@@ -757,9 +748,6 @@ class EmbeddingLookupThor(Cell):
|
|
|
757
748
|
else:
|
|
758
749
|
self.gatherv2 = ops.Gather()
|
|
759
750
|
self.embeddinglookup = ops.EmbeddingLookup().set_device('CPU')
|
|
760
|
-
enable_ps = _get_ps_context("enable_ps")
|
|
761
|
-
if enable_ps:
|
|
762
|
-
self._process_vocab_cache(slice_mode)
|
|
763
751
|
self.embedding_size = Validator.check_positive_int(embedding_size, 'embedding_size', self.cls_name)
|
|
764
752
|
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size],
|
|
765
753
|
mstype.float16), name='embedding_table')
|
|
@@ -772,10 +760,6 @@ class EmbeddingLookupThor(Cell):
|
|
|
772
760
|
self.shape = ops.Shape()
|
|
773
761
|
if is_auto_parallel:
|
|
774
762
|
self.unique = ops.Unique().shard(((1,),))
|
|
775
|
-
if self.cache_enable and enable_ps:
|
|
776
|
-
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size)
|
|
777
|
-
if is_auto_parallel:
|
|
778
|
-
self.unique.add_prim_attr('cache_enable', True)
|
|
779
763
|
indices_shape_size = 2
|
|
780
764
|
if slice_mode == "field_slice" and is_auto_parallel:
|
|
781
765
|
if not manual_shapes:
|
|
@@ -792,7 +776,7 @@ class EmbeddingLookupThor(Cell):
|
|
|
792
776
|
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
|
|
793
777
|
elif slice_mode == "table_row_slice" and is_auto_parallel:
|
|
794
778
|
full_batch = _get_full_batch()
|
|
795
|
-
if (target == 'DEVICE' and not full_batch)
|
|
779
|
+
if (target == 'DEVICE' and not full_batch):
|
|
796
780
|
indices_shape_size = 1
|
|
797
781
|
self.gather_revert.shard(((1, 1), (get_group_size(),)))
|
|
798
782
|
self.forward_unique = True
|
|
@@ -818,11 +802,6 @@ class EmbeddingLookupThor(Cell):
|
|
|
818
802
|
raise ValueError(f"For '{self.cls_name}', the 'slice_mode' must be one of values in "
|
|
819
803
|
f"['field_slice', 'table_row_slice', 'table_column_slice', 'batch_slice'], "
|
|
820
804
|
f"but got 'slice_mode': {slice_mode}")
|
|
821
|
-
if self.cache_enable and not enable_ps:
|
|
822
|
-
if parallel_mode != ParallelMode.STAND_ALONE:
|
|
823
|
-
raise ValueError(f"For '{self.cls_name}', the 'parallel_mode' must be equal to "
|
|
824
|
-
f"'ParallelMode.STAND_ALONE', but got {parallel_mode}.")
|
|
825
|
-
self._set_cache_enable()
|
|
826
805
|
self.embedding_table.unique = self.forward_unique
|
|
827
806
|
self.max_norm = max_norm
|
|
828
807
|
if self.max_norm is not None:
|
|
@@ -859,66 +838,6 @@ class EmbeddingLookupThor(Cell):
|
|
|
859
838
|
self.matrix_g = matrix_g
|
|
860
839
|
return out
|
|
861
840
|
|
|
862
|
-
def _set_cache_enable(self):
|
|
863
|
-
"""EmbeddingLookup cache check for not ps env, which is only support 'ascend'."""
|
|
864
|
-
if self.target != 'DEVICE':
|
|
865
|
-
raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid "
|
|
866
|
-
f"only when 'target' is 'DEVICE', but got 'target': {self.target}.")
|
|
867
|
-
if not self.sparse:
|
|
868
|
-
raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid "
|
|
869
|
-
f"only when 'sparse' is true, but got 'sparse': {self.sparse}.")
|
|
870
|
-
if context.get_context("device_target") != 'Ascend':
|
|
871
|
-
raise ValueError(f"For '{self.cls_name}', the configuration of 'vocab_cache_size' is valid "
|
|
872
|
-
f"only when 'device_target' is 'Ascend', but got {context.get_context('device_target')}.")
|
|
873
|
-
|
|
874
|
-
logger.info("EmbeddingLookup cache enable takes effect.")
|
|
875
|
-
self.forward_unique = True
|
|
876
|
-
self.unique = ops.Unique().set_device('CPU')
|
|
877
|
-
self.unique.add_prim_attr('cache_enable', True)
|
|
878
|
-
self.embedding_table.cache_enable = self.cache_enable
|
|
879
|
-
self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size)
|
|
880
|
-
self.reshape_first = ops.Reshape().set_device('CPU')
|
|
881
|
-
|
|
882
|
-
def _process_vocab_cache(self, slice_mode):
|
|
883
|
-
"""PS embeddingLookup cache check and process."""
|
|
884
|
-
self.cache_enable = False
|
|
885
|
-
if self.vocab_cache_size > 0:
|
|
886
|
-
if self.target == 'CPU':
|
|
887
|
-
logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
|
|
888
|
-
"current target is CPU, so it will be ignored.")
|
|
889
|
-
return
|
|
890
|
-
enable_ps = _get_ps_context("enable_ps")
|
|
891
|
-
if not enable_ps:
|
|
892
|
-
logger.warning(
|
|
893
|
-
"The configuration of 'vocab_cache_size' is valid only in parameter server trainning "
|
|
894
|
-
"mode, current mode is not parameter server trainning mode, so it will be ignored.")
|
|
895
|
-
return
|
|
896
|
-
parallel_mode = _get_parallel_mode()
|
|
897
|
-
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
898
|
-
if is_auto_parallel:
|
|
899
|
-
rank_size = get_group_size()
|
|
900
|
-
rank_id = get_rank()
|
|
901
|
-
full_batch = _get_full_batch()
|
|
902
|
-
if rank_size > 1 and not (full_batch and slice_mode == "table_row_slice"):
|
|
903
|
-
raise ValueError(f"For '{self.cls_name}', the embeddingLookup cache of parameter server parallel "
|
|
904
|
-
f"only be used in 'full_batch' and 'table_row_slice' parallel strategy, but got "
|
|
905
|
-
f"'full_batch': {full_batch}, 'slice_mode': {slice_mode}.")
|
|
906
|
-
self.vocab_cache_size = self.vocab_cache_size * rank_size
|
|
907
|
-
_set_rank_id(rank_id)
|
|
908
|
-
self.cache_enable = True
|
|
909
|
-
if _is_role_worker():
|
|
910
|
-
self.vocab_size = self.vocab_cache_size
|
|
911
|
-
|
|
912
|
-
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
|
|
913
|
-
"""PS embeddingLookup cache enable set."""
|
|
914
|
-
self.embedding_table.cache_enable = True
|
|
915
|
-
self.embedding_table.is_param_ps = True
|
|
916
|
-
_set_cache_enable(True)
|
|
917
|
-
if self.sparse:
|
|
918
|
-
self.forward_unique = True
|
|
919
|
-
if _is_role_worker():
|
|
920
|
-
_insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size)
|
|
921
|
-
|
|
922
841
|
def construct(self, indices):
|
|
923
842
|
if self.target == "CPU":
|
|
924
843
|
out = self.embeddinglookup(self.embedding_table, indices, 0)
|
mindspore/nn/optim/ada_grad.py
CHANGED
|
@@ -204,7 +204,6 @@ class Adagrad(Optimizer):
|
|
|
204
204
|
def construct(self, grads):
|
|
205
205
|
params = self._parameters
|
|
206
206
|
accum = self.accum
|
|
207
|
-
grads = self.flatten_gradients(grads)
|
|
208
207
|
grads = self.decay_weight(grads)
|
|
209
208
|
grads = self.gradients_centralization(grads)
|
|
210
209
|
grads = self.scale_grad(grads)
|
mindspore/nn/optim/adafactor.py
CHANGED
|
@@ -408,7 +408,6 @@ class AdaFactor(Optimizer):
|
|
|
408
408
|
|
|
409
409
|
@jit(backend="ms_backend")
|
|
410
410
|
def construct(self, gradients):
|
|
411
|
-
gradients = self.flatten_gradients(gradients)
|
|
412
411
|
lr = self.get_lr()
|
|
413
412
|
self.assignadd(self.global_step, self.global_step_increase_tensor)
|
|
414
413
|
step = F.assign_add(self.step, 1)
|