mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mindspore might be problematic. Click here for more details.
- mindspore/.commit_id +1 -1
- mindspore/__init__.py +3 -1
- mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
- mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
- mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
- mindspore/_checkparam.py +50 -9
- mindspore/_extends/parse/compile_config.py +41 -0
- mindspore/_extends/parse/parser.py +9 -7
- mindspore/_extends/parse/standard_method.py +52 -14
- mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
- mindspore/amp.py +24 -10
- mindspore/common/__init__.py +6 -4
- mindspore/common/_pijit_context.py +190 -0
- mindspore/common/_register_for_tensor.py +2 -1
- mindspore/common/_tensor_overload.py +139 -0
- mindspore/common/api.py +102 -87
- mindspore/common/dump.py +5 -6
- mindspore/common/generator.py +1 -7
- mindspore/common/hook_handle.py +14 -26
- mindspore/common/initializer.py +51 -15
- mindspore/common/mindir_util.py +2 -2
- mindspore/common/parameter.py +62 -15
- mindspore/common/recompute.py +39 -9
- mindspore/common/sparse_tensor.py +7 -3
- mindspore/common/tensor.py +183 -37
- mindspore/communication/__init__.py +1 -1
- mindspore/communication/_comm_helper.py +38 -3
- mindspore/communication/comm_func.py +315 -60
- mindspore/communication/management.py +14 -14
- mindspore/context.py +132 -22
- mindspore/dataset/__init__.py +1 -1
- mindspore/dataset/audio/__init__.py +1 -1
- mindspore/dataset/core/config.py +7 -0
- mindspore/dataset/core/validator_helpers.py +7 -0
- mindspore/dataset/engine/cache_client.py +1 -1
- mindspore/dataset/engine/datasets.py +72 -44
- mindspore/dataset/engine/datasets_audio.py +7 -7
- mindspore/dataset/engine/datasets_standard_format.py +53 -3
- mindspore/dataset/engine/datasets_text.py +20 -20
- mindspore/dataset/engine/datasets_user_defined.py +174 -104
- mindspore/dataset/engine/datasets_vision.py +33 -33
- mindspore/dataset/engine/iterators.py +29 -0
- mindspore/dataset/engine/obs/util.py +7 -0
- mindspore/dataset/engine/queue.py +114 -60
- mindspore/dataset/engine/serializer_deserializer.py +2 -2
- mindspore/dataset/engine/validators.py +34 -14
- mindspore/dataset/text/__init__.py +1 -4
- mindspore/dataset/transforms/__init__.py +0 -3
- mindspore/dataset/utils/line_reader.py +2 -0
- mindspore/dataset/vision/__init__.py +1 -4
- mindspore/dataset/vision/utils.py +1 -1
- mindspore/dataset/vision/validators.py +2 -1
- mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
- mindspore/experimental/es/embedding_service.py +883 -0
- mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
- mindspore/experimental/llm_boost/__init__.py +21 -0
- mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
- mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
- mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
- mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
- mindspore/experimental/llm_boost/register.py +129 -0
- mindspore/experimental/llm_boost/utils.py +31 -0
- mindspore/experimental/optim/adamw.py +85 -0
- mindspore/experimental/optim/optimizer.py +3 -0
- mindspore/hal/__init__.py +3 -3
- mindspore/hal/contiguous_tensors_handle.py +175 -0
- mindspore/hal/stream.py +18 -0
- mindspore/include/api/model_group.h +13 -1
- mindspore/include/api/types.h +10 -10
- mindspore/include/dataset/config.h +2 -2
- mindspore/include/dataset/constants.h +2 -2
- mindspore/include/dataset/execute.h +2 -2
- mindspore/include/dataset/vision.h +4 -0
- mindspore/log.py +1 -1
- mindspore/mindrecord/filewriter.py +68 -51
- mindspore/mindspore_backend.dll +0 -0
- mindspore/mindspore_common.dll +0 -0
- mindspore/mindspore_core.dll +0 -0
- mindspore/mindspore_np_dtype.dll +0 -0
- mindspore/mindspore_ops.dll +0 -0
- mindspore/mint/__init__.py +983 -46
- mindspore/mint/distributed/__init__.py +31 -0
- mindspore/mint/distributed/distributed.py +254 -0
- mindspore/mint/nn/__init__.py +268 -23
- mindspore/mint/nn/functional.py +125 -19
- mindspore/mint/nn/layer/__init__.py +39 -0
- mindspore/mint/nn/layer/activation.py +133 -0
- mindspore/mint/nn/layer/normalization.py +477 -0
- mindspore/mint/nn/layer/pooling.py +110 -0
- mindspore/mint/optim/adamw.py +26 -13
- mindspore/mint/special/__init__.py +63 -0
- mindspore/multiprocessing/__init__.py +2 -1
- mindspore/nn/__init__.py +0 -1
- mindspore/nn/cell.py +276 -96
- mindspore/nn/layer/activation.py +211 -44
- mindspore/nn/layer/basic.py +137 -10
- mindspore/nn/layer/embedding.py +137 -2
- mindspore/nn/layer/normalization.py +101 -5
- mindspore/nn/layer/padding.py +34 -48
- mindspore/nn/layer/pooling.py +161 -7
- mindspore/nn/layer/transformer.py +3 -3
- mindspore/nn/loss/__init__.py +2 -2
- mindspore/nn/loss/loss.py +84 -6
- mindspore/nn/optim/__init__.py +2 -1
- mindspore/nn/optim/adadelta.py +1 -1
- mindspore/nn/optim/adam.py +1 -1
- mindspore/nn/optim/lamb.py +1 -1
- mindspore/nn/optim/tft_wrapper.py +124 -0
- mindspore/nn/wrap/cell_wrapper.py +12 -23
- mindspore/nn/wrap/grad_reducer.py +5 -5
- mindspore/nn/wrap/loss_scale.py +17 -3
- mindspore/numpy/__init__.py +1 -1
- mindspore/numpy/array_creations.py +65 -68
- mindspore/numpy/array_ops.py +64 -60
- mindspore/numpy/fft.py +610 -75
- mindspore/numpy/logic_ops.py +11 -10
- mindspore/numpy/math_ops.py +85 -84
- mindspore/numpy/utils_const.py +4 -4
- mindspore/opencv_core452.dll +0 -0
- mindspore/opencv_imgcodecs452.dll +0 -0
- mindspore/opencv_imgproc452.dll +0 -0
- mindspore/ops/__init__.py +6 -4
- mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
- mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
- mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
- mindspore/ops/_vmap/vmap_array_ops.py +2 -4
- mindspore/ops/_vmap/vmap_math_ops.py +17 -1
- mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
- mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
- mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
- mindspore/ops/auto_generate/gen_extend_func.py +767 -13
- mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
- mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
- mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
- mindspore/ops/composite/base.py +85 -48
- mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
- mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
- mindspore/ops/function/__init__.py +22 -0
- mindspore/ops/function/array_func.py +492 -153
- mindspore/ops/function/debug_func.py +113 -1
- mindspore/ops/function/fft_func.py +15 -2
- mindspore/ops/function/grad/grad_func.py +3 -2
- mindspore/ops/function/math_func.py +564 -207
- mindspore/ops/function/nn_func.py +817 -383
- mindspore/ops/function/other_func.py +3 -2
- mindspore/ops/function/random_func.py +402 -12
- mindspore/ops/function/reshard_func.py +13 -11
- mindspore/ops/function/sparse_unary_func.py +1 -1
- mindspore/ops/function/vmap_func.py +3 -2
- mindspore/ops/functional.py +24 -14
- mindspore/ops/op_info_register.py +3 -3
- mindspore/ops/operations/__init__.py +7 -2
- mindspore/ops/operations/_grad_ops.py +2 -76
- mindspore/ops/operations/_infer_ops.py +1 -1
- mindspore/ops/operations/_inner_ops.py +71 -94
- mindspore/ops/operations/array_ops.py +14 -146
- mindspore/ops/operations/comm_ops.py +63 -53
- mindspore/ops/operations/custom_ops.py +83 -19
- mindspore/ops/operations/debug_ops.py +42 -10
- mindspore/ops/operations/manually_defined/_inner.py +12 -0
- mindspore/ops/operations/manually_defined/ops_def.py +273 -20
- mindspore/ops/operations/math_ops.py +12 -223
- mindspore/ops/operations/nn_ops.py +20 -114
- mindspore/ops/operations/other_ops.py +7 -4
- mindspore/ops/operations/random_ops.py +46 -1
- mindspore/ops/primitive.py +18 -6
- mindspore/ops_generate/arg_dtype_cast.py +2 -0
- mindspore/ops_generate/gen_aclnn_implement.py +11 -11
- mindspore/ops_generate/gen_constants.py +36 -0
- mindspore/ops_generate/gen_ops.py +67 -52
- mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
- mindspore/ops_generate/gen_pyboost_func.py +131 -47
- mindspore/ops_generate/op_proto.py +10 -3
- mindspore/ops_generate/pyboost_utils.py +14 -1
- mindspore/ops_generate/template.py +43 -21
- mindspore/parallel/__init__.py +3 -1
- mindspore/parallel/_auto_parallel_context.py +31 -9
- mindspore/parallel/_cell_wrapper.py +85 -0
- mindspore/parallel/_parallel_serialization.py +47 -19
- mindspore/parallel/_tensor.py +127 -13
- mindspore/parallel/_utils.py +53 -22
- mindspore/parallel/algo_parameter_config.py +5 -5
- mindspore/parallel/checkpoint_transform.py +46 -39
- mindspore/parallel/cluster/process_entity/__init__.py +1 -1
- mindspore/parallel/cluster/process_entity/_api.py +31 -23
- mindspore/parallel/cluster/process_entity/_utils.py +2 -27
- mindspore/parallel/parameter_broadcast.py +3 -4
- mindspore/parallel/shard.py +162 -31
- mindspore/parallel/transform_safetensors.py +1146 -0
- mindspore/profiler/__init__.py +2 -1
- mindspore/profiler/common/constant.py +29 -0
- mindspore/profiler/common/registry.py +47 -0
- mindspore/profiler/common/util.py +28 -0
- mindspore/profiler/dynamic_profiler.py +694 -0
- mindspore/profiler/envprofiling.py +17 -19
- mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
- mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
- mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
- mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
- mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
- mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
- mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
- mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
- mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
- mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
- mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
- mindspore/profiler/parser/base_timeline_generator.py +19 -25
- mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
- mindspore/profiler/parser/framework_parser.py +1 -391
- mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
- mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
- mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
- mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
- mindspore/profiler/parser/memory_usage_parser.py +0 -154
- mindspore/profiler/parser/profiler_info.py +78 -6
- mindspore/profiler/profiler.py +153 -0
- mindspore/profiler/profiling.py +285 -413
- mindspore/rewrite/__init__.py +1 -2
- mindspore/rewrite/common/namespace.py +4 -4
- mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
- mindspore/run_check/_check_version.py +39 -104
- mindspore/safeguard/rewrite_obfuscation.py +591 -247
- mindspore/train/__init__.py +4 -3
- mindspore/train/_utils.py +105 -19
- mindspore/train/amp.py +171 -53
- mindspore/train/callback/__init__.py +2 -2
- mindspore/train/callback/_callback.py +4 -4
- mindspore/train/callback/_checkpoint.py +97 -31
- mindspore/train/callback/_cluster_monitor.py +1 -1
- mindspore/train/callback/_flops_collector.py +1 -0
- mindspore/train/callback/_loss_monitor.py +3 -3
- mindspore/train/callback/_on_request_exit.py +145 -31
- mindspore/train/callback/_summary_collector.py +5 -5
- mindspore/train/callback/_tft_register.py +375 -0
- mindspore/train/dataset_helper.py +15 -3
- mindspore/train/metrics/metric.py +3 -3
- mindspore/train/metrics/roc.py +4 -4
- mindspore/train/mind_ir_pb2.py +44 -39
- mindspore/train/model.py +154 -58
- mindspore/train/serialization.py +342 -128
- mindspore/utils/__init__.py +21 -0
- mindspore/utils/utils.py +60 -0
- mindspore/version.py +1 -1
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +248 -242
- mindspore/include/c_api/ms/abstract.h +0 -67
- mindspore/include/c_api/ms/attribute.h +0 -197
- mindspore/include/c_api/ms/base/handle_types.h +0 -43
- mindspore/include/c_api/ms/base/macros.h +0 -32
- mindspore/include/c_api/ms/base/status.h +0 -33
- mindspore/include/c_api/ms/base/types.h +0 -283
- mindspore/include/c_api/ms/context.h +0 -102
- mindspore/include/c_api/ms/graph.h +0 -160
- mindspore/include/c_api/ms/node.h +0 -606
- mindspore/include/c_api/ms/tensor.h +0 -161
- mindspore/include/c_api/ms/value.h +0 -84
- mindspore/mindspore_shared_lib.dll +0 -0
- mindspore/nn/extend/basic.py +0 -140
- mindspore/nn/extend/embedding.py +0 -143
- mindspore/nn/extend/layer/normalization.py +0 -109
- mindspore/nn/extend/pooling.py +0 -117
- mindspore/nn/layer/embedding_service.py +0 -531
- mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
- mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
- mindspore/ops/extend/__init__.py +0 -53
- mindspore/ops/extend/array_func.py +0 -218
- mindspore/ops/extend/math_func.py +0 -76
- mindspore/ops/extend/nn_func.py +0 -308
- mindspore/ops/silent_check.py +0 -162
- mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
- mindspore/profiler/parser/msadvisor_parser.py +0 -240
- mindspore/train/callback/_mindio_ttp.py +0 -443
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
- {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
mindspore/nn/layer/activation.py
CHANGED
|
@@ -33,6 +33,7 @@ __all__ = ['Softmin',
|
|
|
33
33
|
'Softmax',
|
|
34
34
|
'Softmax2d',
|
|
35
35
|
'LogSoftmax',
|
|
36
|
+
'LogSoftmaxExt',
|
|
36
37
|
'ReLU',
|
|
37
38
|
'ReLU6',
|
|
38
39
|
'RReLU',
|
|
@@ -46,6 +47,7 @@ __all__ = ['Softmin',
|
|
|
46
47
|
'Sigmoid',
|
|
47
48
|
'Softsign',
|
|
48
49
|
'PReLU',
|
|
50
|
+
'PReLUExt',
|
|
49
51
|
'get_activation',
|
|
50
52
|
'LeakyReLU',
|
|
51
53
|
'HSigmoid',
|
|
@@ -279,6 +281,35 @@ class Softmax(Cell):
|
|
|
279
281
|
return self.softmax(input)
|
|
280
282
|
|
|
281
283
|
|
|
284
|
+
class SoftmaxExt(Cell):
|
|
285
|
+
r"""
|
|
286
|
+
Applies the Softmax function to an n-dimensional input Tensor.
|
|
287
|
+
|
|
288
|
+
For details, please refer to :func:`mindspore.mint.nn.functional.softmax`.
|
|
289
|
+
|
|
290
|
+
Supported Platforms:
|
|
291
|
+
``Ascend``
|
|
292
|
+
|
|
293
|
+
Examples:
|
|
294
|
+
>>> import mindspore
|
|
295
|
+
>>> from mindspore import Tensor, nn
|
|
296
|
+
>>> import numpy as np
|
|
297
|
+
>>> input = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
|
|
298
|
+
>>> softmax = nn.SoftmaxExt()
|
|
299
|
+
>>> output = softmax(input)
|
|
300
|
+
>>> print(output)
|
|
301
|
+
[0.03168 0.01166 0.0861 0.636 0.2341 ]
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
def __init__(self, dim=None):
|
|
305
|
+
"""Initialize Softmax."""
|
|
306
|
+
super(SoftmaxExt, self).__init__()
|
|
307
|
+
self.dim = dim
|
|
308
|
+
|
|
309
|
+
def construct(self, input):
|
|
310
|
+
return ops.function.nn_func.softmax_ext(input, self.dim)
|
|
311
|
+
|
|
312
|
+
|
|
282
313
|
class LogSoftmax(Cell):
|
|
283
314
|
r"""
|
|
284
315
|
Applies the LogSoftmax function to n-dimensional input tensor element-wise.
|
|
@@ -329,6 +360,51 @@ class LogSoftmax(Cell):
|
|
|
329
360
|
return self.log_softmax(x)
|
|
330
361
|
|
|
331
362
|
|
|
363
|
+
class LogSoftmaxExt(Cell):
|
|
364
|
+
r"""
|
|
365
|
+
Applies the Log Softmax function to the input tensor on the specified axis.
|
|
366
|
+
Supposes a slice in the given axis, :math:`x` for each element :math:`x_i`,
|
|
367
|
+
the Log Softmax function is shown as follows:
|
|
368
|
+
|
|
369
|
+
.. math::
|
|
370
|
+
\text{output}(x_i) = \log \left(\frac{\exp(x_i)} {\sum_{j = 0}^{N-1}\exp(x_j)}\right),
|
|
371
|
+
|
|
372
|
+
where :math:`N` is the length of the Tensor.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
dim (int, optional): The axis to perform the Log softmax operation. Default: ``None`` .
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
Tensor, with the same shape as the input.
|
|
379
|
+
|
|
380
|
+
Raises:
|
|
381
|
+
ValueError: If `dim` is not in range [-len(input.shape), len(input.shape)).
|
|
382
|
+
|
|
383
|
+
Supported Platforms:
|
|
384
|
+
``Ascend``
|
|
385
|
+
|
|
386
|
+
Examples:
|
|
387
|
+
>>> import mindspore
|
|
388
|
+
>>> from mindspore import Tensor, nn
|
|
389
|
+
>>> import numpy as np
|
|
390
|
+
>>> x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
|
|
391
|
+
>>> log_softmax = nn.LogSoftmaxExt(dim=-1)
|
|
392
|
+
>>> output = log_softmax(x)
|
|
393
|
+
>>> print(output)
|
|
394
|
+
[[-5.00672150e+00 -6.72150636e-03 -1.20067215e+01]
|
|
395
|
+
[-7.00091219e+00 -1.40009127e+01 -9.12250078e-04]]
|
|
396
|
+
"""
|
|
397
|
+
|
|
398
|
+
def __init__(self, dim=None):
|
|
399
|
+
"""Initialize LogSoftmaxExt."""
|
|
400
|
+
super(LogSoftmaxExt, self).__init__()
|
|
401
|
+
self.log_softmax = P.LogSoftmaxExt()
|
|
402
|
+
self.dim = dim
|
|
403
|
+
|
|
404
|
+
def construct(self, x):
|
|
405
|
+
return self.log_softmax(x, dim=self.dim)
|
|
406
|
+
|
|
407
|
+
|
|
332
408
|
class ELU(Cell):
|
|
333
409
|
r"""
|
|
334
410
|
Applies the exponential linear unit function element-wise.
|
|
@@ -434,8 +510,8 @@ class ReLU(Cell):
|
|
|
434
510
|
super(ReLU, self).__init__()
|
|
435
511
|
self.relu = P.ReLU()
|
|
436
512
|
|
|
437
|
-
def construct(self,
|
|
438
|
-
return self.relu(
|
|
513
|
+
def construct(self, input):
|
|
514
|
+
return self.relu(input)
|
|
439
515
|
|
|
440
516
|
|
|
441
517
|
class ReLU6(Cell):
|
|
@@ -898,6 +974,13 @@ class GELU(Cell):
|
|
|
898
974
|
Outputs:
|
|
899
975
|
Tensor, with the same type and shape as the `x`.
|
|
900
976
|
|
|
977
|
+
Note:
|
|
978
|
+
when calculating the input gradient of GELU with an input value of infinity, there are differences
|
|
979
|
+
in the output of the backward between ``Ascend`` and ``GPU``.
|
|
980
|
+
when x is -inf, the computation result of ``Ascend`` is 0, and the computation result of ``GPU`` is Nan.
|
|
981
|
+
when x is inf, the computation result of ``Ascend`` is dy, and the computation result of ``GPU`` is Nan.
|
|
982
|
+
In mathematical terms, the result of Ascend has higher precision.
|
|
983
|
+
|
|
901
984
|
Raises:
|
|
902
985
|
TypeError: If dtype of `x` is not one of float16, float32, or float64.
|
|
903
986
|
|
|
@@ -1164,14 +1247,85 @@ class PReLU(Cell):
|
|
|
1164
1247
|
return self.prelu(x, F.cast(self.w, x.dtype))
|
|
1165
1248
|
|
|
1166
1249
|
|
|
1250
|
+
class PReLUExt(Cell):
|
|
1251
|
+
r"""
|
|
1252
|
+
Applies PReLU activation function element-wise.
|
|
1253
|
+
|
|
1254
|
+
PReLU is defined as:
|
|
1255
|
+
|
|
1256
|
+
.. math::
|
|
1257
|
+
|
|
1258
|
+
PReLU(x_i)= \max(0, x_i) + w * \min(0, x_i),
|
|
1259
|
+
|
|
1260
|
+
where :math:`x_i` is an element of an channel of the input.
|
|
1261
|
+
|
|
1262
|
+
Here :math:`w` is a learnable parameter with a default initial value 0.25.
|
|
1263
|
+
Parameter :math:`w` has dimensionality of the argument channel. If called without argument
|
|
1264
|
+
channel, a single parameter :math:`w` will be shared across all channels.
|
|
1265
|
+
|
|
1266
|
+
PReLU Activation Function Graph:
|
|
1267
|
+
|
|
1268
|
+
.. image:: ../images/PReLU2.png
|
|
1269
|
+
:align: center
|
|
1270
|
+
|
|
1271
|
+
.. note::
|
|
1272
|
+
Channel dim is the 2nd dim of input. When input has dims < 2, then there is
|
|
1273
|
+
no channel dim and the number of channels = 1.
|
|
1274
|
+
|
|
1275
|
+
Args:
|
|
1276
|
+
num_parameters (int): number of `w` to learn. Although it takes an int as input,
|
|
1277
|
+
there is only two legitimate values: 1, or the number of channels at Tensor `input`. Default: ``1`` .
|
|
1278
|
+
init (float): the initial value of `w`. Default: ``0.25`` .
|
|
1279
|
+
dtype (mindspore.dtype, optional): the type of `w`. Default: ``None`` . Supported data type
|
|
1280
|
+
is {float16, float32, bfloat16}.
|
|
1281
|
+
|
|
1282
|
+
Inputs:
|
|
1283
|
+
- **input** (Tensor) - The input of PReLU.
|
|
1284
|
+
|
|
1285
|
+
Outputs:
|
|
1286
|
+
Tensor, with the same dtype and shape as the `input`.
|
|
1287
|
+
|
|
1288
|
+
Supported Platforms:
|
|
1289
|
+
``Ascend``
|
|
1290
|
+
|
|
1291
|
+
Examples:
|
|
1292
|
+
>>> import mindspore
|
|
1293
|
+
>>> from mindspore import Tensor, nn
|
|
1294
|
+
>>> import numpy as np
|
|
1295
|
+
>>> x = Tensor(np.array([[[[0.1, 0.6], [0.9, 0.9]]]]), mindspore.float32)
|
|
1296
|
+
>>> prelu = nn.PReLUExt()
|
|
1297
|
+
>>> output = prelu(x)
|
|
1298
|
+
>>> print(output)
|
|
1299
|
+
[[[[0.1 0.6]
|
|
1300
|
+
[0.9 0.9]]]]
|
|
1301
|
+
|
|
1302
|
+
"""
|
|
1303
|
+
|
|
1304
|
+
def __init__(self, num_parameters=1, init=0.25, dtype=None):
|
|
1305
|
+
"""Initialize PReLUExt."""
|
|
1306
|
+
super(PReLUExt, self).__init__()
|
|
1307
|
+
tmp = np.empty((num_parameters,), dtype=np.float32)
|
|
1308
|
+
tmp.fill(init)
|
|
1309
|
+
w = Tensor(tmp, dtype=dtype)
|
|
1310
|
+
self.weight = Parameter(w, name='weight')
|
|
1311
|
+
|
|
1312
|
+
def construct(self, input):
|
|
1313
|
+
return ops.prelu(input, self.weight)
|
|
1314
|
+
|
|
1315
|
+
|
|
1167
1316
|
class HSwish(Cell):
|
|
1168
1317
|
r"""
|
|
1169
|
-
Applies
|
|
1318
|
+
Applies Hard Swish activation function element-wise.
|
|
1170
1319
|
|
|
1171
1320
|
Hard swish is defined as:
|
|
1172
1321
|
|
|
1173
1322
|
.. math::
|
|
1174
|
-
\text{
|
|
1323
|
+
\text{Hardswish}(input) =
|
|
1324
|
+
\begin{cases}
|
|
1325
|
+
0, & \text{ if } input \leq -3, \\
|
|
1326
|
+
input, & \text{ if } input \geq +3, \\
|
|
1327
|
+
input*(input + 3)/6, & \text{ otherwise }
|
|
1328
|
+
\end{cases}
|
|
1175
1329
|
|
|
1176
1330
|
HSwish Activation Function Graph:
|
|
1177
1331
|
|
|
@@ -1179,14 +1333,14 @@ class HSwish(Cell):
|
|
|
1179
1333
|
:align: center
|
|
1180
1334
|
|
|
1181
1335
|
Inputs:
|
|
1182
|
-
- **
|
|
1183
|
-
The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
|
|
1336
|
+
- **input** (Tensor) - The input of HSwish.
|
|
1184
1337
|
|
|
1185
1338
|
Outputs:
|
|
1186
|
-
Tensor, with the same type and shape as the `
|
|
1339
|
+
Tensor, with the same type and shape as the `input`.
|
|
1187
1340
|
|
|
1188
1341
|
Raises:
|
|
1189
|
-
TypeError: If
|
|
1342
|
+
TypeError: If `input` is not a tensor.
|
|
1343
|
+
TypeError: If `input` is neither int nor float.
|
|
1190
1344
|
|
|
1191
1345
|
Supported Platforms:
|
|
1192
1346
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -1195,9 +1349,9 @@ class HSwish(Cell):
|
|
|
1195
1349
|
>>> import mindspore
|
|
1196
1350
|
>>> from mindspore import Tensor, nn
|
|
1197
1351
|
>>> import numpy as np
|
|
1198
|
-
>>>
|
|
1352
|
+
>>> input = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
|
|
1199
1353
|
>>> hswish = nn.HSwish()
|
|
1200
|
-
>>> result = hswish(
|
|
1354
|
+
>>> result = hswish(input)
|
|
1201
1355
|
>>> print(result)
|
|
1202
1356
|
[-0.3333 -0.3333 0. 1.667 0.6665]
|
|
1203
1357
|
"""
|
|
@@ -1207,18 +1361,23 @@ class HSwish(Cell):
|
|
|
1207
1361
|
super(HSwish, self).__init__()
|
|
1208
1362
|
self.hswish = P.HSwish()
|
|
1209
1363
|
|
|
1210
|
-
def construct(self,
|
|
1211
|
-
return self.hswish(
|
|
1364
|
+
def construct(self, input):
|
|
1365
|
+
return self.hswish(input)
|
|
1212
1366
|
|
|
1213
1367
|
|
|
1214
1368
|
class HSigmoid(Cell):
|
|
1215
1369
|
r"""
|
|
1216
|
-
Applies Hard
|
|
1370
|
+
Applies Hard Sigmoid activation function element-wise.
|
|
1217
1371
|
|
|
1218
|
-
Hard
|
|
1372
|
+
Hard Sigmoid is defined as:
|
|
1219
1373
|
|
|
1220
1374
|
.. math::
|
|
1221
|
-
\text{
|
|
1375
|
+
\text{Hardsigmoid}(input) =
|
|
1376
|
+
\begin{cases}
|
|
1377
|
+
0, & \text{ if } input \leq -3, \\
|
|
1378
|
+
1, & \text{ if } input \geq +3, \\
|
|
1379
|
+
input/6 + 1/2, & \text{ otherwise }
|
|
1380
|
+
\end{cases}
|
|
1222
1381
|
|
|
1223
1382
|
HSigmoid Activation Function Graph:
|
|
1224
1383
|
|
|
@@ -1226,13 +1385,14 @@ class HSigmoid(Cell):
|
|
|
1226
1385
|
:align: center
|
|
1227
1386
|
|
|
1228
1387
|
Inputs:
|
|
1229
|
-
- **
|
|
1388
|
+
- **input** (Tensor) - The input of HSigmoid.
|
|
1230
1389
|
|
|
1231
1390
|
Outputs:
|
|
1232
|
-
Tensor, with the same type and shape as the `
|
|
1391
|
+
Tensor, with the same type and shape as the `input`.
|
|
1233
1392
|
|
|
1234
1393
|
Raises:
|
|
1235
|
-
TypeError: If `
|
|
1394
|
+
TypeError: If `input` is not a Tensor.
|
|
1395
|
+
TypeError: If `input` is neither int nor float.
|
|
1236
1396
|
|
|
1237
1397
|
Supported Platforms:
|
|
1238
1398
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -1241,9 +1401,9 @@ class HSigmoid(Cell):
|
|
|
1241
1401
|
>>> import mindspore
|
|
1242
1402
|
>>> from mindspore import Tensor, nn
|
|
1243
1403
|
>>> import numpy as np
|
|
1244
|
-
>>>
|
|
1404
|
+
>>> input = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
|
|
1245
1405
|
>>> hsigmoid = nn.HSigmoid()
|
|
1246
|
-
>>> result = hsigmoid(
|
|
1406
|
+
>>> result = hsigmoid(input)
|
|
1247
1407
|
>>> print(result)
|
|
1248
1408
|
[0.3333 0.1666 0.5 0.8335 0.6665]
|
|
1249
1409
|
"""
|
|
@@ -1253,8 +1413,8 @@ class HSigmoid(Cell):
|
|
|
1253
1413
|
super(HSigmoid, self).__init__()
|
|
1254
1414
|
self.hsigmoid = P.HSigmoid()
|
|
1255
1415
|
|
|
1256
|
-
def construct(self,
|
|
1257
|
-
return self.hsigmoid(
|
|
1416
|
+
def construct(self, input):
|
|
1417
|
+
return self.hsigmoid(input)
|
|
1258
1418
|
|
|
1259
1419
|
|
|
1260
1420
|
class LogSigmoid(Cell):
|
|
@@ -1370,21 +1530,22 @@ class SoftShrink(Cell):
|
|
|
1370
1530
|
:align: center
|
|
1371
1531
|
|
|
1372
1532
|
Args:
|
|
1373
|
-
lambd (
|
|
1374
|
-
|
|
1533
|
+
lambd (number, optional): The threshold :math:`\lambda` defined by the Soft Shrink formula.
|
|
1534
|
+
It should be greater than or equal to 0, default: ``0.5`` .
|
|
1375
1535
|
|
|
1376
1536
|
Inputs:
|
|
1377
|
-
- **
|
|
1378
|
-
|
|
1537
|
+
- **input** (Tensor) - The input of Soft Shrink. Supported dtypes:
|
|
1538
|
+
|
|
1539
|
+
- Ascend: float16, float32, bfloat16.
|
|
1540
|
+
- CPU/GPU: float16, float32.
|
|
1379
1541
|
|
|
1380
1542
|
Outputs:
|
|
1381
|
-
Tensor,
|
|
1543
|
+
Tensor, the same shape and data type as the input.
|
|
1382
1544
|
|
|
1383
1545
|
Raises:
|
|
1384
|
-
TypeError: If lambd is not a float.
|
|
1385
|
-
TypeError: If
|
|
1386
|
-
TypeError: If dtype of
|
|
1387
|
-
ValueError: If lambd is less than 0.
|
|
1546
|
+
TypeError: If `lambd` is not a float, int or bool.
|
|
1547
|
+
TypeError: If `input` is not a tensor.
|
|
1548
|
+
TypeError: If dtype of `input` is not float16, float32 or bfloat16.
|
|
1388
1549
|
|
|
1389
1550
|
Supported Platforms:
|
|
1390
1551
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -1393,9 +1554,9 @@ class SoftShrink(Cell):
|
|
|
1393
1554
|
>>> import mindspore
|
|
1394
1555
|
>>> from mindspore import Tensor, nn
|
|
1395
1556
|
>>> import numpy as np
|
|
1396
|
-
>>>
|
|
1557
|
+
>>> input = Tensor(np.array([[ 0.5297, 0.7871, 1.1754], [ 0.7836, 0.6218, -1.1542]]), mindspore.float16)
|
|
1397
1558
|
>>> softshrink = nn.SoftShrink()
|
|
1398
|
-
>>> output = softshrink(
|
|
1559
|
+
>>> output = softshrink(input)
|
|
1399
1560
|
>>> print(output)
|
|
1400
1561
|
[[ 0.02979 0.287 0.676 ]
|
|
1401
1562
|
[ 0.2837 0.1216 -0.6543 ]]
|
|
@@ -1405,8 +1566,8 @@ class SoftShrink(Cell):
|
|
|
1405
1566
|
super(SoftShrink, self).__init__()
|
|
1406
1567
|
self.softshrink = P.SoftShrink(lambd)
|
|
1407
1568
|
|
|
1408
|
-
def construct(self,
|
|
1409
|
-
output = self.softshrink(
|
|
1569
|
+
def construct(self, input):
|
|
1570
|
+
output = self.softshrink(input)
|
|
1410
1571
|
return output
|
|
1411
1572
|
|
|
1412
1573
|
|
|
@@ -1430,17 +1591,21 @@ class HShrink(Cell):
|
|
|
1430
1591
|
:align: center
|
|
1431
1592
|
|
|
1432
1593
|
Args:
|
|
1433
|
-
lambd (
|
|
1594
|
+
lambd (number, optional): The threshold :math:`\lambda` defined by the Hard Shrink formula. Default: ``0.5`` .
|
|
1434
1595
|
|
|
1435
1596
|
Inputs:
|
|
1436
|
-
- **
|
|
1597
|
+
- **input** (Tensor) - The input of Hard Shrink. Supported dtypes:
|
|
1598
|
+
|
|
1599
|
+
- Ascend: float16, float32, bfloat16.
|
|
1600
|
+
- CPU/GPU: float16, float32.
|
|
1437
1601
|
|
|
1438
1602
|
Outputs:
|
|
1439
1603
|
Tensor, the same shape and data type as the input.
|
|
1440
1604
|
|
|
1441
1605
|
Raises:
|
|
1442
|
-
TypeError: If `lambd` is not a float.
|
|
1443
|
-
TypeError: If
|
|
1606
|
+
TypeError: If `lambd` is not a float, int or bool.
|
|
1607
|
+
TypeError: If `input` is not a tensor.
|
|
1608
|
+
TypeError: If dtype of `input` is not float16, float32 or bfloat16.
|
|
1444
1609
|
|
|
1445
1610
|
Supported Platforms:
|
|
1446
1611
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -1449,20 +1614,20 @@ class HShrink(Cell):
|
|
|
1449
1614
|
>>> import mindspore
|
|
1450
1615
|
>>> from mindspore import Tensor, nn
|
|
1451
1616
|
>>> import numpy as np
|
|
1452
|
-
>>>
|
|
1617
|
+
>>> input = Tensor(np.array([[0.5, 1, 2.0], [0.0533, 0.0776, -2.1233]]), mindspore.float32)
|
|
1453
1618
|
>>> hshrink = nn.HShrink()
|
|
1454
|
-
>>> output = hshrink(
|
|
1619
|
+
>>> output = hshrink(input)
|
|
1455
1620
|
>>> print(output)
|
|
1456
1621
|
[[ 0. 1. 2. ]
|
|
1457
|
-
|
|
1622
|
+
[ 0. 0. -2.1233]]
|
|
1458
1623
|
"""
|
|
1459
1624
|
|
|
1460
1625
|
def __init__(self, lambd=0.5):
|
|
1461
1626
|
super(HShrink, self).__init__()
|
|
1462
1627
|
self.hshrink = P.HShrink(lambd)
|
|
1463
1628
|
|
|
1464
|
-
def construct(self,
|
|
1465
|
-
return self.hshrink(
|
|
1629
|
+
def construct(self, input):
|
|
1630
|
+
return self.hshrink(input)
|
|
1466
1631
|
|
|
1467
1632
|
|
|
1468
1633
|
class Threshold(Cell):
|
|
@@ -1602,6 +1767,7 @@ _activation = {
|
|
|
1602
1767
|
'softmax': Softmax,
|
|
1603
1768
|
'softmax2d': Softmax2d,
|
|
1604
1769
|
'logsoftmax': LogSoftmax,
|
|
1770
|
+
'logsoftmaxExt': LogSoftmaxExt,
|
|
1605
1771
|
'relu': ReLU,
|
|
1606
1772
|
'relu6': ReLU6,
|
|
1607
1773
|
'rrelu': RReLU,
|
|
@@ -1615,6 +1781,7 @@ _activation = {
|
|
|
1615
1781
|
'sigmoid': Sigmoid,
|
|
1616
1782
|
'softsign': Softsign,
|
|
1617
1783
|
'prelu': PReLU,
|
|
1784
|
+
'preluExt': PReLUExt,
|
|
1618
1785
|
'leakyrelu': LeakyReLU,
|
|
1619
1786
|
'hswish': HSwish,
|
|
1620
1787
|
'hsigmoid': HSigmoid,
|
mindspore/nn/layer/basic.py
CHANGED
|
@@ -40,7 +40,7 @@ from mindspore.common._decorator import deprecated
|
|
|
40
40
|
from mindspore.ops.auto_generate import dropout_ext_op, fold_ext
|
|
41
41
|
from mindspore.common.generator import default_generator
|
|
42
42
|
|
|
43
|
-
__all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', 'Tril', 'Triu',
|
|
43
|
+
__all__ = ['Dropout', 'Flatten', 'Dense', 'Linear', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', 'Tril', 'Triu',
|
|
44
44
|
'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag', 'L1Regularizer', 'Dropout1d',
|
|
45
45
|
'Dropout2d', 'Dropout3d', 'Upsample', 'Roll', 'Identity', 'Unflatten', 'DropoutExt']
|
|
46
46
|
|
|
@@ -510,8 +510,8 @@ class UpsampleExt(Cell):
|
|
|
510
510
|
self.align_corners = align_corners
|
|
511
511
|
self.recompute_scale_factor = recompute_scale_factor
|
|
512
512
|
|
|
513
|
-
def construct(self,
|
|
514
|
-
out = interpolate_ext(
|
|
513
|
+
def construct(self, input):
|
|
514
|
+
out = interpolate_ext(input, self.size, self.scale_factor, self.mode,
|
|
515
515
|
self.align_corners, self.recompute_scale_factor)
|
|
516
516
|
return out
|
|
517
517
|
|
|
@@ -579,11 +579,15 @@ class Identity(Cell):
|
|
|
579
579
|
r"""
|
|
580
580
|
A placeholder identity operator that returns the same as input.
|
|
581
581
|
|
|
582
|
+
Args:
|
|
583
|
+
args (Any): Any argument.
|
|
584
|
+
kwargs (Any): Any keyword argument.
|
|
585
|
+
|
|
582
586
|
Inputs:
|
|
583
|
-
- **
|
|
587
|
+
- **input** (Any) - The input of Identity.
|
|
584
588
|
|
|
585
589
|
Outputs:
|
|
586
|
-
The same as `
|
|
590
|
+
The same as `input`.
|
|
587
591
|
|
|
588
592
|
Supported Platforms:
|
|
589
593
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -592,19 +596,19 @@ class Identity(Cell):
|
|
|
592
596
|
>>> import mindspore
|
|
593
597
|
>>> from mindspore import Tensor, nn
|
|
594
598
|
>>> import numpy as np
|
|
595
|
-
>>>
|
|
599
|
+
>>> input = Tensor(np.array([1, 2, 3, 4]), mindspore.int64)
|
|
596
600
|
>>> net = nn.Identity()
|
|
597
|
-
>>> output = net(
|
|
601
|
+
>>> output = net(input)
|
|
598
602
|
>>> print(output)
|
|
599
603
|
[1 2 3 4]
|
|
600
604
|
"""
|
|
601
605
|
|
|
602
|
-
def __init__(self):
|
|
606
|
+
def __init__(self, *args, **kwargs):
|
|
603
607
|
"""Initialize Identity."""
|
|
604
608
|
super(Identity, self).__init__()
|
|
605
609
|
|
|
606
|
-
def construct(self,
|
|
607
|
-
return
|
|
610
|
+
def construct(self, input):
|
|
611
|
+
return input
|
|
608
612
|
|
|
609
613
|
|
|
610
614
|
class Dense(Cell):
|
|
@@ -621,6 +625,9 @@ class Dense(Cell):
|
|
|
621
625
|
data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector
|
|
622
626
|
with the same data type as the :math:`X` created by the layer (only if has_bias is True).
|
|
623
627
|
|
|
628
|
+
.. warning::
|
|
629
|
+
In PYNATIVE mode, if `bias` is ``False`` , the `x` cannot be greater than 6D.
|
|
630
|
+
|
|
624
631
|
Args:
|
|
625
632
|
in_channels (int): The number of channels in the input space.
|
|
626
633
|
out_channels (int): The number of channels in the output space.
|
|
@@ -635,6 +642,8 @@ class Dense(Cell):
|
|
|
635
642
|
layer. Both activation name, e.g. 'relu', and mindspore activation function, e.g. mindspore.ops.ReLU(),
|
|
636
643
|
are supported. Default: ``None`` .
|
|
637
644
|
dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
|
|
645
|
+
When `weight_init` is Tensor, Parameter has the same data type as `weight_init` ,
|
|
646
|
+
in other cases, Parameter has the same data type as `dtype`, the same goes for `bias_init`.
|
|
638
647
|
|
|
639
648
|
Inputs:
|
|
640
649
|
- **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `in_channels` in `Args` should be equal
|
|
@@ -651,6 +660,7 @@ class Dense(Cell):
|
|
|
651
660
|
is not equal to `out_channels` or shape[1] of `weight_init` is not equal to `in_channels`.
|
|
652
661
|
ValueError: If length of shape of `bias_init` is not equal to 1
|
|
653
662
|
or shape[0] of `bias_init` is not equal to `out_channels`.
|
|
663
|
+
RuntimeError: If `bias` is ``False`` and `x` is greater than 6D in PYNATIVE mode.
|
|
654
664
|
|
|
655
665
|
Supported Platforms:
|
|
656
666
|
``Ascend`` ``GPU`` ``CPU``
|
|
@@ -743,6 +753,123 @@ class Dense(Cell):
|
|
|
743
753
|
return s
|
|
744
754
|
|
|
745
755
|
|
|
756
|
+
class Linear(Cell):
|
|
757
|
+
r"""
|
|
758
|
+
The linear connected layer.
|
|
759
|
+
|
|
760
|
+
Applies linear connected layer for the input. This layer implements the operation as:
|
|
761
|
+
|
|
762
|
+
.. math::
|
|
763
|
+
\text{outputs} = X * kernel + bias
|
|
764
|
+
|
|
765
|
+
.. warning::
|
|
766
|
+
In PYNATIVE mode, if `bias` is ``False`` , the `x` cannot be greater than 6D.
|
|
767
|
+
|
|
768
|
+
where :math:`X` is the input tensors, :math:`\text{kernel}` is a weight matrix with the same
|
|
769
|
+
data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector
|
|
770
|
+
with the same data type as the :math:`X` created by the layer (only if has_bias is True).
|
|
771
|
+
|
|
772
|
+
Args:
|
|
773
|
+
in_features (int): The number of features in the input space.
|
|
774
|
+
out_features (int): The number of features in the output space.
|
|
775
|
+
bias (bool): Specifies whether the layer uses a bias vector :math:`\text{bias}`. Default: ``True``.
|
|
776
|
+
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
|
777
|
+
is same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
|
|
778
|
+
weight will be initialized using HeUniform.
|
|
779
|
+
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
|
780
|
+
same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
|
|
781
|
+
bias will be initialized using Uniform.
|
|
782
|
+
dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``None`` .
|
|
783
|
+
If `dtype` is ``None`` , `dtype` is set to ``mstype.float32`` when initializing the method.
|
|
784
|
+
When `weight_init` is Tensor, Parameter has the same data type as `weight_init` ,
|
|
785
|
+
in other cases, Parameter has the same data type as `dtype`, the same goes for `bias_init`.
|
|
786
|
+
|
|
787
|
+
Inputs:
|
|
788
|
+
- **x** (Tensor) - Tensor of shape :math:`(*, in\_features)`. The `in_features` in `Args` should be equal
|
|
789
|
+
to :math:`in\_features` in `Inputs`.
|
|
790
|
+
|
|
791
|
+
Outputs:
|
|
792
|
+
Tensor of shape :math:`(*, out\_features)`.
|
|
793
|
+
|
|
794
|
+
Raises:
|
|
795
|
+
TypeError: If `in_features` or `out_features` is not an int.
|
|
796
|
+
TypeError: If `bias` is not a bool.
|
|
797
|
+
ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init`
|
|
798
|
+
is not equal to `out_features` or shape[1] of `weight_init` is not equal to `in_features`.
|
|
799
|
+
ValueError: If length of shape of `bias_init` is not equal to 1
|
|
800
|
+
or shape[0] of `bias_init` is not equal to `out_features`.
|
|
801
|
+
RuntimeError: If `bias` is ``False`` and `x` is greater than 6D in PYNATIVE mode.
|
|
802
|
+
|
|
803
|
+
Supported Platforms:
|
|
804
|
+
``Ascend`` ``GPU`` ``CPU``
|
|
805
|
+
|
|
806
|
+
Examples:
|
|
807
|
+
>>> import mindspore
|
|
808
|
+
>>> from mindspore import Tensor
|
|
809
|
+
>>> from mindspore import nn
|
|
810
|
+
>>> import numpy as np
|
|
811
|
+
>>> x = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
|
|
812
|
+
>>> net = nn.mint.nn.Linear(3, 4)
|
|
813
|
+
>>> output = net(x)
|
|
814
|
+
>>> print(output.shape)
|
|
815
|
+
(2, 4)
|
|
816
|
+
"""
|
|
817
|
+
|
|
818
|
+
@cell_attr_register(attrs=['has_bias'])
|
|
819
|
+
def __init__(self,
|
|
820
|
+
in_features,
|
|
821
|
+
out_features,
|
|
822
|
+
bias=True,
|
|
823
|
+
weight_init=None,
|
|
824
|
+
bias_init=None,
|
|
825
|
+
dtype=None):
|
|
826
|
+
"""Initialize Linear."""
|
|
827
|
+
super(Linear, self).__init__()
|
|
828
|
+
self.in_features = Validator.check_positive_int(
|
|
829
|
+
in_features, "in_features", self.cls_name)
|
|
830
|
+
self.out_features = Validator.check_positive_int(
|
|
831
|
+
out_features, "out_features", self.cls_name)
|
|
832
|
+
self.has_bias = Validator.check_bool(
|
|
833
|
+
bias, "has_bias", self.cls_name)
|
|
834
|
+
self.dense = P.Dense()
|
|
835
|
+
if dtype is None:
|
|
836
|
+
dtype = mstype.float32
|
|
837
|
+
if isinstance(weight_init, Tensor):
|
|
838
|
+
if weight_init.ndim != 2 or weight_init.shape[0] != out_features or \
|
|
839
|
+
weight_init.shape[1] != in_features:
|
|
840
|
+
raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' must "
|
|
841
|
+
f"be equal to 2, and the first dim must be equal to 'out_features', and the "
|
|
842
|
+
f"second dim must be equal to 'in_features'. But got 'weight_init': {weight_init}, "
|
|
843
|
+
f"'out_features': {out_features}, 'in_features': {in_features}.")
|
|
844
|
+
if weight_init is None:
|
|
845
|
+
weight_init = HeUniform(math.sqrt(5))
|
|
846
|
+
self.weight = Parameter(initializer(
|
|
847
|
+
weight_init, [out_features, in_features], dtype=dtype), name="weight")
|
|
848
|
+
|
|
849
|
+
self.bias = None
|
|
850
|
+
if self.has_bias:
|
|
851
|
+
if isinstance(bias_init, Tensor):
|
|
852
|
+
if bias_init.ndim != 1 or bias_init.shape[0] != out_features:
|
|
853
|
+
raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must "
|
|
854
|
+
f"be equal to 1, and the first dim must be equal to 'out_features'. But got "
|
|
855
|
+
f"'bias_init': {bias_init}, 'out_features': {out_features}.")
|
|
856
|
+
if bias_init is None:
|
|
857
|
+
bound = 1 / math.sqrt(in_features)
|
|
858
|
+
bias_init = Uniform(scale=bound)
|
|
859
|
+
self.bias = Parameter(initializer(
|
|
860
|
+
bias_init, [out_features], dtype=dtype), name="bias")
|
|
861
|
+
|
|
862
|
+
def construct(self, x):
|
|
863
|
+
x = self.dense(x, self.weight, self.bias)
|
|
864
|
+
return x
|
|
865
|
+
|
|
866
|
+
def extend_repr(self):
|
|
867
|
+
s = f'input_features={self.in_features}, output_features={self.out_features}'
|
|
868
|
+
if self.has_bias:
|
|
869
|
+
s += f', has_bias={self.has_bias}'
|
|
870
|
+
return s
|
|
871
|
+
|
|
872
|
+
|
|
746
873
|
@constexpr
|
|
747
874
|
def _is_equal_one(x):
|
|
748
875
|
if x is None:
|