mindspore 2.3.0__cp39-cp39-win_amd64.whl → 2.4.1__cp39-cp39-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (287) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/__init__.py +3 -1
  3. mindspore/_c_dataengine.cp39-win_amd64.pyd +0 -0
  4. mindspore/_c_expression.cp39-win_amd64.pyd +0 -0
  5. mindspore/_c_mindrecord.cp39-win_amd64.pyd +0 -0
  6. mindspore/_checkparam.py +50 -9
  7. mindspore/_extends/parse/compile_config.py +41 -0
  8. mindspore/_extends/parse/parser.py +9 -7
  9. mindspore/_extends/parse/standard_method.py +52 -14
  10. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  11. mindspore/amp.py +24 -10
  12. mindspore/avcodec-59.dll +0 -0
  13. mindspore/avdevice-59.dll +0 -0
  14. mindspore/avfilter-8.dll +0 -0
  15. mindspore/avformat-59.dll +0 -0
  16. mindspore/avutil-57.dll +0 -0
  17. mindspore/common/__init__.py +6 -4
  18. mindspore/common/_pijit_context.py +190 -0
  19. mindspore/common/_register_for_tensor.py +2 -1
  20. mindspore/common/_tensor_overload.py +139 -0
  21. mindspore/common/api.py +102 -87
  22. mindspore/common/dump.py +5 -6
  23. mindspore/common/generator.py +1 -7
  24. mindspore/common/hook_handle.py +14 -26
  25. mindspore/common/initializer.py +51 -15
  26. mindspore/common/mindir_util.py +2 -2
  27. mindspore/common/parameter.py +62 -15
  28. mindspore/common/recompute.py +39 -9
  29. mindspore/common/sparse_tensor.py +7 -3
  30. mindspore/common/tensor.py +183 -37
  31. mindspore/communication/__init__.py +1 -1
  32. mindspore/communication/_comm_helper.py +38 -3
  33. mindspore/communication/comm_func.py +315 -60
  34. mindspore/communication/management.py +14 -14
  35. mindspore/context.py +132 -22
  36. mindspore/dataset/__init__.py +1 -1
  37. mindspore/dataset/audio/__init__.py +1 -1
  38. mindspore/dataset/core/config.py +7 -0
  39. mindspore/dataset/core/validator_helpers.py +7 -0
  40. mindspore/dataset/engine/cache_client.py +1 -1
  41. mindspore/dataset/engine/datasets.py +72 -44
  42. mindspore/dataset/engine/datasets_audio.py +7 -7
  43. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  44. mindspore/dataset/engine/datasets_text.py +20 -20
  45. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  46. mindspore/dataset/engine/datasets_vision.py +33 -33
  47. mindspore/dataset/engine/iterators.py +29 -0
  48. mindspore/dataset/engine/obs/util.py +7 -0
  49. mindspore/dataset/engine/queue.py +114 -60
  50. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  51. mindspore/dataset/engine/validators.py +34 -14
  52. mindspore/dataset/text/__init__.py +1 -4
  53. mindspore/dataset/transforms/__init__.py +0 -3
  54. mindspore/dataset/utils/line_reader.py +2 -0
  55. mindspore/dataset/vision/__init__.py +1 -4
  56. mindspore/dataset/vision/utils.py +1 -1
  57. mindspore/dataset/vision/validators.py +2 -1
  58. mindspore/dnnl.dll +0 -0
  59. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  60. mindspore/experimental/es/embedding_service.py +883 -0
  61. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  62. mindspore/experimental/llm_boost/__init__.py +21 -0
  63. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  64. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  65. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  66. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  67. mindspore/experimental/llm_boost/register.py +129 -0
  68. mindspore/experimental/llm_boost/utils.py +31 -0
  69. mindspore/experimental/optim/adamw.py +85 -0
  70. mindspore/experimental/optim/optimizer.py +3 -0
  71. mindspore/hal/__init__.py +3 -3
  72. mindspore/hal/contiguous_tensors_handle.py +175 -0
  73. mindspore/hal/stream.py +18 -0
  74. mindspore/include/api/model_group.h +13 -1
  75. mindspore/include/api/types.h +10 -10
  76. mindspore/include/dataset/config.h +2 -2
  77. mindspore/include/dataset/constants.h +2 -2
  78. mindspore/include/dataset/execute.h +2 -2
  79. mindspore/include/dataset/vision.h +4 -0
  80. mindspore/jpeg62.dll +0 -0
  81. mindspore/log.py +1 -1
  82. mindspore/mindrecord/filewriter.py +68 -51
  83. mindspore/mindspore_backend.dll +0 -0
  84. mindspore/mindspore_common.dll +0 -0
  85. mindspore/mindspore_core.dll +0 -0
  86. mindspore/mindspore_glog.dll +0 -0
  87. mindspore/mindspore_np_dtype.dll +0 -0
  88. mindspore/mindspore_ops.dll +0 -0
  89. mindspore/mint/__init__.py +983 -46
  90. mindspore/mint/distributed/__init__.py +31 -0
  91. mindspore/mint/distributed/distributed.py +254 -0
  92. mindspore/mint/nn/__init__.py +268 -23
  93. mindspore/mint/nn/functional.py +125 -19
  94. mindspore/mint/nn/layer/__init__.py +39 -0
  95. mindspore/mint/nn/layer/activation.py +133 -0
  96. mindspore/mint/nn/layer/normalization.py +477 -0
  97. mindspore/mint/nn/layer/pooling.py +110 -0
  98. mindspore/mint/optim/adamw.py +26 -13
  99. mindspore/mint/special/__init__.py +63 -0
  100. mindspore/multiprocessing/__init__.py +2 -1
  101. mindspore/nn/__init__.py +0 -1
  102. mindspore/nn/cell.py +276 -96
  103. mindspore/nn/layer/activation.py +211 -44
  104. mindspore/nn/layer/basic.py +137 -10
  105. mindspore/nn/layer/embedding.py +137 -2
  106. mindspore/nn/layer/normalization.py +101 -5
  107. mindspore/nn/layer/padding.py +34 -48
  108. mindspore/nn/layer/pooling.py +161 -7
  109. mindspore/nn/layer/transformer.py +3 -3
  110. mindspore/nn/loss/__init__.py +2 -2
  111. mindspore/nn/loss/loss.py +84 -6
  112. mindspore/nn/optim/__init__.py +2 -1
  113. mindspore/nn/optim/adadelta.py +1 -1
  114. mindspore/nn/optim/adam.py +1 -1
  115. mindspore/nn/optim/lamb.py +1 -1
  116. mindspore/nn/optim/tft_wrapper.py +124 -0
  117. mindspore/nn/wrap/cell_wrapper.py +12 -23
  118. mindspore/nn/wrap/grad_reducer.py +5 -5
  119. mindspore/nn/wrap/loss_scale.py +17 -3
  120. mindspore/numpy/__init__.py +1 -1
  121. mindspore/numpy/array_creations.py +65 -68
  122. mindspore/numpy/array_ops.py +64 -60
  123. mindspore/numpy/fft.py +610 -75
  124. mindspore/numpy/logic_ops.py +11 -10
  125. mindspore/numpy/math_ops.py +85 -84
  126. mindspore/numpy/utils_const.py +4 -4
  127. mindspore/opencv_core452.dll +0 -0
  128. mindspore/opencv_imgcodecs452.dll +0 -0
  129. mindspore/opencv_imgproc452.dll +0 -0
  130. mindspore/ops/__init__.py +6 -4
  131. mindspore/ops/_grad_experimental/grad_array_ops.py +0 -11
  132. mindspore/ops/_grad_experimental/grad_comm_ops.py +67 -4
  133. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  134. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  135. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  136. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  137. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +91 -7
  138. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  139. mindspore/ops/auto_generate/gen_extend_func.py +767 -13
  140. mindspore/ops/auto_generate/gen_ops_def.py +2452 -364
  141. mindspore/ops/auto_generate/gen_ops_prim.py +5442 -1756
  142. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  143. mindspore/ops/composite/base.py +85 -48
  144. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  145. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  146. mindspore/ops/function/__init__.py +22 -0
  147. mindspore/ops/function/array_func.py +492 -153
  148. mindspore/ops/function/debug_func.py +113 -1
  149. mindspore/ops/function/fft_func.py +15 -2
  150. mindspore/ops/function/grad/grad_func.py +3 -2
  151. mindspore/ops/function/math_func.py +564 -207
  152. mindspore/ops/function/nn_func.py +817 -383
  153. mindspore/ops/function/other_func.py +3 -2
  154. mindspore/ops/function/random_func.py +402 -12
  155. mindspore/ops/function/reshard_func.py +13 -11
  156. mindspore/ops/function/sparse_unary_func.py +1 -1
  157. mindspore/ops/function/vmap_func.py +3 -2
  158. mindspore/ops/functional.py +24 -14
  159. mindspore/ops/op_info_register.py +3 -3
  160. mindspore/ops/operations/__init__.py +7 -2
  161. mindspore/ops/operations/_grad_ops.py +2 -76
  162. mindspore/ops/operations/_infer_ops.py +1 -1
  163. mindspore/ops/operations/_inner_ops.py +71 -94
  164. mindspore/ops/operations/array_ops.py +14 -146
  165. mindspore/ops/operations/comm_ops.py +63 -53
  166. mindspore/ops/operations/custom_ops.py +83 -19
  167. mindspore/ops/operations/debug_ops.py +42 -10
  168. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  169. mindspore/ops/operations/manually_defined/ops_def.py +273 -20
  170. mindspore/ops/operations/math_ops.py +12 -223
  171. mindspore/ops/operations/nn_ops.py +20 -114
  172. mindspore/ops/operations/other_ops.py +7 -4
  173. mindspore/ops/operations/random_ops.py +46 -1
  174. mindspore/ops/primitive.py +18 -6
  175. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  176. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  177. mindspore/ops_generate/gen_constants.py +36 -0
  178. mindspore/ops_generate/gen_ops.py +67 -52
  179. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  180. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  181. mindspore/ops_generate/op_proto.py +10 -3
  182. mindspore/ops_generate/pyboost_utils.py +14 -1
  183. mindspore/ops_generate/template.py +43 -21
  184. mindspore/parallel/__init__.py +3 -1
  185. mindspore/parallel/_auto_parallel_context.py +31 -9
  186. mindspore/parallel/_cell_wrapper.py +85 -0
  187. mindspore/parallel/_parallel_serialization.py +47 -19
  188. mindspore/parallel/_tensor.py +127 -13
  189. mindspore/parallel/_utils.py +53 -22
  190. mindspore/parallel/algo_parameter_config.py +5 -5
  191. mindspore/parallel/checkpoint_transform.py +46 -39
  192. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  193. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  194. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  195. mindspore/parallel/parameter_broadcast.py +3 -4
  196. mindspore/parallel/shard.py +162 -31
  197. mindspore/parallel/transform_safetensors.py +1146 -0
  198. mindspore/profiler/__init__.py +2 -1
  199. mindspore/profiler/common/constant.py +29 -0
  200. mindspore/profiler/common/registry.py +47 -0
  201. mindspore/profiler/common/util.py +28 -0
  202. mindspore/profiler/dynamic_profiler.py +694 -0
  203. mindspore/profiler/envprofiling.py +17 -19
  204. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  205. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  206. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  207. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  208. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  209. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  210. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  211. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  212. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  213. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  214. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  215. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  216. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  217. mindspore/profiler/parser/framework_parser.py +1 -391
  218. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  219. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  220. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  221. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  222. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  223. mindspore/profiler/parser/profiler_info.py +78 -6
  224. mindspore/profiler/profiler.py +153 -0
  225. mindspore/profiler/profiling.py +285 -413
  226. mindspore/rewrite/__init__.py +1 -2
  227. mindspore/rewrite/common/namespace.py +4 -4
  228. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  229. mindspore/run_check/_check_version.py +39 -104
  230. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  231. mindspore/swresample-4.dll +0 -0
  232. mindspore/swscale-6.dll +0 -0
  233. mindspore/tinyxml2.dll +0 -0
  234. mindspore/train/__init__.py +4 -3
  235. mindspore/train/_utils.py +105 -19
  236. mindspore/train/amp.py +171 -53
  237. mindspore/train/callback/__init__.py +2 -2
  238. mindspore/train/callback/_callback.py +4 -4
  239. mindspore/train/callback/_checkpoint.py +97 -31
  240. mindspore/train/callback/_cluster_monitor.py +1 -1
  241. mindspore/train/callback/_flops_collector.py +1 -0
  242. mindspore/train/callback/_loss_monitor.py +3 -3
  243. mindspore/train/callback/_on_request_exit.py +145 -31
  244. mindspore/train/callback/_summary_collector.py +5 -5
  245. mindspore/train/callback/_tft_register.py +375 -0
  246. mindspore/train/dataset_helper.py +15 -3
  247. mindspore/train/metrics/metric.py +3 -3
  248. mindspore/train/metrics/roc.py +4 -4
  249. mindspore/train/mind_ir_pb2.py +44 -39
  250. mindspore/train/model.py +154 -58
  251. mindspore/train/serialization.py +342 -128
  252. mindspore/turbojpeg.dll +0 -0
  253. mindspore/utils/__init__.py +21 -0
  254. mindspore/utils/utils.py +60 -0
  255. mindspore/version.py +1 -1
  256. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/METADATA +13 -7
  257. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/RECORD +260 -254
  258. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/WHEEL +1 -1
  259. mindspore/include/c_api/ms/abstract.h +0 -67
  260. mindspore/include/c_api/ms/attribute.h +0 -197
  261. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  262. mindspore/include/c_api/ms/base/macros.h +0 -32
  263. mindspore/include/c_api/ms/base/status.h +0 -33
  264. mindspore/include/c_api/ms/base/types.h +0 -283
  265. mindspore/include/c_api/ms/context.h +0 -102
  266. mindspore/include/c_api/ms/graph.h +0 -160
  267. mindspore/include/c_api/ms/node.h +0 -606
  268. mindspore/include/c_api/ms/tensor.h +0 -161
  269. mindspore/include/c_api/ms/value.h +0 -84
  270. mindspore/mindspore_shared_lib.dll +0 -0
  271. mindspore/nn/extend/basic.py +0 -140
  272. mindspore/nn/extend/embedding.py +0 -143
  273. mindspore/nn/extend/layer/normalization.py +0 -109
  274. mindspore/nn/extend/pooling.py +0 -117
  275. mindspore/nn/layer/embedding_service.py +0 -531
  276. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  277. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  278. mindspore/ops/extend/__init__.py +0 -53
  279. mindspore/ops/extend/array_func.py +0 -218
  280. mindspore/ops/extend/math_func.py +0 -76
  281. mindspore/ops/extend/nn_func.py +0 -308
  282. mindspore/ops/silent_check.py +0 -162
  283. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  284. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  285. mindspore/train/callback/_mindio_ttp.py +0 -443
  286. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/entry_points.txt +0 -0
  287. {mindspore-2.3.0.dist-info → mindspore-2.4.1.dist-info}/top_level.txt +0 -0
