mindspore 2.3.0rc1__cp37-cp37m-manylinux1_x86_64.whl → 2.3.0rc2__cp37-cp37m-manylinux1_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +1 -1
- mindspore/_akg/akg/utils/tbe_codegen_utils.py +13 -3
- mindspore/_c_dataengine.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_c_expression.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/_checkparam.py +20 -0
- mindspore/_extends/parse/parser.py +1 -1
- mindspore/_extends/parse/standard_method.py +6 -5
- mindspore/_mindspore_offline_debug.cpython-37m-x86_64-linux-gnu.so +0 -0
- mindspore/amp.py +5 -5
- mindspore/bin/cache_admin +0 -0
- mindspore/bin/cache_server +0 -0
- mindspore/boost/boost_cell_wrapper.py +1 -1
- mindspore/boost/group_loss_scale_manager.py +1 -1
- mindspore/common/__init__.py +4 -2
- mindspore/common/_register_for_recompute.py +48 -0
- mindspore/common/_stub_tensor.py +1 -0
- mindspore/common/api.py +56 -4
- mindspore/common/dtype.py +5 -3
- mindspore/common/dump.py +2 -2
- mindspore/common/hook_handle.py +51 -4
- mindspore/common/initializer.py +1 -1
- mindspore/common/jit_config.py +17 -6
- mindspore/common/parameter.py +7 -2
- mindspore/common/recompute.py +247 -0
- mindspore/common/sparse_tensor.py +2 -2
- mindspore/common/symbol.py +1 -1
- mindspore/common/tensor.py +74 -36
- mindspore/communication/__init__.py +3 -3
- mindspore/communication/management.py +30 -30
- mindspore/context.py +28 -15
- mindspore/dataset/__init__.py +5 -5
- mindspore/dataset/audio/__init__.py +2 -2
- mindspore/dataset/audio/transforms.py +51 -51
- mindspore/dataset/callback/ds_callback.py +2 -2
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +3 -3
- mindspore/dataset/engine/datasets_audio.py +14 -14
- mindspore/dataset/engine/datasets_standard_format.py +3 -3
- mindspore/dataset/engine/datasets_text.py +38 -38
- mindspore/dataset/engine/datasets_user_defined.py +3 -3
- mindspore/dataset/engine/datasets_vision.py +68 -68
- mindspore/dataset/text/__init__.py +3 -3
- mindspore/dataset/text/transforms.py +26 -26
- mindspore/dataset/transforms/__init__.py +1 -1
- mindspore/dataset/vision/__init__.py +3 -3
- mindspore/dataset/vision/transforms.py +92 -92
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/experimental/optim/adadelta.py +2 -2
- mindspore/experimental/optim/adagrad.py +2 -2
- mindspore/experimental/optim/adam.py +2 -2
- mindspore/experimental/optim/adamax.py +2 -2
- mindspore/experimental/optim/adamw.py +2 -2
- mindspore/experimental/optim/asgd.py +2 -2
- mindspore/experimental/optim/lr_scheduler.py +24 -20
- mindspore/experimental/optim/nadam.py +2 -2
- mindspore/experimental/optim/optimizer.py +1 -1
- mindspore/experimental/optim/radam.py +2 -2
- mindspore/experimental/optim/rmsprop.py +2 -2
- mindspore/experimental/optim/rprop.py +2 -2
- mindspore/experimental/optim/sgd.py +2 -2
- mindspore/hal/stream.py +2 -0
- mindspore/include/mindapi/base/types.h +5 -0
- mindspore/lib/libdnnl.so.2 +0 -0
- mindspore/lib/libmindspore.so +0 -0
- mindspore/lib/libmindspore_backend.so +0 -0
- mindspore/lib/libmindspore_common.so +0 -0
- mindspore/lib/libmindspore_core.so +0 -0
- mindspore/lib/libmindspore_glog.so.0 +0 -0
- mindspore/lib/libmindspore_gpr.so.15 +0 -0
- mindspore/lib/libmindspore_grpc++.so.1 +0 -0
- mindspore/lib/libmindspore_grpc.so.15 +0 -0
- mindspore/lib/libmindspore_shared_lib.so +0 -0
- mindspore/lib/libopencv_core.so.4.5 +0 -0
- mindspore/lib/libopencv_imgcodecs.so.4.5 +0 -0
- mindspore/lib/libopencv_imgproc.so.4.5 +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/aicpu_kernel/impl/libcust_cpu_kernels.so +0 -0
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_impl/cpu/config/cust_aicpu_kernel.json +6 -6
- mindspore/lib/plugin/ascend/custom_aicpu_ops/op_proto/libcust_op_proto.so +0 -0
- mindspore/lib/plugin/ascend/libdvpp_utils.so +0 -0
- mindspore/lib/plugin/ascend/libmindspore_cpu_kernels.so +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.10 +0 -0
- mindspore/lib/plugin/gpu/libcuda_ops.so.11 +0 -0
- mindspore/lib/plugin/gpu10.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.1/libnccl.so.2 +0 -0
- mindspore/lib/plugin/gpu11.6/libnccl.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_ascend.so.2 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.10.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.1 +0 -0
- mindspore/lib/plugin/libmindspore_gpu.so.11.6 +0 -0
- mindspore/log.py +2 -2
- mindspore/mint/__init__.py +457 -0
- mindspore/mint/nn/__init__.py +430 -0
- mindspore/mint/nn/functional.py +424 -0
- mindspore/mint/optim/__init__.py +24 -0
- mindspore/mint/optim/adamw.py +186 -0
- mindspore/multiprocessing/__init__.py +4 -0
- mindspore/nn/__init__.py +3 -0
- mindspore/nn/cell.py +51 -47
- mindspore/nn/extend/__init__.py +29 -0
- mindspore/nn/extend/basic.py +140 -0
- mindspore/nn/extend/embedding.py +143 -0
- mindspore/nn/extend/layer/__init__.py +27 -0
- mindspore/nn/extend/layer/normalization.py +107 -0
- mindspore/nn/extend/pooling.py +117 -0
- mindspore/nn/generator.py +297 -0
- mindspore/nn/layer/basic.py +109 -1
- mindspore/nn/layer/container.py +2 -2
- mindspore/nn/layer/conv.py +6 -6
- mindspore/nn/layer/embedding.py +1 -1
- mindspore/nn/layer/normalization.py +21 -43
- mindspore/nn/layer/padding.py +4 -0
- mindspore/nn/optim/ada_grad.py +2 -2
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adafactor.py +1 -1
- mindspore/nn/optim/adam.py +7 -7
- mindspore/nn/optim/adamax.py +2 -2
- mindspore/nn/optim/adasum.py +2 -2
- mindspore/nn/optim/asgd.py +2 -2
- mindspore/nn/optim/ftrl.py +1 -1
- mindspore/nn/optim/lamb.py +3 -3
- mindspore/nn/optim/lars.py +1 -1
- mindspore/nn/optim/lazyadam.py +2 -2
- mindspore/nn/optim/momentum.py +2 -2
- mindspore/nn/optim/optimizer.py +2 -2
- mindspore/nn/optim/proximal_ada_grad.py +2 -2
- mindspore/nn/optim/rmsprop.py +2 -2
- mindspore/nn/optim/rprop.py +2 -2
- mindspore/nn/optim/sgd.py +2 -2
- mindspore/nn/optim/thor.py +2 -2
- mindspore/nn/wrap/cell_wrapper.py +9 -9
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/ops/_grad_experimental/grad_comm_ops.py +4 -2
- mindspore/ops/_vmap/vmap_grad_nn_ops.py +41 -2
- mindspore/ops/_vmap/vmap_math_ops.py +27 -8
- mindspore/ops/_vmap/vmap_nn_ops.py +66 -8
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +73 -1
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +12 -3
- mindspore/ops/auto_generate/gen_arg_handler.py +24 -0
- mindspore/ops/auto_generate/gen_extend_func.py +274 -0
- mindspore/ops/auto_generate/gen_ops_def.py +889 -22
- mindspore/ops/auto_generate/gen_ops_prim.py +3541 -253
- mindspore/ops/auto_generate/pyboost_inner_prim.py +282 -0
- mindspore/ops/composite/multitype_ops/_compile_utils.py +2 -1
- mindspore/ops/composite/multitype_ops/_constexpr_utils.py +9 -0
- mindspore/ops/extend/__init__.py +9 -1
- mindspore/ops/extend/array_func.py +134 -27
- mindspore/ops/extend/math_func.py +3 -3
- mindspore/ops/extend/nn_func.py +363 -2
- mindspore/ops/function/__init__.py +19 -2
- mindspore/ops/function/array_func.py +463 -439
- mindspore/ops/function/clip_func.py +7 -18
- mindspore/ops/function/grad/grad_func.py +5 -5
- mindspore/ops/function/linalg_func.py +4 -4
- mindspore/ops/function/math_func.py +260 -243
- mindspore/ops/function/nn_func.py +825 -62
- mindspore/ops/function/random_func.py +73 -4
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +1 -1
- mindspore/ops/functional.py +2 -2
- mindspore/ops/op_info_register.py +1 -31
- mindspore/ops/operations/__init__.py +2 -3
- mindspore/ops/operations/_grad_ops.py +2 -107
- mindspore/ops/operations/_inner_ops.py +5 -5
- mindspore/ops/operations/_sequence_ops.py +2 -2
- mindspore/ops/operations/array_ops.py +11 -233
- mindspore/ops/operations/comm_ops.py +32 -32
- mindspore/ops/operations/custom_ops.py +7 -89
- mindspore/ops/operations/manually_defined/ops_def.py +329 -4
- mindspore/ops/operations/math_ops.py +13 -163
- mindspore/ops/operations/nn_ops.py +9 -316
- mindspore/ops/operations/random_ops.py +1 -1
- mindspore/ops/operations/sparse_ops.py +3 -3
- mindspore/ops/primitive.py +2 -2
- mindspore/ops_generate/arg_dtype_cast.py +12 -3
- mindspore/ops_generate/arg_handler.py +24 -0
- mindspore/ops_generate/gen_ops_inner_prim.py +2 -0
- mindspore/ops_generate/gen_pyboost_func.py +13 -6
- mindspore/ops_generate/pyboost_utils.py +2 -17
- mindspore/parallel/__init__.py +3 -2
- mindspore/parallel/_auto_parallel_context.py +106 -1
- mindspore/parallel/_parallel_serialization.py +34 -2
- mindspore/parallel/_utils.py +16 -0
- mindspore/parallel/algo_parameter_config.py +4 -4
- mindspore/parallel/checkpoint_transform.py +249 -77
- mindspore/parallel/cluster/process_entity/_api.py +1 -1
- mindspore/parallel/parameter_broadcast.py +1 -1
- mindspore/parallel/shard.py +1 -1
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +1 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +17 -5
- mindspore/profiler/parser/ascend_msprof_exporter.py +3 -3
- mindspore/profiler/parser/ascend_msprof_generator.py +10 -3
- mindspore/profiler/parser/ascend_op_generator.py +26 -9
- mindspore/profiler/parser/ascend_timeline_generator.py +7 -4
- mindspore/profiler/parser/profiler_info.py +11 -1
- mindspore/profiler/profiling.py +13 -5
- mindspore/rewrite/api/node.py +12 -12
- mindspore/rewrite/api/symbol_tree.py +11 -11
- mindspore/run_check/_check_version.py +1 -1
- mindspore/safeguard/rewrite_obfuscation.py +2 -2
- mindspore/train/amp.py +4 -4
- mindspore/train/anf_ir_pb2.py +8 -2
- mindspore/train/callback/_backup_and_restore.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +2 -2
- mindspore/train/callback/_early_stop.py +2 -2
- mindspore/train/callback/_landscape.py +4 -4
- mindspore/train/callback/_loss_monitor.py +2 -2
- mindspore/train/callback/_on_request_exit.py +2 -2
- mindspore/train/callback/_reduce_lr_on_plateau.py +2 -2
- mindspore/train/callback/_summary_collector.py +2 -2
- mindspore/train/callback/_time_monitor.py +2 -2
- mindspore/train/dataset_helper.py +8 -3
- mindspore/train/loss_scale_manager.py +2 -2
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/mind_ir_pb2.py +22 -17
- mindspore/train/model.py +15 -15
- mindspore/train/serialization.py +18 -18
- mindspore/train/summary/summary_record.py +7 -7
- mindspore/train/train_thor/convert_utils.py +3 -3
- mindspore/version.py +1 -1
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/METADATA +1 -1
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/RECORD +226 -212
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0rc1.dist-info → mindspore-2.3.0rc2.dist-info}/top_level.txt +0 -0
mindspore/nn/cell.py
CHANGED
|
@@ -20,10 +20,9 @@ import inspect
|
|
|
20
20
|
import os
|
|
21
21
|
import time
|
|
22
22
|
from collections import OrderedDict
|
|
23
|
-
from types import FunctionType, MethodType
|
|
24
23
|
import numpy
|
|
25
24
|
|
|
26
|
-
from mindspore._checkparam import args_type_check
|
|
25
|
+
from mindspore._checkparam import args_type_check, check_hook_fn
|
|
27
26
|
from mindspore.common._auto_dynamic import is_auto_dynamic, convert_inputs_to_dynamic
|
|
28
27
|
from mindspore import log as logger
|
|
29
28
|
from mindspore.common.parameter import PARAMETER_NAME_DEFAULT
|
|
@@ -34,7 +33,7 @@ from mindspore._c_expression import init_pipeline, update_func_graph_hyper_param
|
|
|
34
33
|
from mindspore import _checkparam as Validator
|
|
35
34
|
from mindspore.common import dtype as mstype
|
|
36
35
|
from mindspore.common.api import _cell_graph_executor, _pynative_executor, _get_args_for_run, cells_compile_cache
|
|
37
|
-
from mindspore.common.api import _generate_branch_control_input
|
|
36
|
+
from mindspore.common.api import _generate_branch_control_input, _convert_python_data, _get_args_for_run_predict
|
|
38
37
|
from mindspore.common.parameter import Parameter, ParameterTuple
|
|
39
38
|
from mindspore.common.tensor import Tensor
|
|
40
39
|
from mindspore.ops.operations import Cast
|
|
@@ -43,6 +42,7 @@ from mindspore.ops.operations import _inner_ops as inner
|
|
|
43
42
|
from mindspore.parallel.shard import Shard
|
|
44
43
|
from mindspore._check_jit_forbidden_api import jit_forbidden_register
|
|
45
44
|
from mindspore.common._decorator import deprecated
|
|
45
|
+
from mindspore.common._register_for_recompute import recompute_registry
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
class Cell(Cell_):
|
|
@@ -125,11 +125,13 @@ class Cell(Cell_):
|
|
|
125
125
|
self._create_time = int(time.time() * 1e9)
|
|
126
126
|
self.arguments_key = ""
|
|
127
127
|
self.compile_cache = set()
|
|
128
|
+
self.phase_cache = dict()
|
|
128
129
|
cells_compile_cache[id(self)] = self.compile_cache
|
|
129
130
|
self.parameter_broadcast_done = False
|
|
130
131
|
self._id = 1
|
|
131
132
|
self.exist_names = set("")
|
|
132
133
|
self.exist_objs = set()
|
|
134
|
+
self.recompute_cell = None
|
|
133
135
|
init_pipeline()
|
|
134
136
|
|
|
135
137
|
# call gc to release GE session resources used by non-used cell objects
|
|
@@ -217,7 +219,7 @@ class Cell(Cell_):
|
|
|
217
219
|
|
|
218
220
|
Tutorial Examples:
|
|
219
221
|
- `Cell and Parameter - Custom Cell Reverse
|
|
220
|
-
<https://mindspore.cn/tutorials/en/
|
|
222
|
+
<https://mindspore.cn/tutorials/en/master/advanced/modules/layer.html#custom-cell-reverse>`_
|
|
221
223
|
"""
|
|
222
224
|
return self._bprop_debug
|
|
223
225
|
|
|
@@ -415,7 +417,7 @@ class Cell(Cell_):
|
|
|
415
417
|
elif isinstance(item, float):
|
|
416
418
|
res.append(self.cast(item, dst_type))
|
|
417
419
|
elif hasattr(item, "dtype") and item.dtype in \
|
|
418
|
-
|
|
420
|
+
{mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16} and item.dtype != dst_type:
|
|
419
421
|
res.append(self.cast(item, dst_type))
|
|
420
422
|
else:
|
|
421
423
|
res.append(item)
|
|
@@ -474,7 +476,10 @@ class Cell(Cell_):
|
|
|
474
476
|
elif hasattr(self, "_shard_fn"):
|
|
475
477
|
output = self._shard_fn(*cast_inputs, **kwargs)
|
|
476
478
|
else:
|
|
477
|
-
|
|
479
|
+
if self.recompute_cell is not None:
|
|
480
|
+
output = self.recompute_cell(*cast_inputs, **kwargs)
|
|
481
|
+
else:
|
|
482
|
+
output = self.construct(*cast_inputs, **kwargs)
|
|
478
483
|
if self._enable_forward_hook:
|
|
479
484
|
output = self._run_forward_hook(cast_inputs, output)
|
|
480
485
|
return output
|
|
@@ -659,6 +664,16 @@ class Cell(Cell_):
|
|
|
659
664
|
self.check_names_and_refresh_name()
|
|
660
665
|
self._is_check_and_refresh = True
|
|
661
666
|
|
|
667
|
+
def _predict(self, *args, **kwargs):
|
|
668
|
+
if not hasattr(self, "phase"):
|
|
669
|
+
return False, None
|
|
670
|
+
if (self.phase == "prefill" or self.phase == 'increment') and self.phase in self.phase_cache:
|
|
671
|
+
new_args = _get_args_for_run_predict(self, args, kwargs, self._compile_args)
|
|
672
|
+
res = _cell_graph_executor._graph_executor(tuple(new_args), self.phase_cache[self.phase])
|
|
673
|
+
res = _convert_python_data(res)
|
|
674
|
+
return True, res
|
|
675
|
+
return False, None
|
|
676
|
+
|
|
662
677
|
def __call__(self, *args, **kwargs):
|
|
663
678
|
# Run in Graph mode.
|
|
664
679
|
if os.getenv("MS_JIT") != '0' and context._get_mode() == context.GRAPH_MODE:
|
|
@@ -667,7 +682,12 @@ class Cell(Cell_):
|
|
|
667
682
|
bound_arguments.apply_defaults()
|
|
668
683
|
args = bound_arguments.args
|
|
669
684
|
kwargs = bound_arguments.kwargs
|
|
685
|
+
|
|
686
|
+
predict_compiled, res = self._predict(*args, **kwargs)
|
|
687
|
+
if predict_compiled:
|
|
688
|
+
return res
|
|
670
689
|
self._check_construct_args(*args)
|
|
690
|
+
|
|
671
691
|
if self._hook_fn_registered():
|
|
672
692
|
logger.warning(f"For 'Cell', it's not support hook function in graph mode. If you want to use hook "
|
|
673
693
|
f"function, please use context.set_context to set pynative mode.")
|
|
@@ -964,7 +984,6 @@ class Cell(Cell_):
|
|
|
964
984
|
return self._dynamic_shape_inputs
|
|
965
985
|
return args
|
|
966
986
|
|
|
967
|
-
|
|
968
987
|
def compile(self, *args, **kwargs):
|
|
969
988
|
"""
|
|
970
989
|
Compile Cell as a computation graph, the input must be consistent with the input defined in construct.
|
|
@@ -1335,7 +1354,7 @@ class Cell(Cell_):
|
|
|
1335
1354
|
|
|
1336
1355
|
Tutorial Examples:
|
|
1337
1356
|
- `Model Training - Optimizer
|
|
1338
|
-
<https://mindspore.cn/tutorials/en/
|
|
1357
|
+
<https://mindspore.cn/tutorials/en/master/beginner/train.html#optimizer>`_
|
|
1339
1358
|
"""
|
|
1340
1359
|
return list(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))
|
|
1341
1360
|
|
|
@@ -1446,7 +1465,7 @@ class Cell(Cell_):
|
|
|
1446
1465
|
|
|
1447
1466
|
Tutorial Examples:
|
|
1448
1467
|
- `Building a Network - Model Parameters
|
|
1449
|
-
<https://mindspore.cn/tutorials/en/
|
|
1468
|
+
<https://mindspore.cn/tutorials/en/master/beginner/model.html#model-parameters>`_
|
|
1450
1469
|
"""
|
|
1451
1470
|
cells = []
|
|
1452
1471
|
if expand:
|
|
@@ -1785,7 +1804,7 @@ class Cell(Cell_):
|
|
|
1785
1804
|
accelerate the algorithm in the algorithm library.
|
|
1786
1805
|
|
|
1787
1806
|
If `boost_type` is not in the algorithm library, please view the algorithm in the algorithm library through
|
|
1788
|
-
`algorithm library <https://gitee.com/mindspore/mindspore/tree/
|
|
1807
|
+
`algorithm library <https://gitee.com/mindspore/mindspore/tree/master/mindspore/python/mindspore/boost>`_.
|
|
1789
1808
|
|
|
1790
1809
|
Note:
|
|
1791
1810
|
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
|
|
@@ -1842,7 +1861,7 @@ class Cell(Cell_):
|
|
|
1842
1861
|
|
|
1843
1862
|
Tutorial Examples:
|
|
1844
1863
|
- `Model Training - Implementing Training and Evaluation
|
|
1845
|
-
<https://mindspore.cn/tutorials/en/
|
|
1864
|
+
<https://mindspore.cn/tutorials/en/master/beginner/train.html#training-and-evaluation>`_
|
|
1846
1865
|
"""
|
|
1847
1866
|
if mode:
|
|
1848
1867
|
self._phase = 'train'
|
|
@@ -1936,8 +1955,8 @@ class Cell(Cell_):
|
|
|
1936
1955
|
hook_fn (function): Python function. Forward pre hook function.
|
|
1937
1956
|
|
|
1938
1957
|
Returns:
|
|
1939
|
-
|
|
1940
|
-
|
|
1958
|
+
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
1959
|
+
`handle.remove()` .
|
|
1941
1960
|
|
|
1942
1961
|
Raises:
|
|
1943
1962
|
TypeError: If the `hook_fn` is not a function of python.
|
|
@@ -1972,17 +1991,8 @@ class Cell(Cell_):
|
|
|
1972
1991
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
1973
1992
|
value= [ 2.00000000e+00]))
|
|
1974
1993
|
"""
|
|
1975
|
-
if
|
|
1976
|
-
logger.warning(f"'register_forward_pre_hook' function is only supported in pynative mode, you can use "
|
|
1977
|
-
f"context.set_context to set pynative mode.")
|
|
1994
|
+
if not check_hook_fn("register_forward_pre_hook", hook_fn):
|
|
1978
1995
|
return HookHandle()
|
|
1979
|
-
|
|
1980
|
-
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
|
1981
|
-
raise TypeError(f"When using 'register_forward_pre_hook(hook_fn)', the type of 'hook_fn' must be python "
|
|
1982
|
-
f"function, but got {type(hook_fn)}.")
|
|
1983
|
-
if hook_fn.__code__.co_name == "staging_specialize":
|
|
1984
|
-
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
|
|
1985
|
-
|
|
1986
1996
|
self._enable_forward_pre_hook = True
|
|
1987
1997
|
_pynative_executor.set_hook_changed(self)
|
|
1988
1998
|
if not hasattr(self, '_forward_pre_hook_key'):
|
|
@@ -2036,8 +2046,8 @@ class Cell(Cell_):
|
|
|
2036
2046
|
hook_fn (function): Python function. Forward hook function.
|
|
2037
2047
|
|
|
2038
2048
|
Returns:
|
|
2039
|
-
|
|
2040
|
-
|
|
2049
|
+
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
2050
|
+
`handle.remove()` .
|
|
2041
2051
|
|
|
2042
2052
|
Raises:
|
|
2043
2053
|
TypeError: If the `hook_fn` is not a function of python.
|
|
@@ -2074,17 +2084,8 @@ class Cell(Cell_):
|
|
|
2074
2084
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]), Tensor(shape=[1], dtype=Float32,
|
|
2075
2085
|
value= [ 2.00000000e+00]))
|
|
2076
2086
|
"""
|
|
2077
|
-
if
|
|
2078
|
-
logger.warning(f"'register_forward_hook' function is only supported in pynative mode, you can use "
|
|
2079
|
-
f"context.set_context to set pynative mode.")
|
|
2087
|
+
if not check_hook_fn("register_forward_hook", hook_fn):
|
|
2080
2088
|
return HookHandle()
|
|
2081
|
-
|
|
2082
|
-
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
|
2083
|
-
raise TypeError(f"When using 'register_forward_hook(hook_fn)', the type of 'hook_fn' must be python "
|
|
2084
|
-
f"function, but got {type(hook_fn)}.")
|
|
2085
|
-
if hook_fn.__code__.co_name == "staging_specialize":
|
|
2086
|
-
raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
|
|
2087
|
-
|
|
2088
2089
|
self._enable_forward_hook = True
|
|
2089
2090
|
_pynative_executor.set_hook_changed(self)
|
|
2090
2091
|
if not hasattr(self, '_forward_hook_key'):
|
|
@@ -2136,8 +2137,8 @@ class Cell(Cell_):
|
|
|
2136
2137
|
hook_fn (function): Python function. Backward hook function.
|
|
2137
2138
|
|
|
2138
2139
|
Returns:
|
|
2139
|
-
|
|
2140
|
-
|
|
2140
|
+
A handle corresponding to the `hook_fn` . The handle can be used to remove the added `hook_fn` by calling
|
|
2141
|
+
`handle.remove()` .
|
|
2141
2142
|
|
|
2142
2143
|
Raises:
|
|
2143
2144
|
TypeError: If the `hook_fn` is not a function of python.
|
|
@@ -2172,14 +2173,8 @@ class Cell(Cell_):
|
|
|
2172
2173
|
>>> print(output)
|
|
2173
2174
|
(Tensor(shape=[1], dtype=Float32, value= [ 2.00000000e+00]),)
|
|
2174
2175
|
"""
|
|
2175
|
-
if
|
|
2176
|
-
logger.warning(f"'register_backward_hook' function is only supported in pynative mode, you can use "
|
|
2177
|
-
f"context.set_context to set pynative mode.")
|
|
2176
|
+
if not check_hook_fn("register_backward_hook", hook_fn):
|
|
2178
2177
|
return HookHandle()
|
|
2179
|
-
|
|
2180
|
-
if not isinstance(hook_fn, (FunctionType, MethodType)):
|
|
2181
|
-
raise TypeError(f"When using 'register_backward_hook(hook_fn)', the type of 'hook_fn' must be python "
|
|
2182
|
-
f"function, but got {type(hook_fn)}.")
|
|
2183
2178
|
if self._cell_backward_hook is None:
|
|
2184
2179
|
self._enable_backward_hook = True
|
|
2185
2180
|
self._cell_backward_hook = inner.CellBackwardHook(self.cls_name + "(" + str(id(self)) + ")")
|
|
@@ -2209,10 +2204,16 @@ class Cell(Cell_):
|
|
|
2209
2204
|
else:
|
|
2210
2205
|
inputs = self._cell_backward_hook(*inputs)
|
|
2211
2206
|
inputs = (inputs,)
|
|
2212
|
-
if
|
|
2213
|
-
|
|
2207
|
+
if self.recompute_cell is not None:
|
|
2208
|
+
if isinstance(inputs, tuple):
|
|
2209
|
+
outputs = self.recompute_cell(*inputs, **kwargs)
|
|
2210
|
+
else:
|
|
2211
|
+
outputs = self.recompute_cell(inputs, **kwargs)
|
|
2214
2212
|
else:
|
|
2215
|
-
|
|
2213
|
+
if isinstance(inputs, tuple):
|
|
2214
|
+
outputs = self.construct(*inputs, **kwargs)
|
|
2215
|
+
else:
|
|
2216
|
+
outputs = self.construct(inputs, **kwargs)
|
|
2216
2217
|
outputs = self._cell_backward_hook(outputs)
|
|
2217
2218
|
return outputs
|
|
2218
2219
|
|
|
@@ -2342,6 +2343,9 @@ class Cell(Cell_):
|
|
|
2342
2343
|
introduced by optimizer shard are recomputed in auto parallel or semi auto parallel mode.
|
|
2343
2344
|
Default: ``False`` .
|
|
2344
2345
|
"""
|
|
2346
|
+
if context.get_context("mode") == context.PYNATIVE_MODE:
|
|
2347
|
+
self.recompute_cell = recompute_registry.get()(self.construct)
|
|
2348
|
+
return
|
|
2345
2349
|
self._recompute()
|
|
2346
2350
|
if 'mp_comm_recompute' in kwargs.keys():
|
|
2347
2351
|
self._mp_comm_recompute(kwargs.get('mp_comm_recompute', False))
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
nn Extend.
|
|
17
|
+
"""
|
|
18
|
+
from __future__ import absolute_import
|
|
19
|
+
|
|
20
|
+
from mindspore.nn.extend.embedding import Embedding
|
|
21
|
+
from mindspore.nn.extend.basic import Linear
|
|
22
|
+
from mindspore.nn.extend.pooling import MaxPool2d
|
|
23
|
+
from mindspore.nn.extend import layer
|
|
24
|
+
from mindspore.nn.extend.layer import *
|
|
25
|
+
|
|
26
|
+
__all__ = ['Embedding', 'Linear', 'MaxPool2d']
|
|
27
|
+
__all__.extend(layer.__all__)
|
|
28
|
+
|
|
29
|
+
__all__.sort()
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
|
|
16
|
+
"""basic"""
|
|
17
|
+
from __future__ import absolute_import
|
|
18
|
+
|
|
19
|
+
import math
|
|
20
|
+
|
|
21
|
+
import mindspore.common.dtype as mstype
|
|
22
|
+
from mindspore import _checkparam as Validator
|
|
23
|
+
from mindspore._extends import cell_attr_register
|
|
24
|
+
from mindspore.common.initializer import initializer, HeUniform, Uniform
|
|
25
|
+
from mindspore.common.parameter import Parameter
|
|
26
|
+
from mindspore.common.tensor import Tensor
|
|
27
|
+
from mindspore.nn.cell import Cell
|
|
28
|
+
from mindspore.ops import operations as P
|
|
29
|
+
|
|
30
|
+
__all__ = ['Linear']
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Linear(Cell):
|
|
34
|
+
r"""
|
|
35
|
+
The linear connected layer.
|
|
36
|
+
|
|
37
|
+
Applies linear connected layer for the input. This layer implements the operation as:
|
|
38
|
+
|
|
39
|
+
.. math::
|
|
40
|
+
\text{outputs} = X * kernel + bias
|
|
41
|
+
|
|
42
|
+
where :math:`X` is the input tensors, :math:`\text{kernel}` is a weight matrix with the same
|
|
43
|
+
data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector
|
|
44
|
+
with the same data type as the :math:`X` created by the layer (only if has_bias is True).
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
in_features (int): The number of features in the input space.
|
|
48
|
+
out_features (int): The number of features in the output space.
|
|
49
|
+
bias (bool): Specifies whether the layer uses a bias vector :math:`\text{bias}`. Default: ``True``.
|
|
50
|
+
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
|
51
|
+
is same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
|
|
52
|
+
weight will be initialized using HeUniform.
|
|
53
|
+
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
|
54
|
+
same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
|
|
55
|
+
bias will be initialized using Uniform.
|
|
56
|
+
dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``None`` .
|
|
57
|
+
|
|
58
|
+
Inputs:
|
|
59
|
+
- **x** (Tensor) - Tensor of shape :math:`(*, in\_features)`. The `in_features` in `Args` should be equal
|
|
60
|
+
to :math:`in\_features` in `Inputs`.
|
|
61
|
+
|
|
62
|
+
Outputs:
|
|
63
|
+
Tensor of shape :math:`(*, out\_features)`.
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
TypeError: If `in_features` or `out_features` is not an int.
|
|
67
|
+
TypeError: If `bias` is not a bool.
|
|
68
|
+
ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init`
|
|
69
|
+
is not equal to `out_features` or shape[1] of `weight_init` is not equal to `in_features`.
|
|
70
|
+
ValueError: If length of shape of `bias_init` is not equal to 1
|
|
71
|
+
or shape[0] of `bias_init` is not equal to `out_features`.
|
|
72
|
+
|
|
73
|
+
Supported Platforms:
|
|
74
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
75
|
+
|
|
76
|
+
Examples:
|
|
77
|
+
>>> import mindspore
|
|
78
|
+
>>> from mindspore import Tensor
|
|
79
|
+
>>> from mindspore.nn.extend import Linear
|
|
80
|
+
>>> import numpy as np
|
|
81
|
+
>>> x = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
|
|
82
|
+
>>> net = Linear(3, 4)
|
|
83
|
+
>>> output = net(x)
|
|
84
|
+
>>> print(output.shape)
|
|
85
|
+
(2, 4)
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
@cell_attr_register(attrs=['has_bias'])
|
|
89
|
+
def __init__(self,
|
|
90
|
+
in_features,
|
|
91
|
+
out_features,
|
|
92
|
+
bias=True,
|
|
93
|
+
weight_init=None,
|
|
94
|
+
bias_init=None,
|
|
95
|
+
dtype=None):
|
|
96
|
+
"""Initialize Linear."""
|
|
97
|
+
super(Linear, self).__init__()
|
|
98
|
+
self.in_features = Validator.check_positive_int(
|
|
99
|
+
in_features, "in_features", self.cls_name)
|
|
100
|
+
self.out_features = Validator.check_positive_int(
|
|
101
|
+
out_features, "out_features", self.cls_name)
|
|
102
|
+
self.has_bias = Validator.check_bool(
|
|
103
|
+
bias, "has_bias", self.cls_name)
|
|
104
|
+
self.dense = P.Dense()
|
|
105
|
+
if dtype is None:
|
|
106
|
+
dtype = mstype.float32
|
|
107
|
+
if isinstance(weight_init, Tensor):
|
|
108
|
+
if weight_init.ndim != 2 or weight_init.shape[0] != out_features or \
|
|
109
|
+
weight_init.shape[1] != in_features:
|
|
110
|
+
raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' must "
|
|
111
|
+
f"be equal to 2, and the first dim must be equal to 'out_features', and the "
|
|
112
|
+
f"second dim must be equal to 'in_features'. But got 'weight_init': {weight_init}, "
|
|
113
|
+
f"'out_features': {out_features}, 'in_features': {in_features}.")
|
|
114
|
+
if weight_init is None:
|
|
115
|
+
weight_init = HeUniform(math.sqrt(5))
|
|
116
|
+
self.weight = Parameter(initializer(
|
|
117
|
+
weight_init, [out_features, in_features], dtype=dtype), name="weight")
|
|
118
|
+
|
|
119
|
+
self.bias = None
|
|
120
|
+
if self.has_bias:
|
|
121
|
+
if isinstance(bias_init, Tensor):
|
|
122
|
+
if bias_init.ndim != 1 or bias_init.shape[0] != out_features:
|
|
123
|
+
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must "
|
|
124
|
+
f"be equal to 1, and the first dim must be equal to 'out_features'. But got "
|
|
125
|
+
f"'bias_init': {bias_init}, 'out_features': {out_features}.")
|
|
126
|
+
if bias_init is None:
|
|
127
|
+
bound = 1 / math.sqrt(in_features)
|
|
128
|
+
bias_init = Uniform(scale=bound)
|
|
129
|
+
self.bias = Parameter(initializer(
|
|
130
|
+
bias_init, [out_features], dtype=dtype), name="bias")
|
|
131
|
+
|
|
132
|
+
def construct(self, x):
|
|
133
|
+
x = self.dense(x, self.weight, self.bias)
|
|
134
|
+
return x
|
|
135
|
+
|
|
136
|
+
def extend_repr(self):
|
|
137
|
+
s = f'input_features={self.in_features}, output_features={self.out_features}'
|
|
138
|
+
if self.has_bias:
|
|
139
|
+
s += f', has_bias={self.has_bias}'
|
|
140
|
+
return s
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""embedding"""
|
|
16
|
+
from __future__ import absolute_import
|
|
17
|
+
|
|
18
|
+
import mindspore.common.dtype as mstype
|
|
19
|
+
from mindspore.common.initializer import Normal
|
|
20
|
+
from mindspore import _checkparam as Validator
|
|
21
|
+
from mindspore.nn.cell import Cell
|
|
22
|
+
from mindspore import ops
|
|
23
|
+
from mindspore.common.parameter import Parameter
|
|
24
|
+
from mindspore.common.tensor import Tensor
|
|
25
|
+
|
|
26
|
+
__all__ = ['Embedding']
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Embedding(Cell):
|
|
30
|
+
r"""
|
|
31
|
+
Embedding layer.
|
|
32
|
+
Retrieve the word embeddings in weight stored in the layer using indices specified in `input`.
|
|
33
|
+
|
|
34
|
+
.. warning::
|
|
35
|
+
On Ascend, the behavior is unpredictable when the value of `input` is invalid.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
num_embeddings (int): Size of the dictionary of embeddings.
|
|
39
|
+
embedding_dim (int): The size of each embedding vector.
|
|
40
|
+
padding_idx (int, optional): If the value is not None, the corresponding row of embedding vector
|
|
41
|
+
will not be updated in training. The value of embedding vector at `padding_idx` will default
|
|
42
|
+
to zeros when the Embedding layer is newly constructed. The value should be in range
|
|
43
|
+
`[-num_embeddings, num_embeddings)` if it's not ``None``. Default ``None``.
|
|
44
|
+
max_norm (float, optional): If the value is not None, firstly get the p-norm result of the embedding
|
|
45
|
+
vector specified by `input` where p is specified by `norm_type`; if the result is larger then `max_norm`,
|
|
46
|
+
update the embedding vector` with :math:`\frac{max\_norm}{result+1e^{-7}}`. Default ``None``.
|
|
47
|
+
norm_type (float, optional): Indicated the value of p in p-norm. Default ``2.0``.
|
|
48
|
+
scale_grad_by_freq (bool, optional): If ``True`` the gradients will be scaled by the inverse of frequency
|
|
49
|
+
of the index in `input`. Default ``False``.
|
|
50
|
+
_weight (Tensor, optional): Used to initialize the weight of Embedding. If ``None``, the weight will be
|
|
51
|
+
initialized from normal distribution :math:`{N}(\text{sigma=1.0}, \text{mean=0.0})`. Default ``None``.
|
|
52
|
+
dtype (mindspore.dtype, optional) : Dtype of Parameters. It is meaningless when `_weight` is not None.
|
|
53
|
+
Default: ``mindspore.float32``.
|
|
54
|
+
|
|
55
|
+
Inputs:
|
|
56
|
+
- **input** (Tensor) - The indices used to lookup in the embedding vector. The data type must be
|
|
57
|
+
mindspore.int32 or mindspore.int64, and the value should be in range `[0, num_embeddings)`.
|
|
58
|
+
|
|
59
|
+
Outputs:
|
|
60
|
+
Tensor, has the same data type as weight, the shape is :math:`(*input.shape, embedding_dim)`.
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
TypeError: If `num_embeddings` is not an int.
|
|
64
|
+
TypeError: If `embedding_dim` is not an int.
|
|
65
|
+
ValueError: If `padding_idx` is out of valid range.
|
|
66
|
+
TypeError: If `max_norm` is not a float.
|
|
67
|
+
TypeError: If `norm_type` is not a float.
|
|
68
|
+
TypeError: If `scale_grad_by_freq` is not a bool.
|
|
69
|
+
TypeError: If `dtype` is not one of mindspore.dtype.
|
|
70
|
+
|
|
71
|
+
Supported Platforms:
|
|
72
|
+
``Ascend``
|
|
73
|
+
|
|
74
|
+
Examples:
|
|
75
|
+
>>> import mindspore
|
|
76
|
+
>>> import numpy as np
|
|
77
|
+
>>> from mindspore import Tensor, nn
|
|
78
|
+
>>> input = Tensor([[1, 0, 1, 1], [0, 0, 1, 0]])
|
|
79
|
+
>>> embedding = nn.extend.Embedding(num_embeddings=10, embedding_dim=3)
|
|
80
|
+
>>> output = embedding(input)
|
|
81
|
+
>>> print(output)
|
|
82
|
+
[[[-0.0024154 -0.01203444 0.00811537]
|
|
83
|
+
[ 0.00233847 -0.00596091 0.00536799]
|
|
84
|
+
[-0.0024154 -0.01203444 0.00811537]
|
|
85
|
+
[-0.0024154 -0.01203444 0.00811537]]
|
|
86
|
+
[[ 0.00233847 -0.00596091 0.00536799]
|
|
87
|
+
[ 0.00233847 -0.00596091 0.00536799]
|
|
88
|
+
[-0.0024154 -0.01203444 0.00811537]
|
|
89
|
+
[ 0.00233847 -0.00596091 0.00536799]]]
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0,
|
|
93
|
+
scale_grad_by_freq=False, _weight=None, dtype=mstype.float32):
|
|
94
|
+
"""Initialize Embedding."""
|
|
95
|
+
super().__init__()
|
|
96
|
+
self.num_embeddings = Validator.check_value_type(
|
|
97
|
+
'num_embeddings', num_embeddings, [int], self.cls_name)
|
|
98
|
+
self.embedding_dim = Validator.check_value_type(
|
|
99
|
+
'embedding_dim', embedding_dim, [int], self.cls_name)
|
|
100
|
+
Validator.check_subclass(
|
|
101
|
+
"dtype", dtype, mstype.number_type, self.cls_name)
|
|
102
|
+
self.dtype = dtype
|
|
103
|
+
self.padding_idx = padding_idx
|
|
104
|
+
if _weight is None:
|
|
105
|
+
init_tensor = Tensor(shape=[num_embeddings, embedding_dim], dtype=dtype, init=Normal(1, 0))
|
|
106
|
+
init_tensor = self._zero_weight_by_index(init_tensor)
|
|
107
|
+
self.weight = Parameter(init_tensor, name='weight')
|
|
108
|
+
else:
|
|
109
|
+
self.weight = Parameter(_weight)
|
|
110
|
+
|
|
111
|
+
self.max_norm = max_norm
|
|
112
|
+
if max_norm is not None:
|
|
113
|
+
self.max_norm = Validator.check_value_type('max_norm', max_norm, [float], self.cls_name)
|
|
114
|
+
|
|
115
|
+
self.norm_type = norm_type
|
|
116
|
+
if norm_type is not None:
|
|
117
|
+
self.norm_type = Validator.check_value_type('norm_type', norm_type,
|
|
118
|
+
[float], self.cls_name)
|
|
119
|
+
|
|
120
|
+
self.scale_grad_by_freq = scale_grad_by_freq
|
|
121
|
+
if scale_grad_by_freq is not None:
|
|
122
|
+
self.scale_grad_by_freq = Validator.check_value_type('scale_grad_by_freq',
|
|
123
|
+
scale_grad_by_freq,
|
|
124
|
+
[bool], self.cls_name)
|
|
125
|
+
|
|
126
|
+
def _zero_weight_by_index(self, init_tensor):
|
|
127
|
+
if self.padding_idx is not None:
|
|
128
|
+
self.padding_idx = Validator.check_int_range(self.padding_idx, -self.num_embeddings, self.num_embeddings,
|
|
129
|
+
Validator.INC_LEFT, "padding_idx", self.cls_name)
|
|
130
|
+
if isinstance(init_tensor, Tensor) and init_tensor.init is not None:
|
|
131
|
+
init_tensor = init_tensor.init_data()
|
|
132
|
+
init_tensor[self.padding_idx] = 0
|
|
133
|
+
|
|
134
|
+
return init_tensor
|
|
135
|
+
|
|
136
|
+
def construct(self, input):
|
|
137
|
+
return ops.embedding(input, self.weight, self.padding_idx, self.max_norm,
|
|
138
|
+
self.norm_type, self.scale_grad_by_freq)
|
|
139
|
+
|
|
140
|
+
def extend_repr(self):
|
|
141
|
+
return f'num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, ' \
|
|
142
|
+
f'padding_idx={self.padding_idx}, max_norm={self.max_norm}, norm_type={self.norm_type}, ' \
|
|
143
|
+
f'scale_grad_by_freq={self.scale_grad_by_freq}, dtype={self.dtype}'
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Copyright 2024 Huawei Technologies Co., Ltd
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ============================================================================
|
|
15
|
+
"""
|
|
16
|
+
Layer.
|
|
17
|
+
|
|
18
|
+
The high-level components(Cells) used to construct the neural network.
|
|
19
|
+
"""
|
|
20
|
+
from __future__ import absolute_import
|
|
21
|
+
|
|
22
|
+
from mindspore.nn.extend.layer import normalization
|
|
23
|
+
from mindspore.nn.extend.layer.normalization import *
|
|
24
|
+
|
|
25
|
+
__all__ = []
|
|
26
|
+
|
|
27
|
+
__all__.extend(normalization.__all__)
|