mindspore 2.3.0__cp310-cp310-win_amd64.whl → 2.4.0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (308) hide show
  1. mindspore/.commit_id +1 -1
  2. mindspore/Microsoft.VisualStudio.Telemetry.dll +0 -0
  3. mindspore/Newtonsoft.Json.dll +0 -0
  4. mindspore/__init__.py +3 -1
  5. mindspore/_c_dataengine.cp310-win_amd64.pyd +0 -0
  6. mindspore/_c_expression.cp310-win_amd64.pyd +0 -0
  7. mindspore/_c_mindrecord.cp310-win_amd64.pyd +0 -0
  8. mindspore/_checkparam.py +50 -9
  9. mindspore/_extends/parse/compile_config.py +41 -0
  10. mindspore/_extends/parse/parser.py +9 -7
  11. mindspore/_extends/parse/standard_method.py +52 -14
  12. mindspore/_extends/pijit/pijit_func_white_list.py +350 -24
  13. mindspore/amp.py +24 -10
  14. mindspore/atlprov.dll +0 -0
  15. mindspore/avcodec-59.dll +0 -0
  16. mindspore/avdevice-59.dll +0 -0
  17. mindspore/avfilter-8.dll +0 -0
  18. mindspore/avformat-59.dll +0 -0
  19. mindspore/avutil-57.dll +0 -0
  20. mindspore/c1.dll +0 -0
  21. mindspore/c1xx.dll +0 -0
  22. mindspore/c2.dll +0 -0
  23. mindspore/common/__init__.py +6 -4
  24. mindspore/common/_pijit_context.py +190 -0
  25. mindspore/common/_register_for_tensor.py +2 -1
  26. mindspore/common/_tensor_overload.py +139 -0
  27. mindspore/common/api.py +102 -87
  28. mindspore/common/dump.py +5 -6
  29. mindspore/common/generator.py +1 -7
  30. mindspore/common/hook_handle.py +14 -26
  31. mindspore/common/mindir_util.py +2 -2
  32. mindspore/common/parameter.py +46 -13
  33. mindspore/common/recompute.py +39 -9
  34. mindspore/common/sparse_tensor.py +7 -3
  35. mindspore/common/tensor.py +209 -29
  36. mindspore/communication/__init__.py +1 -1
  37. mindspore/communication/_comm_helper.py +38 -3
  38. mindspore/communication/comm_func.py +310 -55
  39. mindspore/communication/management.py +14 -14
  40. mindspore/context.py +123 -22
  41. mindspore/dataset/__init__.py +1 -1
  42. mindspore/dataset/audio/__init__.py +1 -1
  43. mindspore/dataset/core/config.py +7 -0
  44. mindspore/dataset/core/validator_helpers.py +7 -0
  45. mindspore/dataset/engine/cache_client.py +1 -1
  46. mindspore/dataset/engine/datasets.py +72 -44
  47. mindspore/dataset/engine/datasets_audio.py +7 -7
  48. mindspore/dataset/engine/datasets_standard_format.py +53 -3
  49. mindspore/dataset/engine/datasets_text.py +20 -20
  50. mindspore/dataset/engine/datasets_user_defined.py +174 -104
  51. mindspore/dataset/engine/datasets_vision.py +33 -33
  52. mindspore/dataset/engine/iterators.py +29 -0
  53. mindspore/dataset/engine/obs/util.py +7 -0
  54. mindspore/dataset/engine/queue.py +114 -60
  55. mindspore/dataset/engine/serializer_deserializer.py +2 -2
  56. mindspore/dataset/engine/validators.py +34 -14
  57. mindspore/dataset/text/__init__.py +1 -4
  58. mindspore/dataset/transforms/__init__.py +0 -3
  59. mindspore/dataset/utils/line_reader.py +2 -0
  60. mindspore/dataset/vision/__init__.py +1 -4
  61. mindspore/dataset/vision/utils.py +1 -1
  62. mindspore/dataset/vision/validators.py +2 -1
  63. mindspore/dnnl.dll +0 -0
  64. mindspore/dpcmi.dll +0 -0
  65. mindspore/{nn/extend → experimental/es}/__init__.py +4 -11
  66. mindspore/experimental/es/embedding_service.py +883 -0
  67. mindspore/{nn/layer → experimental/es}/embedding_service_layer.py +218 -30
  68. mindspore/experimental/llm_boost/__init__.py +21 -0
  69. mindspore/{nn/extend/layer → experimental/llm_boost/atb}/__init__.py +4 -8
  70. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  71. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  72. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  73. mindspore/experimental/llm_boost/register.py +129 -0
  74. mindspore/experimental/llm_boost/utils.py +31 -0
  75. mindspore/experimental/optim/adamw.py +85 -0
  76. mindspore/experimental/optim/optimizer.py +3 -0
  77. mindspore/hal/__init__.py +3 -3
  78. mindspore/hal/contiguous_tensors_handle.py +175 -0
  79. mindspore/hal/stream.py +18 -0
  80. mindspore/include/api/model_group.h +13 -1
  81. mindspore/include/api/types.h +10 -10
  82. mindspore/include/dataset/config.h +2 -2
  83. mindspore/include/dataset/constants.h +2 -2
  84. mindspore/include/dataset/execute.h +2 -2
  85. mindspore/include/dataset/vision.h +4 -0
  86. mindspore/jpeg62.dll +0 -0
  87. mindspore/log.py +1 -1
  88. mindspore/mindrecord/filewriter.py +68 -51
  89. mindspore/mindspore_backend.dll +0 -0
  90. mindspore/mindspore_common.dll +0 -0
  91. mindspore/mindspore_core.dll +0 -0
  92. mindspore/mindspore_glog.dll +0 -0
  93. mindspore/mindspore_np_dtype.dll +0 -0
  94. mindspore/mindspore_ops.dll +0 -0
  95. mindspore/mint/__init__.py +495 -46
  96. mindspore/mint/distributed/__init__.py +31 -0
  97. mindspore/mint/distributed/distributed.py +254 -0
  98. mindspore/mint/nn/__init__.py +266 -21
  99. mindspore/mint/nn/functional.py +125 -19
  100. mindspore/mint/nn/layer/__init__.py +39 -0
  101. mindspore/mint/nn/layer/activation.py +133 -0
  102. mindspore/mint/nn/layer/normalization.py +477 -0
  103. mindspore/mint/nn/layer/pooling.py +110 -0
  104. mindspore/mint/optim/adamw.py +28 -7
  105. mindspore/mint/special/__init__.py +63 -0
  106. mindspore/msobj140.dll +0 -0
  107. mindspore/mspdb140.dll +0 -0
  108. mindspore/mspdbcore.dll +0 -0
  109. mindspore/mspdbst.dll +0 -0
  110. mindspore/mspft140.dll +0 -0
  111. mindspore/msvcdis140.dll +0 -0
  112. mindspore/msvcp140_1.dll +0 -0
  113. mindspore/msvcp140_2.dll +0 -0
  114. mindspore/msvcp140_atomic_wait.dll +0 -0
  115. mindspore/msvcp140_codecvt_ids.dll +0 -0
  116. mindspore/multiprocessing/__init__.py +2 -1
  117. mindspore/nn/__init__.py +0 -1
  118. mindspore/nn/cell.py +275 -93
  119. mindspore/nn/layer/activation.py +211 -44
  120. mindspore/nn/layer/basic.py +113 -3
  121. mindspore/nn/layer/embedding.py +120 -2
  122. mindspore/nn/layer/normalization.py +101 -5
  123. mindspore/nn/layer/padding.py +34 -48
  124. mindspore/nn/layer/pooling.py +161 -7
  125. mindspore/nn/layer/transformer.py +3 -3
  126. mindspore/nn/loss/__init__.py +2 -2
  127. mindspore/nn/loss/loss.py +84 -6
  128. mindspore/nn/optim/__init__.py +2 -1
  129. mindspore/nn/optim/adadelta.py +1 -1
  130. mindspore/nn/optim/adam.py +1 -1
  131. mindspore/nn/optim/lamb.py +1 -1
  132. mindspore/nn/optim/tft_wrapper.py +127 -0
  133. mindspore/nn/wrap/cell_wrapper.py +12 -23
  134. mindspore/nn/wrap/grad_reducer.py +5 -5
  135. mindspore/nn/wrap/loss_scale.py +17 -3
  136. mindspore/numpy/__init__.py +1 -1
  137. mindspore/numpy/array_creations.py +65 -68
  138. mindspore/numpy/array_ops.py +64 -60
  139. mindspore/numpy/fft.py +610 -75
  140. mindspore/numpy/logic_ops.py +11 -10
  141. mindspore/numpy/math_ops.py +85 -84
  142. mindspore/numpy/utils_const.py +4 -4
  143. mindspore/opencv_core452.dll +0 -0
  144. mindspore/opencv_imgcodecs452.dll +0 -0
  145. mindspore/opencv_imgproc452.dll +0 -0
  146. mindspore/ops/__init__.py +6 -4
  147. mindspore/ops/_grad_experimental/grad_comm_ops.py +47 -3
  148. mindspore/ops/_grad_experimental/grad_math_ops.py +0 -22
  149. mindspore/ops/_vmap/vmap_array_ops.py +2 -4
  150. mindspore/ops/_vmap/vmap_math_ops.py +17 -1
  151. mindspore/ops/_vmap/vmap_nn_ops.py +43 -2
  152. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +85 -7
  153. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +2 -0
  154. mindspore/ops/auto_generate/gen_extend_func.py +734 -13
  155. mindspore/ops/auto_generate/gen_ops_def.py +2420 -381
  156. mindspore/ops/auto_generate/gen_ops_prim.py +5196 -1659
  157. mindspore/ops/auto_generate/pyboost_inner_prim.py +176 -56
  158. mindspore/ops/composite/base.py +85 -48
  159. mindspore/ops/composite/multitype_ops/_compile_utils.py +1 -0
  160. mindspore/ops/composite/multitype_ops/not_in_impl.py +2 -2
  161. mindspore/ops/function/__init__.py +22 -0
  162. mindspore/ops/function/array_func.py +490 -153
  163. mindspore/ops/function/debug_func.py +113 -1
  164. mindspore/ops/function/fft_func.py +15 -2
  165. mindspore/ops/function/grad/grad_func.py +3 -2
  166. mindspore/ops/function/math_func.py +558 -207
  167. mindspore/ops/function/nn_func.py +817 -383
  168. mindspore/ops/function/other_func.py +3 -2
  169. mindspore/ops/function/random_func.py +184 -8
  170. mindspore/ops/function/reshard_func.py +13 -11
  171. mindspore/ops/function/sparse_unary_func.py +1 -1
  172. mindspore/ops/function/vmap_func.py +3 -2
  173. mindspore/ops/functional.py +24 -14
  174. mindspore/ops/op_info_register.py +3 -3
  175. mindspore/ops/operations/__init__.py +6 -1
  176. mindspore/ops/operations/_grad_ops.py +2 -76
  177. mindspore/ops/operations/_infer_ops.py +1 -1
  178. mindspore/ops/operations/_inner_ops.py +71 -94
  179. mindspore/ops/operations/array_ops.py +12 -146
  180. mindspore/ops/operations/comm_ops.py +42 -53
  181. mindspore/ops/operations/custom_ops.py +83 -19
  182. mindspore/ops/operations/debug_ops.py +42 -10
  183. mindspore/ops/operations/manually_defined/_inner.py +12 -0
  184. mindspore/ops/operations/manually_defined/ops_def.py +265 -10
  185. mindspore/ops/operations/math_ops.py +12 -223
  186. mindspore/ops/operations/nn_ops.py +20 -114
  187. mindspore/ops/operations/other_ops.py +7 -4
  188. mindspore/ops/operations/random_ops.py +46 -1
  189. mindspore/ops/primitive.py +18 -6
  190. mindspore/ops_generate/arg_dtype_cast.py +2 -0
  191. mindspore/ops_generate/gen_aclnn_implement.py +11 -11
  192. mindspore/ops_generate/gen_constants.py +36 -0
  193. mindspore/ops_generate/gen_ops.py +67 -52
  194. mindspore/ops_generate/gen_ops_inner_prim.py +1 -1
  195. mindspore/ops_generate/gen_pyboost_func.py +131 -47
  196. mindspore/ops_generate/op_proto.py +10 -3
  197. mindspore/ops_generate/pyboost_utils.py +14 -1
  198. mindspore/ops_generate/template.py +43 -21
  199. mindspore/parallel/__init__.py +3 -1
  200. mindspore/parallel/_auto_parallel_context.py +28 -8
  201. mindspore/parallel/_cell_wrapper.py +83 -0
  202. mindspore/parallel/_parallel_serialization.py +47 -19
  203. mindspore/parallel/_tensor.py +81 -11
  204. mindspore/parallel/_utils.py +13 -1
  205. mindspore/parallel/algo_parameter_config.py +5 -5
  206. mindspore/parallel/checkpoint_transform.py +46 -39
  207. mindspore/parallel/cluster/process_entity/__init__.py +1 -1
  208. mindspore/parallel/cluster/process_entity/_api.py +31 -23
  209. mindspore/parallel/cluster/process_entity/_utils.py +2 -27
  210. mindspore/parallel/parameter_broadcast.py +3 -4
  211. mindspore/parallel/shard.py +162 -31
  212. mindspore/parallel/transform_safetensors.py +993 -0
  213. mindspore/pgodb140.dll +0 -0
  214. mindspore/pgort140.dll +0 -0
  215. mindspore/profiler/__init__.py +2 -1
  216. mindspore/profiler/common/constant.py +29 -0
  217. mindspore/profiler/common/registry.py +47 -0
  218. mindspore/profiler/common/util.py +28 -0
  219. mindspore/profiler/dynamic_profiler.py +694 -0
  220. mindspore/profiler/envprofiling.py +17 -19
  221. mindspore/profiler/parser/ascend_analysis/constant.py +18 -0
  222. mindspore/profiler/parser/ascend_analysis/file_manager.py +25 -4
  223. mindspore/profiler/parser/ascend_analysis/function_event.py +43 -19
  224. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +31 -26
  225. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +56 -10
  226. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +55 -8
  227. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  228. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +27 -20
  229. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +9 -2
  230. mindspore/profiler/parser/ascend_msprof_exporter.py +5 -4
  231. mindspore/profiler/parser/ascend_timeline_generator.py +27 -25
  232. mindspore/profiler/parser/base_timeline_generator.py +19 -25
  233. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +25 -12
  234. mindspore/profiler/parser/framework_parser.py +1 -391
  235. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  236. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  237. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  238. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  239. mindspore/profiler/parser/memory_usage_parser.py +0 -154
  240. mindspore/profiler/parser/profiler_info.py +78 -6
  241. mindspore/profiler/profiler.py +153 -0
  242. mindspore/profiler/profiling.py +280 -412
  243. mindspore/rewrite/__init__.py +1 -2
  244. mindspore/rewrite/common/namespace.py +4 -4
  245. mindspore/rewrite/symbol_tree/symbol_tree.py +3 -3
  246. mindspore/run_check/_check_version.py +36 -103
  247. mindspore/safeguard/rewrite_obfuscation.py +591 -247
  248. mindspore/swresample-4.dll +0 -0
  249. mindspore/swscale-6.dll +0 -0
  250. mindspore/tbbmalloc.dll +0 -0
  251. mindspore/tinyxml2.dll +0 -0
  252. mindspore/train/__init__.py +4 -3
  253. mindspore/train/_utils.py +28 -2
  254. mindspore/train/amp.py +171 -53
  255. mindspore/train/callback/__init__.py +2 -2
  256. mindspore/train/callback/_callback.py +4 -4
  257. mindspore/train/callback/_checkpoint.py +85 -22
  258. mindspore/train/callback/_cluster_monitor.py +1 -1
  259. mindspore/train/callback/_flops_collector.py +1 -0
  260. mindspore/train/callback/_loss_monitor.py +3 -3
  261. mindspore/train/callback/_on_request_exit.py +134 -31
  262. mindspore/train/callback/_summary_collector.py +5 -5
  263. mindspore/train/callback/_tft_register.py +352 -0
  264. mindspore/train/dataset_helper.py +7 -3
  265. mindspore/train/metrics/metric.py +3 -3
  266. mindspore/train/metrics/roc.py +4 -4
  267. mindspore/train/mind_ir_pb2.py +44 -39
  268. mindspore/train/model.py +134 -58
  269. mindspore/train/serialization.py +336 -112
  270. mindspore/turbojpeg.dll +0 -0
  271. mindspore/utils/__init__.py +21 -0
  272. mindspore/utils/utils.py +60 -0
  273. mindspore/vcmeta.dll +0 -0
  274. mindspore/vcruntime140.dll +0 -0
  275. mindspore/vcruntime140_1.dll +0 -0
  276. mindspore/version.py +1 -1
  277. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/METADATA +6 -2
  278. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/RECORD +281 -275
  279. mindspore/include/c_api/ms/abstract.h +0 -67
  280. mindspore/include/c_api/ms/attribute.h +0 -197
  281. mindspore/include/c_api/ms/base/handle_types.h +0 -43
  282. mindspore/include/c_api/ms/base/macros.h +0 -32
  283. mindspore/include/c_api/ms/base/status.h +0 -33
  284. mindspore/include/c_api/ms/base/types.h +0 -283
  285. mindspore/include/c_api/ms/context.h +0 -102
  286. mindspore/include/c_api/ms/graph.h +0 -160
  287. mindspore/include/c_api/ms/node.h +0 -606
  288. mindspore/include/c_api/ms/tensor.h +0 -161
  289. mindspore/include/c_api/ms/value.h +0 -84
  290. mindspore/mindspore_shared_lib.dll +0 -0
  291. mindspore/nn/extend/basic.py +0 -140
  292. mindspore/nn/extend/embedding.py +0 -143
  293. mindspore/nn/extend/layer/normalization.py +0 -109
  294. mindspore/nn/extend/pooling.py +0 -117
  295. mindspore/nn/layer/embedding_service.py +0 -531
  296. mindspore/ops/_op_impl/aicpu/strided_slice_v2.py +0 -93
  297. mindspore/ops/_op_impl/aicpu/strided_slice_v2_grad.py +0 -66
  298. mindspore/ops/extend/__init__.py +0 -53
  299. mindspore/ops/extend/array_func.py +0 -218
  300. mindspore/ops/extend/math_func.py +0 -76
  301. mindspore/ops/extend/nn_func.py +0 -308
  302. mindspore/ops/silent_check.py +0 -162
  303. mindspore/profiler/parser/msadvisor_analyzer.py +0 -82
  304. mindspore/profiler/parser/msadvisor_parser.py +0 -240
  305. mindspore/train/callback/_mindio_ttp.py +0 -443
  306. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/WHEEL +0 -0
  307. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/entry_points.txt +0 -0
  308. {mindspore-2.3.0.dist-info → mindspore-2.4.0.dist-info}/top_level.txt +0 -0
@@ -1,161 +0,0 @@
1
- /**
2
- * Copyright 2022 Huawei Technologies Co., Ltd
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #ifndef MINDSPORE_CCSRC_C_API_INCLUDE_FUNC_TENSOR_H_
18
- #define MINDSPORE_CCSRC_C_API_INCLUDE_FUNC_TENSOR_H_
19
-
20
- #include <stdbool.h>
21
- #include <stdlib.h>
22
- #include "include/c_api/ms/base/macros.h"
23
- #include "include/c_api/ms/base/status.h"
24
- #include "include/c_api/ms/base/types.h"
25
- #include "include/c_api/ms/base/handle_types.h"
26
- #include "include/c_api/ms/context.h"
27
-
28
- #ifdef __cplusplus
29
- extern "C" {
30
- #endif
31
-
32
- /// \brief Create a tensor with input data buffer.
33
- ///
34
- /// \param[in] res_mgr Resource manager that saves allocated instance resources.
35
- /// \param[in] data The input data to be copied into tensor.
36
- /// \param[in] type [TypeId] Data type of the tensor.
37
- /// \param[in] shape The shape arary of the tensor.
38
- /// \param[in] shape_size The size of shape array, i.e., the rank of the tensor.
39
- /// \param[in] data_len The length of data in bytes.
40
- ///
41
- /// \return The pointer of the created tensor instance.
42
- MIND_C_API TensorHandle MSNewTensor(ResMgrHandle res_mgr, void *data, DataTypeC type, const int64_t shape[],
43
- size_t shape_size, size_t data_len);
44
-
45
- /// \brief Create a tensor with path to a space-sperated txt file.
46
- ///
47
- /// \param[in] res_mgr Resource manager that saves allocated instance resources.
48
- /// \param[in] type [TypeId] Data type of the tensor.
49
- /// \param[in] shape The shape arary of the tensor.
50
- /// \param[in] shape_size The size of shape array, i.e., the rank of the tensor.
51
- /// \param[in] path path to the file.
52
- ///
53
- /// \return The pointer of the created tensor instance.
54
- MIND_C_API TensorHandle MSNewTensorFromFile(ResMgrHandle res_mgr, DataTypeC type, const int64_t shape[],
55
- size_t shape_size, const char *path);
56
-
57
- /// \brief Create a tensor with input data buffer and given source data type.
58
- ///
59
- /// \param[in] res_mgr Resource manager that saves allocated instance resources.
60
- /// \param[in] shape The shape arary of the tensor.
61
- /// \param[in] shape_size The size of shape array, i.e., the rank of the tensor.
62
- /// \param[in] data The input data to be copied into tensor.
63
- /// \param[in] tensor_type [TypeId] Data type of the tensor.
64
- /// \param[in] src_type [TypeId] The source data type.
65
- ///
66
- /// \return The pointer of the created tensor instance.
67
- MIND_C_API TensorHandle MSNewTensorWithSrcType(ResMgrHandle res_mgr, void *data, const int64_t shape[],
68
- size_t shape_size, DataTypeC tensor_type, DataTypeC src_type);
69
-
70
- /// \brief Create a tensor with float32 scalar value.
71
- ///
72
- /// \param[in] res_mgr Resource manager that saves allocated instance resources.
73
- /// \param[in] value The input scalar value.
74
- ///
75
- /// \return The pointer of the created tensor instance.
76
- MIND_C_API TensorHandle MSNewTensorScalarFloat32(ResMgrHandle res_mgr, float value);
77
-
78
- /// \brief Create a tensor with int32 scalar value.
79
- ///
80
- /// \param[in] res_mgr Resource manager that saves allocated instance resources.
81
- /// \param[in] value The input scalar value.
82
- ///
83
- /// \return The pointer of the created tensor instance.
84
- MIND_C_API TensorHandle MSNewTensorScalarInt32(ResMgrHandle res_mgr, int value);
85
-
86
- /// \brief Get the raw pointer of tensor data.
87
- ///
88
- /// \param[in] res_mgr Resource manager that saves allocated instance resources.
89
- /// \param[in] tensor The pointer of the tensor instance.
90
- ///
91
- /// \return The pointer to the tensor data
92
- MIND_C_API void *MSTensorGetData(ResMgrHandle res_mgr, ConstTensorHandle tensor);
93
-
94
- /// \brief Set tensor data type.
95
- ///
96
- /// \param[in] res_mgr Resource manager that saves allocated instance resources.
97
- /// \param[in] tensor The pointer of the tensor instance.
98
- /// \param[in] type The data type to be set.
99
- ///
100
- /// \return Error code that indicate whether the functions executed successfully.
101
- MIND_C_API STATUS MSTensorSetDataType(ResMgrHandle res_mgr, TensorHandle tensor, DataTypeC type);
102
-
103
- /// \brief Get tensor data type.
104
- ///
105
- /// \param[in] res_mgr Resource manager that saves allocated instance resources.
106
- /// \param[in] tensor The pointer of the tensor instance.
107
- ///
108
- /// \return The data type of tensor.
109
- MIND_C_API DataTypeC MSTensorGetDataType(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error);
110
-
111
- /// \brief Get the byte size of tensor data.
112
- ///
113
- /// \param[in] res_mgr Resource manager that saves allocated instance resources.
114
- /// \param[in] tensor The pointer of the tensor instance.
115
- /// \param[in] error Records error code that indicate whether the functions executed successfully.
116
- ///
117
- /// \return The byte size of tensor data.
118
- MIND_C_API size_t MSTensorGetDataSize(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error);
119
-
120
- /// \brief Get the element number of tensor array.
121
- ///
122
- /// \param[in] res_mgr Resource manager that saves allocated instance resources.
123
- /// \param[in] tensor The pointer of the tensor instance.
124
- /// \param[in] error Records error code that indicate whether the functions executed successfully.
125
- ///
126
- /// \return The element number of tensor array.
127
- MIND_C_API size_t MSTensorGetElementNum(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error);
128
-
129
- /// \brief Get the dimension of tensor.
130
- ///
131
- /// \param[in] res_mgr Resource manager that saves allocated instance resources.
132
- /// \param[in] tensor The pointer of the tensor instance.
133
- /// \param[in] error Records error code that indicate whether the functions executed successfully.
134
- ///
135
- /// \return The dimension of tensor.
136
- MIND_C_API size_t MSTensorGetDimension(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error);
137
-
138
- /// \brief Set the shape of tensor array.
139
- ///
140
- /// \param[in] res_mgr Resource manager that saves allocated instance resources.
141
- /// \param[in] tensor The pointer of the tensor instance.
142
- /// \param[in] shape The shape array.
143
- /// \param[in] dim The the dimension of tensor, i.e., size of shape array.
144
- ///
145
- /// \return Error code indicates whether the function executed successfully.
146
- MIND_C_API STATUS MSTensorSetShape(ResMgrHandle res_mgr, TensorHandle tensor, const int64_t shape[], size_t dim);
147
-
148
- /// \brief Get the shape of tensor array.
149
- ///
150
- /// \param[in] res_mgr Resource manager that saves allocated instance resources.
151
- /// \param[in] tensor The pointer of the tensor instance.
152
- /// \param[in] shape The shape array.
153
- /// \param[in] dim The the dimension of tensor, i.e., size of shape array.
154
- ///
155
- /// \return Error code indicates whether the function executed successfully.
156
- MIND_C_API STATUS MSTensorGetShape(ResMgrHandle res_mgr, ConstTensorHandle tensor, int64_t shape[], size_t dim);
157
-
158
- #ifdef __cplusplus
159
- }
160
- #endif
161
- #endif // MINDSPORE_CCSRC_C_API_INCLUDE_FUNC_TENSOR_H_
@@ -1,84 +0,0 @@
1
- /**
2
- * Copyright 2022 Huawei Technologies Co., Ltd
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #ifndef MINDSPORE_CCSRC_C_API_INCLUDE_VALUE_H_
18
- #define MINDSPORE_CCSRC_C_API_INCLUDE_VALUE_H_
19
-
20
- #include <stdbool.h>
21
- #include <stdlib.h>
22
- #include "include/c_api/ms/base/macros.h"
23
- #include "include/c_api/ms/base/handle_types.h"
24
- #include "include/c_api/ms/base/types.h"
25
- #include "include/c_api/ms/context.h"
26
-
27
- #ifdef __cplusplus
28
- extern "C" {
29
- #endif
30
-
31
- /// \brief Create new Int64 scalar value.
32
- ///
33
- /// \param[in] res_mgr Resource Handle that manages the nodes of the funcGraph.
34
- /// \param[in] v Given value.
35
- ///
36
- /// \return Value handle.
37
- MIND_C_API ValueHandle MSNewValueInt64(ResMgrHandle res_mgr, const int64_t v);
38
-
39
- /// \brief Create new flaot32 scalar value.
40
- ///
41
- /// \param[in] res_mgr Resource Handle that manages the nodes of the funcGraph.
42
- /// \param[in] v Given value.
43
- ///
44
- /// \return Value handle.
45
- MIND_C_API ValueHandle MSNewValueFloat32(ResMgrHandle res_mgr, const float v);
46
-
47
- /// \brief Create new Bool scalar value.
48
- ///
49
- /// \param[in] res_mgr Resource Handle that manages the nodes of the funcGraph.
50
- /// \param[in] v Given value.
51
- ///
52
- /// \return Value handle.
53
- MIND_C_API ValueHandle MSNewValueBool(ResMgrHandle res_mgr, const bool v);
54
-
55
- /// \brief Create new value of DataType.
56
- ///
57
- /// \param[in] res_mgr Resource Handle that manages the nodes of the funcGraph.
58
- /// \param[in] type Given data type.
59
- ///
60
- /// \return Value handle.
61
- MIND_C_API ValueHandle MSNewValueType(ResMgrHandle res_mgr, DataTypeC type);
62
-
63
- /// \brief Create new vector of Strings Value.
64
- ///
65
- /// \param[in] res_mgr Resource Handle that manages the nodes of the funcGraph.
66
- /// \param[in] strs Given value.
67
- /// \param[in] vec_len Length of the string vector.
68
- ///
69
- /// \return Value handle.
70
- MIND_C_API ValueHandle MSNewValueStrings(ResMgrHandle res_mgr, const char *strs[], size_t vec_len);
71
-
72
- /// \brief Create new Value with array.
73
- ///
74
- /// \param[in] res_mgr Resource Handle that manages the nodes of the funcGraph.
75
- /// \param[in] value Given array.
76
- /// \param[in] vec_size Given array size.
77
- /// \param[in] data_type Datatype of the array.
78
- ///
79
- /// \return Value handle
80
- MIND_C_API ValueHandle MSNewValueArray(ResMgrHandle res_mgr, void *value, size_t vec_size, DataTypeC data_type);
81
- #ifdef __cplusplus
82
- }
83
- #endif
84
- #endif // MINDSPORE_CCSRC_C_API_INCLUDE_VALUE_H_
Binary file
@@ -1,140 +0,0 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
-
16
- """basic"""
17
- from __future__ import absolute_import
18
-
19
- import math
20
-
21
- import mindspore.common.dtype as mstype
22
- from mindspore import _checkparam as Validator
23
- from mindspore._extends import cell_attr_register
24
- from mindspore.common.initializer import initializer, HeUniform, Uniform
25
- from mindspore.common.parameter import Parameter
26
- from mindspore.common.tensor import Tensor
27
- from mindspore.nn.cell import Cell
28
- from mindspore.ops import operations as P
29
-
30
- __all__ = ['Linear']
31
-
32
-
33
- class Linear(Cell):
34
- r"""
35
- The linear connected layer.
36
-
37
- Applies linear connected layer for the input. This layer implements the operation as:
38
-
39
- .. math::
40
- \text{outputs} = X * kernel + bias
41
-
42
- where :math:`X` is the input tensors, :math:`\text{kernel}` is a weight matrix with the same
43
- data type as the :math:`X` created by the layer, and :math:`\text{bias}` is a bias vector
44
- with the same data type as the :math:`X` created by the layer (only if has_bias is True).
45
-
46
- Args:
47
- in_features (int): The number of features in the input space.
48
- out_features (int): The number of features in the output space.
49
- bias (bool): Specifies whether the layer uses a bias vector :math:`\text{bias}`. Default: ``True``.
50
- weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
51
- is same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
52
- weight will be initialized using HeUniform.
53
- bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
54
- same as `x`. The values of str refer to the function `initializer`. Default: ``None`` ,
55
- bias will be initialized using Uniform.
56
- dtype (:class:`mindspore.dtype`): Data type of Parameter. Default: ``None`` .
57
-
58
- Inputs:
59
- - **x** (Tensor) - Tensor of shape :math:`(*, in\_features)`. The `in_features` in `Args` should be equal
60
- to :math:`in\_features` in `Inputs`.
61
-
62
- Outputs:
63
- Tensor of shape :math:`(*, out\_features)`.
64
-
65
- Raises:
66
- TypeError: If `in_features` or `out_features` is not an int.
67
- TypeError: If `bias` is not a bool.
68
- ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init`
69
- is not equal to `out_features` or shape[1] of `weight_init` is not equal to `in_features`.
70
- ValueError: If length of shape of `bias_init` is not equal to 1
71
- or shape[0] of `bias_init` is not equal to `out_features`.
72
-
73
- Supported Platforms:
74
- ``Ascend`` ``GPU`` ``CPU``
75
-
76
- Examples:
77
- >>> import mindspore
78
- >>> from mindspore import Tensor
79
- >>> from mindspore import nn
80
- >>> import numpy as np
81
- >>> x = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
82
- >>> net = nn.extend.Linear(3, 4)
83
- >>> output = net(x)
84
- >>> print(output.shape)
85
- (2, 4)
86
- """
87
-
88
- @cell_attr_register(attrs=['has_bias'])
89
- def __init__(self,
90
- in_features,
91
- out_features,
92
- bias=True,
93
- weight_init=None,
94
- bias_init=None,
95
- dtype=None):
96
- """Initialize Linear."""
97
- super(Linear, self).__init__()
98
- self.in_features = Validator.check_positive_int(
99
- in_features, "in_features", self.cls_name)
100
- self.out_features = Validator.check_positive_int(
101
- out_features, "out_features", self.cls_name)
102
- self.has_bias = Validator.check_bool(
103
- bias, "has_bias", self.cls_name)
104
- self.dense = P.Dense()
105
- if dtype is None:
106
- dtype = mstype.float32
107
- if isinstance(weight_init, Tensor):
108
- if weight_init.ndim != 2 or weight_init.shape[0] != out_features or \
109
- weight_init.shape[1] != in_features:
110
- raise ValueError(f"For '{self.cls_name}', weight init shape error. The ndim of 'weight_init' must "
111
- f"be equal to 2, and the first dim must be equal to 'out_features', and the "
112
- f"second dim must be equal to 'in_features'. But got 'weight_init': {weight_init}, "
113
- f"'out_features': {out_features}, 'in_features': {in_features}.")
114
- if weight_init is None:
115
- weight_init = HeUniform(math.sqrt(5))
116
- self.weight = Parameter(initializer(
117
- weight_init, [out_features, in_features], dtype=dtype), name="weight")
118
-
119
- self.bias = None
120
- if self.has_bias:
121
- if isinstance(bias_init, Tensor):
122
- if bias_init.ndim != 1 or bias_init.shape[0] != out_features:
123
- raise ValueError(f"For '{self.cls_name}', bias init shape error. The ndim of 'bias_init' must "
124
- f"be equal to 1, and the first dim must be equal to 'out_features'. But got "
125
- f"'bias_init': {bias_init}, 'out_features': {out_features}.")
126
- if bias_init is None:
127
- bound = 1 / math.sqrt(in_features)
128
- bias_init = Uniform(scale=bound)
129
- self.bias = Parameter(initializer(
130
- bias_init, [out_features], dtype=dtype), name="bias")
131
-
132
- def construct(self, x):
133
- x = self.dense(x, self.weight, self.bias)
134
- return x
135
-
136
- def extend_repr(self):
137
- s = f'input_features={self.in_features}, output_features={self.out_features}'
138
- if self.has_bias:
139
- s += f', has_bias={self.has_bias}'
140
- return s
@@ -1,143 +0,0 @@
1
- # Copyright 2024 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """embedding"""
16
- from __future__ import absolute_import
17
-
18
- import mindspore.common.dtype as mstype
19
- from mindspore.common.initializer import Normal
20
- from mindspore import _checkparam as Validator
21
- from mindspore.nn.cell import Cell
22
- from mindspore import ops
23
- from mindspore.common.parameter import Parameter
24
- from mindspore.common.tensor import Tensor
25
-
26
- __all__ = ['Embedding']
27
-
28
-
29
- class Embedding(Cell):
30
- r"""
31
- Embedding layer.
32
- Retrieve the word embeddings in weight stored in the layer using indices specified in `input`.
33
-
34
- .. warning::
35
- On Ascend, the behavior is unpredictable when the value of `input` is invalid.
36
-
37
- Args:
38
- num_embeddings (int): Size of the dictionary of embeddings.
39
- embedding_dim (int): The size of each embedding vector.
40
- padding_idx (int, optional): If the value is not None, the corresponding row of embedding vector
41
- will not be updated in training. The value of embedding vector at `padding_idx` will default
42
- to zeros when the Embedding layer is newly constructed. The value should be in range
43
- `[-num_embeddings, num_embeddings)` if it's not ``None``. Default ``None``.
44
- max_norm (float, optional): If the value is not None, firstly get the p-norm result of the embedding
45
- vector specified by `input` where p is specified by `norm_type`; if the result is larger then `max_norm`,
46
- update the embedding vector` with :math:`\frac{max\_norm}{result+1e^{-7}}`. Default ``None``.
47
- norm_type (float, optional): Indicated the value of p in p-norm. Default ``2.0``.
48
- scale_grad_by_freq (bool, optional): If ``True`` the gradients will be scaled by the inverse of frequency
49
- of the index in `input`. Default ``False``.
50
- _weight (Tensor, optional): Used to initialize the weight of Embedding. If ``None``, the weight will be
51
- initialized from normal distribution :math:`{N}(\text{sigma=1.0}, \text{mean=0.0})`. Default ``None``.
52
- dtype (mindspore.dtype, optional) : Dtype of Parameters. It is meaningless when `_weight` is not None.
53
- Default: ``mindspore.float32``.
54
-
55
- Inputs:
56
- - **input** (Tensor) - The indices used to lookup in the embedding vector. The data type must be
57
- mindspore.int32 or mindspore.int64, and the value should be in range `[0, num_embeddings)`.
58
-
59
- Outputs:
60
- Tensor, has the same data type as weight, the shape is :math:`(*input.shape, embedding\_dim)`.
61
-
62
- Raises:
63
- TypeError: If `num_embeddings` is not an int.
64
- TypeError: If `embedding_dim` is not an int.
65
- ValueError: If `padding_idx` is out of valid range.
66
- TypeError: If `max_norm` is not a float.
67
- TypeError: If `norm_type` is not a float.
68
- TypeError: If `scale_grad_by_freq` is not a bool.
69
- TypeError: If `dtype` is not one of mindspore.dtype.
70
-
71
- Supported Platforms:
72
- ``Ascend``
73
-
74
- Examples:
75
- >>> import mindspore
76
- >>> import numpy as np
77
- >>> from mindspore import Tensor, nn
78
- >>> input = Tensor([[1, 0, 1, 1], [0, 0, 1, 0]])
79
- >>> embedding = nn.extend.Embedding(num_embeddings=10, embedding_dim=3)
80
- >>> output = embedding(input)
81
- >>> print(output)
82
- [[[-0.0024154 -0.01203444 0.00811537]
83
- [ 0.00233847 -0.00596091 0.00536799]
84
- [-0.0024154 -0.01203444 0.00811537]
85
- [-0.0024154 -0.01203444 0.00811537]]
86
- [[ 0.00233847 -0.00596091 0.00536799]
87
- [ 0.00233847 -0.00596091 0.00536799]
88
- [-0.0024154 -0.01203444 0.00811537]
89
- [ 0.00233847 -0.00596091 0.00536799]]]
90
- """
91
-
92
- def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0,
93
- scale_grad_by_freq=False, _weight=None, dtype=mstype.float32):
94
- """Initialize Embedding."""
95
- super().__init__()
96
- self.num_embeddings = Validator.check_value_type(
97
- 'num_embeddings', num_embeddings, [int], self.cls_name)
98
- self.embedding_dim = Validator.check_value_type(
99
- 'embedding_dim', embedding_dim, [int], self.cls_name)
100
- Validator.check_subclass(
101
- "dtype", dtype, mstype.number_type, self.cls_name)
102
- self.dtype = dtype
103
- self.padding_idx = padding_idx
104
- if _weight is None:
105
- init_tensor = Tensor(shape=[num_embeddings, embedding_dim], dtype=dtype, init=Normal(1, 0))
106
- init_tensor = self._zero_weight_by_index(init_tensor)
107
- self.weight = Parameter(init_tensor, name='weight')
108
- else:
109
- self.weight = Parameter(_weight)
110
-
111
- self.max_norm = max_norm
112
- if max_norm is not None:
113
- self.max_norm = Validator.check_value_type('max_norm', max_norm, [float], self.cls_name)
114
-
115
- self.norm_type = norm_type
116
- if norm_type is not None:
117
- self.norm_type = Validator.check_value_type('norm_type', norm_type,
118
- [float], self.cls_name)
119
-
120
- self.scale_grad_by_freq = scale_grad_by_freq
121
- if scale_grad_by_freq is not None:
122
- self.scale_grad_by_freq = Validator.check_value_type('scale_grad_by_freq',
123
- scale_grad_by_freq,
124
- [bool], self.cls_name)
125
-
126
- def _zero_weight_by_index(self, init_tensor):
127
- if self.padding_idx is not None:
128
- self.padding_idx = Validator.check_int_range(self.padding_idx, -self.num_embeddings, self.num_embeddings,
129
- Validator.INC_LEFT, "padding_idx", self.cls_name)
130
- if isinstance(init_tensor, Tensor) and init_tensor.init is not None:
131
- init_tensor = init_tensor.init_data()
132
- init_tensor[self.padding_idx] = 0
133
-
134
- return init_tensor
135
-
136
- def construct(self, input):
137
- return ops.embedding(input, self.weight, self.padding_idx, self.max_norm,
138
- self.norm_type, self.scale_grad_by_freq)
139
-
140
- def extend_repr(self):
141
- return f'num_embeddings={self.num_embeddings}, embedding_dim={self.embedding_dim}, ' \
142
- f'padding_idx={self.padding_idx}, max_norm={self.max_norm}, norm_type={self.norm_type}, ' \
143
- f'scale_grad_by_freq={self.scale_grad_by_freq}, dtype={self.dtype}'
@@ -1,109 +0,0 @@
1
- # Copyright 2020-2024 Huawei Technologies Co., Ltd
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ============================================================================
15
- """normalization"""
16
- from __future__ import absolute_import
17
- from __future__ import division
18
-
19
- from mindspore.ops import functional as F
20
- from mindspore.common.parameter import Parameter
21
- from mindspore.common.initializer import initializer
22
- from mindspore.common import dtype as mstype
23
- from mindspore.nn.cell import Cell
24
-
25
- __all__ = ['LayerNorm']
26
-
27
-
28
- class LayerNorm(Cell):
29
- r"""
30
- Applies Layer Normalization over a mini-batch of inputs.
31
-
32
- Layer Normalization is widely used in recurrent neural networks. It applies
33
- normalization on a mini-batch of inputs for each single training case as described
34
- in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike Batch
35
- Normalization, Layer Normalization performs exactly the same computation at training and
36
- testing time. It is applied across all channels and pixel but only one batch size.
37
- :math:`\gamma` and :math:`\beta` are trainable scale and shift.
38
- It can be described using the following formula:
39
-
40
- .. math::
41
- y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
42
-
43
- Args:
44
- normalized_shape (Union(tuple[int], list[int])): The normalized shape of `x` for LayerNorm
45
- gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
46
- The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
47
- ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'ones'`` .
48
- beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
49
- The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
50
- ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'zeros'`` .
51
- eps (float): A value added to the denominator for numerical stability(:math:`\epsilon`). Default: ``1e-5`` .
52
- elementwise_affine (bool): A bool value, When set to True, gamma and beta can be learned. Default: True.
53
- dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
54
-
55
- Inputs:
56
- - **x** (Tensor) - The shape is :math:`(N, *)`, where :math:`*` means, any number of additional dimensions.
57
-
58
- Outputs:
59
- Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `x`.
60
-
61
- Raises:
62
- TypeError: If `epsilon` is not a float.
63
-
64
- Supported Platforms:
65
- ``Ascend``
66
-
67
- Examples:
68
- >>> import mindspore as ms
69
- >>> import numpy as np
70
- >>> x = ms.Tensor(np.ones([20, 5, 10, 10]), ms.float32)
71
- >>> shape1 = x.shape[1:]
72
- >>> m = ms.nn.extend.LayerNorm(shape1)
73
- >>> output = m(x).shape
74
- >>> print(output)
75
- (20, 5, 10, 10)
76
- """
77
-
78
- def __init__(self,
79
- normalized_shape,
80
- gamma_init='ones',
81
- beta_init='zeros',
82
- eps=1e-5,
83
- elementwise_affine=True,
84
- dtype=mstype.float32
85
- ):
86
- """Initialize LayerNorm."""
87
- super(LayerNorm, self).__init__()
88
- if not isinstance(normalized_shape, (tuple, list)):
89
- raise TypeError(f"For '{self.cls_name}', the type of 'normalized_shape' must be tuple[int] or list[int], "
90
- f"but got {normalized_shape} and the type is {type(normalized_shape)}.")
91
- if not normalized_shape:
92
- raise ValueError(
93
- f"Expected normalized_shape to be at least 1-dimensional, i.e., containing at "
94
- f"least one element, but got normalized_shape = {normalized_shape}"
95
- )
96
- self.normalized_shape = normalized_shape
97
- self.epsilon = eps
98
- self.gamma = Parameter(initializer(
99
- gamma_init, normalized_shape, dtype=dtype), name="gamma", requires_grad=elementwise_affine)
100
- self.beta = Parameter(initializer(
101
- beta_init, normalized_shape, dtype=dtype), name="beta", requires_grad=elementwise_affine)
102
-
103
- def construct(self, input_x):
104
- y = F.layer_norm(input_x, self.normalized_shape, self.gamma.astype(input_x.dtype),
105
- self.beta.astype(input_x.dtype), self.epsilon)
106
- return y
107
-
108
- def extend_repr(self):
109
- return 'normalized_shape={}, gamma{}, beta={}'.format(self.normalized_shape, self.gamma, self.beta)