@@ -33,6 +33,7 @@ __all__ = ['Softmin',
33
33
  'Softmax',
34
34
  'Softmax2d',
35
35
  'LogSoftmax',
36
+ 'LogSoftmaxExt',
36
37
  'ReLU',
37
38
  'ReLU6',
38
39
  'RReLU',
@@ -46,6 +47,7 @@ __all__ = ['Softmin',
46
47
  'Sigmoid',
47
48
  'Softsign',
48
49
  'PReLU',
50
+ 'PReLUExt',
49
51
  'get_activation',
50
52
  'LeakyReLU',
51
53
  'HSigmoid',
@@ -279,6 +281,35 @@ class Softmax(Cell):
279
281
  return self.softmax(input)
280
282
 
281
283
 
284
+ class SoftmaxExt(Cell):
285
+ r"""
286
+ Applies the Softmax function to an n-dimensional input Tensor.
287
+
288
+ For details, please refer to :func:`mindspore.mint.nn.functional.softmax`.
289
+
290
+ Supported Platforms:
291
+ ``Ascend``
292
+
293
+ Examples:
294
+ >>> import mindspore
295
+ >>> from mindspore import Tensor, nn
296
+ >>> import numpy as np
297
+ >>> input = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
298
+ >>> softmax = nn.SoftmaxExt()
299
+ >>> output = softmax(input)
300
+ >>> print(output)
301
+ [0.03168 0.01166 0.0861 0.636 0.2341 ]
302
+ """
303
+
304
+ def __init__(self, dim=None):
305
+ """Initialize Softmax."""
306
+ super(SoftmaxExt, self).__init__()
307
+ self.dim = dim
308
+
309
+ def construct(self, input):
310
+ return ops.function.nn_func.softmax_ext(input, self.dim)
311
+
312
+
282
313
  class LogSoftmax(Cell):
283
314
  r"""
284
315
  Applies the LogSoftmax function to n-dimensional input tensor element-wise.
@@ -329,6 +360,51 @@ class LogSoftmax(Cell):
329
360
  return self.log_softmax(x)
330
361
 
331
362
 
363
+ class LogSoftmaxExt(Cell):
364
+ r"""
365
+ Applies the Log Softmax function to the input tensor on the specified axis.
366
+ Supposes a slice in the given axis, :math:`x` for each element :math:`x_i`,
367
+ the Log Softmax function is shown as follows:
368
+
369
+ .. math::
370
+ \text{output}(x_i) = \log \left(\frac{\exp(x_i)} {\sum_{j = 0}^{N-1}\exp(x_j)}\right),
371
+
372
+ where :math:`N` is the length of the Tensor.
373
+
374
+ Args:
375
+ dim (int, optional): The axis to perform the Log softmax operation. Default: ``None`` .
376
+
377
+ Returns:
378
+ Tensor, with the same shape as the input.
379
+
380
+ Raises:
381
+ ValueError: If `dim` is not in range [-len(input.shape), len(input.shape)).
382
+
383
+ Supported Platforms:
384
+ ``Ascend``
385
+
386
+ Examples:
387
+ >>> import mindspore
388
+ >>> from mindspore import Tensor, nn
389
+ >>> import numpy as np
390
+ >>> x = Tensor(np.array([[-1.0, 4.0, -8.0], [2.0, -5.0, 9.0]]), mindspore.float32)
391
+ >>> log_softmax = nn.LogSoftmaxExt(dim=-1)
392
+ >>> output = log_softmax(x)
393
+ >>> print(output)
394
+ [[-5.00672150e+00 -6.72150636e-03 -1.20067215e+01]
395
+ [-7.00091219e+00 -1.40009127e+01 -9.12250078e-04]]
396
+ """
397
+
398
+ def __init__(self, dim=None):
399
+ """Initialize LogSoftmaxExt."""
400
+ super(LogSoftmaxExt, self).__init__()
401
+ self.log_softmax = P.LogSoftmaxExt()
402
+ self.dim = dim
403
+
404
+ def construct(self, x):
405
+ return self.log_softmax(x, dim=self.dim)
406
+
407
+
332
408
  class ELU(Cell):
333
409
  r"""
334
410
  Applies the exponential linear unit function element-wise.
@@ -434,8 +510,8 @@ class ReLU(Cell):
434
510
  super(ReLU, self).__init__()
435
511
  self.relu = P.ReLU()
436
512
 
437
- def construct(self, x):
438
- return self.relu(x)
513
+ def construct(self, input):
514
+ return self.relu(input)
439
515
 
440
516
 
441
517
  class ReLU6(Cell):
@@ -898,6 +974,13 @@ class GELU(Cell):
898
974
  Outputs:
899
975
  Tensor, with the same type and shape as the `x`.
900
976
 
977
+ Note:
978
+ when calculating the input gradient of GELU with an input value of infinity, there are differences
979
+ in the output of the backward between ``Ascend`` and ``GPU``.
980
+ when x is -inf, the computation result of ``Ascend`` is 0, and the computation result of ``GPU`` is Nan.
981
+ when x is inf, the computation result of ``Ascend`` is dy, and the computation result of ``GPU`` is Nan.
982
+ In mathematical terms, the result of Ascend has higher precision.
983
+
901
984
  Raises:
902
985
  TypeError: If dtype of `x` is not one of float16, float32, or float64.
903
986
 
@@ -1164,14 +1247,85 @@ class PReLU(Cell):
1164
1247
  return self.prelu(x, F.cast(self.w, x.dtype))
1165
1248
 
1166
1249
 
1250
+ class PReLUExt(Cell):
1251
+ r"""
1252
+ Applies PReLU activation function element-wise.
1253
+
1254
+ PReLU is defined as:
1255
+
1256
+ .. math::
1257
+
1258
+ PReLU(x_i)= \max(0, x_i) + w * \min(0, x_i),
1259
+
1260
+ where :math:`x_i` is an element of an channel of the input.
1261
+
1262
+ Here :math:`w` is a learnable parameter with a default initial value 0.25.
1263
+ Parameter :math:`w` has dimensionality of the argument channel. If called without argument
1264
+ channel, a single parameter :math:`w` will be shared across all channels.
1265
+
1266
+ PReLU Activation Function Graph:
1267
+
1268
+ .. image:: ../images/PReLU2.png
1269
+ :align: center
1270
+
1271
+ .. note::
1272
+ Channel dim is the 2nd dim of input. When input has dims < 2, then there is
1273
+ no channel dim and the number of channels = 1.
1274
+
1275
+ Args:
1276
+ num_parameters (int): number of `w` to learn. Although it takes an int as input,
1277
+ there is only two legitimate values: 1, or the number of channels at Tensor `input`. Default: ``1`` .
1278
+ init (float): the initial value of `w`. Default: ``0.25`` .
1279
+ dtype (mindspore.dtype, optional): the type of `w`. Default: ``None`` . Supported data type
1280
+ is {float16, float32, bfloat16}.
1281
+
1282
+ Inputs:
1283
+ - **input** (Tensor) - The input of PReLU.
1284
+
1285
+ Outputs:
1286
+ Tensor, with the same dtype and shape as the `input`.
1287
+
1288
+ Supported Platforms:
1289
+ ``Ascend``
1290
+
1291
+ Examples:
1292
+ >>> import mindspore
1293
+ >>> from mindspore import Tensor, nn
1294
+ >>> import numpy as np
1295
+ >>> x = Tensor(np.array([[[[0.1, 0.6], [0.9, 0.9]]]]), mindspore.float32)
1296
+ >>> prelu = nn.PReLUExt()
1297
+ >>> output = prelu(x)
1298
+ >>> print(output)
1299
+ [[[[0.1 0.6]
1300
+ [0.9 0.9]]]]
1301
+
1302
+ """
1303
+
1304
+ def __init__(self, num_parameters=1, init=0.25, dtype=None):
1305
+ """Initialize PReLUExt."""
1306
+ super(PReLUExt, self).__init__()
1307
+ tmp = np.empty((num_parameters,), dtype=np.float32)
1308
+ tmp.fill(init)
1309
+ w = Tensor(tmp, dtype=dtype)
1310
+ self.weight = Parameter(w, name='weight')
1311
+
1312
+ def construct(self, input):
1313
+ return ops.prelu(input, self.weight)
1314
+
1315
+
1167
1316
  class HSwish(Cell):
1168
1317
  r"""
1169
- Applies hswish-type activation element-wise.
1318
+ Applies Hard Swish activation function element-wise.
1170
1319
 
1171
1320
  Hard swish is defined as:
1172
1321
 
1173
1322
  .. math::
1174
- \text{hswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6},
1323
+ \text{Hardswish}(input) =
1324
+ \begin{cases}
1325
+ 0, & \text{ if } input \leq -3, \\
1326
+ input, & \text{ if } input \geq +3, \\
1327
+ input*(input + 3)/6, & \text{ otherwise }
1328
+ \end{cases}
1175
1329
 
1176
1330
  HSwish Activation Function Graph:
1177
1331
 
@@ -1179,14 +1333,14 @@ class HSwish(Cell):
1179
1333
  :align: center
1180
1334
 
1181
1335
  Inputs:
1182
- - **x** (Tensor) - The input of HSwish, data type must be float16 or float32.
1183
- The shape is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
1336
+ - **input** (Tensor) - The input of HSwish.
1184
1337
 
1185
1338
  Outputs:
1186
- Tensor, with the same type and shape as the `x`.
1339
+ Tensor, with the same type and shape as the `input`.
1187
1340
 
1188
1341
  Raises:
1189
- TypeError: If dtype of `x` is neither float16 nor float32.
1342
+ TypeError: If `input` is not a tensor.
1343
+ TypeError: If `input` is neither int nor float.
1190
1344
 
1191
1345
  Supported Platforms:
1192
1346
  ``Ascend`` ``GPU`` ``CPU``
@@ -1195,9 +1349,9 @@ class HSwish(Cell):
1195
1349
  >>> import mindspore
1196
1350
  >>> from mindspore import Tensor, nn
1197
1351
  >>> import numpy as np
1198
- >>> x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
1352
+ >>> input = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
1199
1353
  >>> hswish = nn.HSwish()
1200
- >>> result = hswish(x)
1354
+ >>> result = hswish(input)
1201
1355
  >>> print(result)
1202
1356
  [-0.3333 -0.3333 0. 1.667 0.6665]
1203
1357
  """
@@ -1207,18 +1361,23 @@ class HSwish(Cell):
1207
1361
  super(HSwish, self).__init__()
1208
1362
  self.hswish = P.HSwish()
1209
1363
 
1210
- def construct(self, x):
1211
- return self.hswish(x)
1364
+ def construct(self, input):
1365
+ return self.hswish(input)
1212
1366
 
1213
1367
 
1214
1368
  class HSigmoid(Cell):
1215
1369
  r"""
1216
- Applies Hard sigmoid activation function element-wise.
1370
+ Applies Hard Sigmoid activation function element-wise.
1217
1371
 
1218
- Hard sigmoid is defined as:
1372
+ Hard Sigmoid is defined as:
1219
1373
 
1220
1374
  .. math::
1221
- \text{hsigmoid}(x_{i}) = \max(0, \min(1, \frac{x_{i} + 3}{6})),
1375
+ \text{Hardsigmoid}(input) =
1376
+ \begin{cases}
1377
+ 0, & \text{ if } input \leq -3, \\
1378
+ 1, & \text{ if } input \geq +3, \\
1379
+ input/6 + 1/2, & \text{ otherwise }
1380
+ \end{cases}
1222
1381
 
1223
1382
  HSigmoid Activation Function Graph:
1224
1383
 
@@ -1226,13 +1385,14 @@ class HSigmoid(Cell):
1226
1385
  :align: center
1227
1386
 
1228
1387
  Inputs:
1229
- - **input_x** (Tensor) - The input of HSigmoid. Tensor of any dimension.
1388
+ - **input** (Tensor) - The input of HSigmoid.
1230
1389
 
1231
1390
  Outputs:
1232
- Tensor, with the same type and shape as the `input_x`.
1391
+ Tensor, with the same type and shape as the `input`.
1233
1392
 
1234
1393
  Raises:
1235
- TypeError: If `input_x` is not a Tensor.
1394
+ TypeError: If `input` is not a Tensor.
1395
+ TypeError: If `input` is neither int nor float.
1236
1396
 
1237
1397
  Supported Platforms:
1238
1398
  ``Ascend`` ``GPU`` ``CPU``
@@ -1241,9 +1401,9 @@ class HSigmoid(Cell):
1241
1401
  >>> import mindspore
1242
1402
  >>> from mindspore import Tensor, nn
1243
1403
  >>> import numpy as np
1244
- >>> x = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
1404
+ >>> input = Tensor(np.array([-1, -2, 0, 2, 1]), mindspore.float16)
1245
1405
  >>> hsigmoid = nn.HSigmoid()
1246
- >>> result = hsigmoid(x)
1406
+ >>> result = hsigmoid(input)
1247
1407
  >>> print(result)
1248
1408
  [0.3333 0.1666 0.5 0.8335 0.6665]
1249
1409
  """
@@ -1253,8 +1413,8 @@ class HSigmoid(Cell):
1253
1413
  super(HSigmoid, self).__init__()
1254
1414
  self.hsigmoid = P.HSigmoid()
1255
1415
 
1256
- def construct(self, input_x):
1257
- return self.hsigmoid(input_x)
1416
+ def construct(self, input):
1417
+ return self.hsigmoid(input)
1258
1418
 
1259
1419
 
1260
1420
  class LogSigmoid(Cell):
@@ -1370,21 +1530,22 @@ class SoftShrink(Cell):
1370
1530
  :align: center
1371
1531
 
1372
1532
  Args:
1373
- lambd (float): the :math:`\lambda` must be no less than zero for the SoftShrink formulation.
1374
- Default: ``0.5`` .
1533
+ lambd (number, optional): The threshold :math:`\lambda` defined by the Soft Shrink formula.
1534
+ It should be greater than or equal to 0, default: ``0.5`` .
1375
1535
 
1376
1536
  Inputs:
1377
- - **input_x** (Tensor) - The input of SoftShrink with data type of float16 or float32.
1378
- Any number of additional dimensions.
1537
+ - **input** (Tensor) - The input of Soft Shrink. Supported dtypes:
1538
+
1539
+ - Ascend: float16, float32, bfloat16.
1540
+ - CPU/GPU: float16, float32.
1379
1541
 
1380
1542
  Outputs:
1381
- Tensor, has the same shape and data type as `input_x`.
1543
+ Tensor, the same shape and data type as the input.
1382
1544
 
1383
1545
  Raises:
1384
- TypeError: If lambd is not a float.
1385
- TypeError: If input_x is not a Tensor.
1386
- TypeError: If dtype of input_x is neither float16 nor float32.
1387
- ValueError: If lambd is less than 0.
1546
+ TypeError: If `lambd` is not a float, int or bool.
1547
+ TypeError: If `input` is not a tensor.
1548
+ TypeError: If dtype of `input` is not float16, float32 or bfloat16.
1388
1549
 
1389
1550
  Supported Platforms:
1390
1551
  ``Ascend`` ``GPU`` ``CPU``
@@ -1393,9 +1554,9 @@ class SoftShrink(Cell):
1393
1554
  >>> import mindspore
1394
1555
  >>> from mindspore import Tensor, nn
1395
1556
  >>> import numpy as np
1396
- >>> input_x = Tensor(np.array([[ 0.5297, 0.7871, 1.1754], [ 0.7836, 0.6218, -1.1542]]), mindspore.float16)
1557
+ >>> input = Tensor(np.array([[ 0.5297, 0.7871, 1.1754], [ 0.7836, 0.6218, -1.1542]]), mindspore.float16)
1397
1558
  >>> softshrink = nn.SoftShrink()
1398
- >>> output = softshrink(input_x)
1559
+ >>> output = softshrink(input)
1399
1560
  >>> print(output)
1400
1561
  [[ 0.02979 0.287 0.676 ]
1401
1562
  [ 0.2837 0.1216 -0.6543 ]]
@@ -1405,8 +1566,8 @@ class SoftShrink(Cell):
1405
1566
  super(SoftShrink, self).__init__()
1406
1567
  self.softshrink = P.SoftShrink(lambd)
1407
1568
 
1408
- def construct(self, input_x):
1409
- output = self.softshrink(input_x)
1569
+ def construct(self, input):
1570
+ output = self.softshrink(input)
1410
1571
  return output
1411
1572
 
1412
1573
 
@@ -1430,17 +1591,21 @@ class HShrink(Cell):
1430
1591
  :align: center
1431
1592
 
1432
1593
  Args:
1433
- lambd (float): The threshold :math:`\lambda` defined by the Hard Shrink formula. Default: ``0.5`` .
1594
+ lambd (number, optional): The threshold :math:`\lambda` defined by the Hard Shrink formula. Default: ``0.5`` .
1434
1595
 
1435
1596
  Inputs:
1436
- - **input_x** (Tensor) - The input of Hard Shrink with data type of float16 or float32.
1597
+ - **input** (Tensor) - The input of Hard Shrink. Supported dtypes:
1598
+
1599
+ - Ascend: float16, float32, bfloat16.
1600
+ - CPU/GPU: float16, float32.
1437
1601
 
1438
1602
  Outputs:
1439
1603
  Tensor, the same shape and data type as the input.
1440
1604
 
1441
1605
  Raises:
1442
- TypeError: If `lambd` is not a float.
1443
- TypeError: If dtype of `input_x` is neither float16 nor float32.
1606
+ TypeError: If `lambd` is not a float, int or bool.
1607
+ TypeError: If `input` is not a tensor.
1608
+ TypeError: If dtype of `input` is not float16, float32 or bfloat16.
1444
1609
 
1445
1610
  Supported Platforms:
1446
1611
  ``Ascend`` ``GPU`` ``CPU``
@@ -1449,20 +1614,20 @@ class HShrink(Cell):
1449
1614
  >>> import mindspore
1450
1615
  >>> from mindspore import Tensor, nn
1451
1616
  >>> import numpy as np
1452
- >>> input_x = Tensor(np.array([[ 0.5, 1, 2.0], [0.0533,0.0776,-2.1233]]), mindspore.float32)
1617
+ >>> input = Tensor(np.array([[0.5, 1, 2.0], [0.0533, 0.0776, -2.1233]]), mindspore.float32)
1453
1618
  >>> hshrink = nn.HShrink()
1454
- >>> output = hshrink(input_x)
1619
+ >>> output = hshrink(input)
1455
1620
  >>> print(output)
1456
1621
  [[ 0. 1. 2. ]
1457
- [ 0. 0. -2.1233]]
1622
+ [ 0. 0. -2.1233]]
1458
1623
  """
1459
1624
 
1460
1625
  def __init__(self, lambd=0.5):
1461
1626
  super(HShrink, self).__init__()
1462
1627
  self.hshrink = P.HShrink(lambd)
1463
1628
 
1464
- def construct(self, input_x):
1465
- return self.hshrink(input_x)
1629
+ def construct(self, input):
1630
+ return self.hshrink(input)
1466
1631
 
1467
1632
 
1468
1633
  class Threshold(Cell):
@@ -1602,6 +1767,7 @@ _activation = {
1602
1767
  'softmax': Softmax,
1603
1768
  'softmax2d': Softmax2d,
1604
1769
  'logsoftmax': LogSoftmax,
1770
+ 'logsoftmaxExt': LogSoftmaxExt,
1605
1771
  'relu': ReLU,
1606
1772
  'relu6': ReLU6,
1607
1773
  'rrelu': RReLU,
@@ -1615,6 +1781,7 @@ _activation = {
1615
1781
  'sigmoid': Sigmoid,
1616
1782
  'softsign': Softsign,
1617
1783
  'prelu': PReLU,
1784
+ 'preluExt': PReLUExt,
1618
1785
  'leakyrelu': LeakyReLU,
1619
1786
  'hswish': HSwish,
1620
1787
  'hsigmoid': HSigmoid,
@@ -40,7 +40,7 @@ from mindspore.common._decorator import deprecated
40
40
  from mindspore.ops.auto_generate import dropout_ext_op, fold_ext
41
41
  from mindspore.common.generator import default_generator
42
42
 
43
- __all__ = ['Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', 'Tril', 'Triu',
43
+ __all__ = ['Dropout', 'Flatten', 'Dense', 'Linear', 'ClipByNorm', 'Norm', 'OneHot', 'Pad', 'Unfold', 'Tril', 'Triu',
44
44
  'MatrixDiag', 'MatrixDiagPart', 'MatrixSetDiag', 'L1Regularizer', 'Dropout1d',
45
45
  'Dropout2d', 'Dropout3d', 'Upsample', 'Roll', 'Identity', 'Unflatten', 'DropoutExt']
46
46
 
@@ -510,8 +510,8 @@ class UpsampleExt(Cell):
510
510
  self.align_corners = align_corners
511
511
  self.recompute_scale_factor = recompute_scale_factor
512
512
 
513
- def construct(self, x):
514
- out = interpolate_ext(x, self.size, self.scale_factor, self.mode,
513
+ def construct(self, input):
514
+ out = interpolate_ext(input, self.size, self.scale_factor, self.mode,
515
515
  self.align_corners, self.recompute_scale_factor)
516
516
  return out
517
517
 
@@ -579,11 +579,15 @@ class Identity(Cell):
579
579
  r"""
580
580
  A placeholder identity operator that returns the same as input.
581
581
 
582
+ Args:
583
+ args (Any): Any argument.
584
+ kwargs (Any): Any keyword argument.
585
+
582
586
  Inputs:
583
- - **x** (Any) - The input of Identity.
587
+ - **input** (Any) - The input of Identity.
584
588
 
585
589
  Outputs:
586
- The same as `x`.
590
+ The same as `input`.
587
591
 
588
592
  Supported Platforms:
589
593
  ``Ascend`` ``GPU`` ``CPU``
@@ -592,19 +596,19 @@ class Identity(Cell):
592
596
  >>> import mindspore
593
597
  >>> from mindspore import Tensor, nn
594
598
  >>> import numpy as np
595
- >>> x = Tensor(np.array([1, 2, 3, 4]), mindspore.int64)
599
+ >>> input = Tensor(np.array([1, 2, 3, 4]), mindspore.int64)
596
600
  >>> net = nn.Identity()
597
- >>> output = net(x)
601
+ >>> output = net(input)
598
602
  >>> print(output)
599
603
  [1 2 3 4]
600
604
  """
601
605
 
602
- def __init__(self):
606
+ def __init__(self, *args, **kwargs):
603
607
  """Initialize Identity."""
604
608
  super(Identity, self).__init__()
605
609
 
606
- def construct(self, x):
607
- return x
610
+ def construct(self, input):
611
+ return input
608
612
 
609
613
 
610
614
  class Dense(Cell):
@@ -621,6 +625,9 @@ class Dense(Cell):
621
625
  data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector
622
626
  with the same data type as the :math:`X` created by the layer (only if has_bias is True).
623
627
 
628
+ .. warning::
629
+ In PYNATIVE mode, if `bias` is ``False`` , the `x` cannot be greater than 6D.
630
+
624
631
  Args:
625
632
  in_channels (int): The number of channels in the input space.
626
633
  out_channels (int): The number of channels in the output space.
@@ -635,6 +642,8 @@ class Dense(Cell):
635
642
  layer. Both activation name, e.g. 'relu', and mindspore activation function, e.g. mindspore.ops.ReLU(),
636
643
  are supported. Default: ``None`` .
637
644
  dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``mstype.float32`` .
645
+ When `weight_init` is Tensor, Parameter has the same data type as `weight_init` ,
646
+ in other cases, Parameter has the same data type as `dtype`, the same goes for `bias_init`.
638
647
 
639
648
  Inputs:
640
649
  - **x** (Tensor) - Tensor of shape :math:`(*, in\_channels)`. The `in_channels` in `Args` should be equal
@@ -651,6 +660,7 @@ class Dense(Cell):
651
660
  is not equal to `out_channels` or shape[1] of `weight_init` is not equal to `in_channels`.
652
661
  ValueError: If length of shape of `bias_init` is not equal to 1
653
662
  or shape[0] of `bias_init` is not equal to `out_channels`.
663
+ RuntimeError: If `bias` is ``False`` and `x` is greater than 6D in PYNATIVE mode.
654
664
 
655
665
  Supported Platforms:
656
666
  ``Ascend`` ``GPU`` ``CPU``
@@ -743,6 +753,123 @@ class Dense(Cell):
743
753
  return s
744
754
 
745
755
 
756
+ class Linear(Cell):
757
+ r"""
758
+ The linear connected layer.
759
+
760
+ Applies linear connected layer for the input. This layer implements the operation as:
761
+
762
+ .. math::
763
+ \text{outputs} = X * kernel + bias
764
+
765
+ .. warning::
766
+ In PYNATIVE mode, if `bias` is ``False`` , the `x` cannot be greater than 6D.
767
+
768
+ where :math:`X` is the input tensors, :math:`\text{kernel}` is a weight matrix with the same
769
+ data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector
770
+ with the same data type as the :math:`X` created by the layer (only if has_bias is True).
771
+
772
+ Args:
773
+ in_features (int): The number of features in the input space.
774
+ out_features (int): The number of features in the output space.
775
+ bias (bool): Specifies whether the layer uses a bias vector :math:`\text{bias}`. Default: ``True``.
776
+ weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
777
+ is same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
778
+ weight will be initialized using HeUniform.
779
+ bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
780
+ same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
781
+ bias will be initialized using Uniform.
782
+ dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``None`` .
783
+ If `dtype` is ``None`` , `dtype` is set to ``mstype.float32`` when initializing the method.
784
+ When `weight_init` is Tensor, Parameter has the same data type as `weight_init` ,
785
+ in other cases, Parameter has the same data type as `dtype`, the same goes for `bias_init`.
786
+
787
+ Inputs:
788
+ - **x** (Tensor) - Tensor of shape :math:`(*, in\_features)`. The `in_features` in `Args` should be equal
789
+ to :math:`in\_features` in `Inputs`.
790
+
791
+ Outputs:
792
+ Tensor of shape :math:`(*, out\_features)`.
793
+
794
+ Raises:
795
+ TypeError: If `in_features` or `out_features` is not an int.
796
+ TypeError: If `bias` is not a bool.
797
+ ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init`
798
+ is not equal to `out_features` or shape[1] of `weight_init` is not equal to `in_features`.
799
+ ValueError: If length of shape of `bias_init` is not equal to 1
800
+ or shape[0] of `bias_init` is not equal to `out_features`.
801
+ RuntimeError: If `bias` is ``False`` and `x` is greater than 6D in PYNATIVE mode.
802
+
803
+ Supported Platforms:
804
+ ``Ascend`` ``GPU`` ``CPU``
805
+
806
+ Examples:
807
+ >>> import mindspore
808
+ >>> from mindspore import Tensor
809
+ >>> from mindspore import nn
810
+ >>> import numpy as np
811
+ >>> x = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
812
+ >>> net = nn.mint.nn.Linear(3, 4)
813
+ >>> output = net(x)
814
+ >>> print(output.shape)
815
+ (2, 4)
816
+ """
817
+
818
+ @cell_attr_register(attrs=['has_bias'])
819
+ def __init__(self,
820
+ in_features,
821
+ out_features,
822
+ bias=True,
823
+ weight_init=None,
824
+ bias_init=None,
825
+ dtype=None):
826
+ """Initialize Linear."""
827
+ super(Linear, self).__init__()
828
+ self.in_features = Validator.check_positive_int(
829
+ in_features, "in_features", self.cls_name)
830
+ self.out_features = Validator.check_positive_int(
831
+ out_features, "out_features", self.cls_name)
832
+ self.has_bias = Validator.check_bool(
833
+ bias, "has_bias", self.cls_name)
834
+ self.dense = P.Dense()
835
+ if dtype is None:
836
+ dtype = mstype.float32
837
+ if isinstance(weight_init, Tensor):
838
+ if weight_init.ndim != 2 or weight_init.shape[0] != out_features or \
839
+ weight_init.shape[1] != in_features:
840
+ raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' must "
841
+ f"be equal to 2, and the first dim must be equal to 'out_features', and the "
842
+ f"second dim must be equal to 'in_features'. But got 'weight_init': {weight_init}, "
843
+ f"'out_features': {out_features}, 'in_features': {in_features}.")
844
+ if weight_init is None:
845
+ weight_init = HeUniform(math.sqrt(5))
846
+ self.weight = Parameter(initializer(
847
+ weight_init, [out_features, in_features], dtype=dtype), name="weight")
848
+
849
+ self.bias = None
850
+ if self.has_bias:
851
+ if isinstance(bias_init, Tensor):
852
+ if bias_init.ndim != 1 or bias_init.shape[0] != out_features:
853
+ raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must "
854
+ f"be equal to 1, and the first dim must be equal to 'out_features'. But got "
855
+ f"'bias_init': {bias_init}, 'out_features': {out_features}.")
856
+ if bias_init is None:
857
+ bound = 1 / math.sqrt(in_features)
858
+ bias_init = Uniform(scale=bound)
859
+ self.bias = Parameter(initializer(
860
+ bias_init, [out_features], dtype=dtype), name="bias")
861
+
862
+ def construct(self, x):
863
+ x = self.dense(x, self.weight, self.bias)
864
+ return x
865
+
866
+ def extend_repr(self):
867
+ s = f'input_features={self.in_features}, output_features={self.out_features}'
868
+ if self.has_bias:
869
+ s += f', has_bias={self.has_bias}'
870
+ return s
871
+
872
+
746
873
  @constexpr
747
874
  def _is_equal_one(x):
748
875
  if x is None: