mindspore 2.4.0__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-311-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1273 @@
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """normalization"""
16
+ from __future__ import absolute_import
17
+ from __future__ import division
18
+
19
+ import itertools
20
+ import numbers
21
+ import hashlib
22
+ import numpy as np
23
+ import mindspore.ops as ops
24
+ from mindspore.ops import operations as P
25
+ from mindspore.ops.operations import _inner_ops as inner
26
+ from mindspore.common.parameter import Parameter
27
+ from mindspore.common.initializer import initializer, Initializer
28
+ from mindspore.common.tensor import Tensor
29
+ from mindspore.ops.primitive import constexpr, _primexpr
30
+ import mindspore.context as context
31
+ from mindspore import _checkparam as validator
32
+ from mindspore._extends import cell_attr_register
33
+ from mindspore.communication.management import get_group_size, get_rank
34
+ from mindspore.communication import management
35
+ from mindspore.common import dtype as mstype
36
+ from mindspore.parallel._utils import _is_in_auto_parallel_mode
37
+ from mindspore.nn.cell import Cell
38
+ from mindspore import log as logger
39
+ from mindspore.ops import group_norm
40
+
41
+ __all__ = ['BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'LayerNorm', 'LayerNormExt', 'GroupNorm',
42
+ 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm2d', 'InstanceNorm3d']
43
+
44
+
45
+ def _check_dim(val, target, cls_name):
46
+ def _check(val, target, cls_name):
47
+ if val != target:
48
+ raise ValueError(f"For '{cls_name}', the in_shape must have {target} dims, but got {val}.")
49
+ _check(val, target, cls_name)
50
+
51
+
52
+ class _BatchNorm(Cell):
53
+ """Batch Normalization base class."""
54
+
55
+ @cell_attr_register
56
+ def __init__(self,
57
+ num_features,
58
+ eps=1e-5,
59
+ momentum=0.9,
60
+ affine=True,
61
+ gamma_init='ones',
62
+ beta_init='zeros',
63
+ moving_mean_init='zeros',
64
+ moving_var_init='ones',
65
+ use_batch_statistics=None,
66
+ data_format='NCHW',
67
+ dtype=mstype.float32):
68
+ """Initialize _BatchNorm."""
69
+ super(_BatchNorm, self).__init__()
70
+ validator.check_value_type('num_features', num_features, [int], self.cls_name)
71
+ if num_features < 1:
72
+ raise ValueError(f"For '{self.cls_name}', the 'num_features' must be at least 1, but got {num_features}.")
73
+
74
+ if momentum < 0 or momentum > 1:
75
+ raise ValueError(f"For '{self.cls_name}', the 'momentum' must be a number in range [0, 1], "
76
+ f"but got {momentum}.")
77
+ self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
78
+ if context.get_context("device_target") != "GPU" and self.format == "NHWC":
79
+ raise ValueError(f"For '{self.cls_name}', the 'NHWC' format only support in GPU target, but got device "
80
+ f"target {context.get_context('device_target')}.")
81
+ self.use_batch_statistics = use_batch_statistics
82
+ if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool):
83
+ raise ValueError(f"For '{self.cls_name}', the 'use_batch_statistics' must be a boolean value or None,"
84
+ f" but got {use_batch_statistics}.")
85
+ self.num_features = num_features
86
+ self.eps = eps
87
+ self.beta_init = beta_init
88
+ self.gamma_init = gamma_init
89
+ self.moving_mean_init = moving_mean_init
90
+ self.moving_var_init = moving_var_init
91
+ self.moving_mean = Parameter(initializer(
92
+ moving_mean_init, num_features, dtype=dtype), name="mean", requires_grad=False)
93
+ self.moving_variance = Parameter(initializer(
94
+ moving_var_init, num_features, dtype=dtype), name="variance", requires_grad=False)
95
+ self.gamma = Parameter(initializer(
96
+ gamma_init, num_features, dtype=dtype), name="gamma", requires_grad=affine)
97
+ self.beta = Parameter(initializer(
98
+ beta_init, num_features, dtype=dtype), name="beta", requires_grad=affine)
99
+
100
+ self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
101
+
102
+ self.shape = P.Shape()
103
+ self.reduce_mean = P.ReduceMean(keep_dims=True)
104
+ self.square = P.Square()
105
+ self.sqrt = P.Sqrt()
106
+ self.cast = P.Cast()
107
+ self.dtype = P.DType()
108
+ self.reshape = P.Reshape()
109
+ self._target = context.get_context("device_target")
110
+ self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
111
+ self.momentum = 1.0 - momentum
112
+
113
+ self.bn_train = P.BatchNorm(is_training=True,
114
+ epsilon=self.eps,
115
+ momentum=self.momentum,
116
+ data_format=self.format)
117
+
118
+ self.bn_infer = P.BatchNorm(is_training=False, epsilon=self.eps, data_format=self.format)
119
+ if _is_in_auto_parallel_mode():
120
+ data_parallel_strategy = ((1,), (1,))
121
+ data_parallel_strategy_one = ((1,), ())
122
+ else:
123
+ data_parallel_strategy = None
124
+ data_parallel_strategy_one = None
125
+ self.sub_mean = P.Sub().shard(data_parallel_strategy)
126
+ self.sub_var = P.Sub().shard(data_parallel_strategy)
127
+ self.mul_mean = P.Mul().shard(data_parallel_strategy_one)
128
+ self.mul_var = P.Mul().shard(data_parallel_strategy_one)
129
+ self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
130
+ self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
131
+
132
+ @staticmethod
133
+ @_primexpr
134
+ def _check_input_dim(shape, cls_name):
135
+ raise NotImplementedError
136
+
137
+ def construct(self, x):
138
+ self._check_input_dim(self.shape(x), self.cls_name)
139
+ x_shape = self.shape(x)
140
+ reshaped_x = x
141
+ if len(x_shape) == 2:
142
+ reshaped_x = self.reshape(x, (x_shape[0], x_shape[1], 1, 1))
143
+ elif len(x_shape) == 3:
144
+ reshaped_x = self.reshape(x, (x_shape[0], x_shape[1], x_shape[2], 1))
145
+ if self.use_batch_statistics is None:
146
+ if self.training:
147
+ return self.bn_train(x,
148
+ self.gamma,
149
+ self.beta,
150
+ self.moving_mean,
151
+ self.moving_variance)[0]
152
+ if not self.training:
153
+ bn_out = self.bn_infer(reshaped_x,
154
+ self.gamma,
155
+ self.beta,
156
+ self.moving_mean,
157
+ self.moving_variance)[0]
158
+ if len(x_shape) < 4:
159
+ bn_out = self.reshape(bn_out, x_shape)
160
+ return bn_out
161
+
162
+ if self.use_batch_statistics:
163
+ return self.bn_train(x,
164
+ self.gamma,
165
+ self.beta,
166
+ self.moving_mean,
167
+ self.moving_variance)[0]
168
+
169
+ bn_out = self.bn_infer(reshaped_x,
170
+ self.gamma,
171
+ self.beta,
172
+ self.moving_mean,
173
+ self.moving_variance)[0]
174
+ if len(x_shape) < 4:
175
+ bn_out = self.reshape(bn_out, x_shape)
176
+ return bn_out
177
+
178
+ def extend_repr(self):
179
+ return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
180
+ self.num_features, self.eps, 1.0 - self.momentum, self.gamma, self.beta, \
181
+ self.moving_mean, self.moving_variance)
182
+
183
+
184
+ class BatchNorm1d(_BatchNorm):
185
+ r"""
186
+ This layer
187
+ applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D or 2D inputs) to
188
+ reduce internal covariate shift. Batch Normalization is widely used in convolutional networks.
189
+ For the setailed contents, refer to `Batch Normalization: Accelerating Deep Network Training by
190
+ Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
191
+ rescales and recenters the feature using a mini-batch of data and
192
+ the learned parameters which can be described in the following formula.
193
+
194
+ .. math::
195
+ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
196
+
197
+ Note:
198
+ The implementation of BatchNorm is different in graph mode and pynative mode, therefore the mode is not
199
+ recommended to be changed after net was initialized.
200
+
201
+ Args:
202
+ num_features (int): number of features or channels `C` of the input `x` .
203
+ eps (float): :math:`\epsilon` added to the denominator for numerical stability. Default: ``1e-5`` .
204
+ momentum (float): A floating hyperparameter of the momentum for the
205
+ running_mean and running_var computation. Default: ``0.9`` .
206
+ affine (bool): A bool value. When set to ``True`` , :math:`\gamma` and :math:`\beta` can be learned.
207
+ Default: ``True`` .
208
+ gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
209
+ The values of str refer to the function `mindspore.common.initializer
210
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.common.initializer.html>`_
211
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
212
+ beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
213
+ The values of str refer to the function `mindspore.common.initializer
214
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.common.initializer.html>`_
215
+ including ``'zeros'`` , ``'ones'``, etc. Default: ``'zeros'`` .
216
+ moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
217
+ The values of str refer to the function `mindspore.common.initializer
218
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.common.initializer.html>`_
219
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'zeros'`` .
220
+ moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
221
+ The values of str refer to the function `mindspore.common.initializer
222
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.common.initializer.html>`_
223
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
224
+ use_batch_statistics (bool): If ``true`` , use the mean value and variance value of current batch data. If
225
+ ``false`` , use the mean value and variance value of specified value. If ``None`` , the training process
226
+ will use the mean and variance of current batch data and track the running mean and variance, the
227
+ evaluation process will use the running mean and variance. Default: ``None`` .
228
+ data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'`` .
229
+ Default: ``'NCHW'`` .
230
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
231
+
232
+ Inputs:
233
+ - **x** (Tensor) - Tensor of shape :math:`(N, C)` or :math:`(N, C, L)` ,
234
+ where `N` is the batch size, `C` is the number of features or channels, and `L` is the sequence length.
235
+ Supported types: float16, float32.
236
+
237
+ Outputs:
238
+ Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C)` or :math:`(N, C, L)` .
239
+
240
+ Raises:
241
+ TypeError: If `num_features` is not an int.
242
+ TypeError: If `eps` is not a float.
243
+ ValueError: If `num_features` is less than 1.
244
+ ValueError: If `momentum` is not in range [0, 1].
245
+
246
+ Supported Platforms:
247
+ ``Ascend`` ``GPU`` ``CPU``
248
+
249
+ Examples:
250
+ >>> import numpy as np
251
+ >>> import mindspore as ms
252
+ >>> net = ms.nn.BatchNorm1d(num_features=4)
253
+ >>> x = ms.Tensor(np.array([[0.7, 0.5, 0.5, 0.6],
254
+ ... [0.5, 0.4, 0.6, 0.9]]).astype(np.float32))
255
+ >>> output = net(x)
256
+ >>> print(output)
257
+ [[ 0.6999965 0.4999975 0.4999975 0.59999704 ]
258
+ [ 0.4999975 0.399998 0.59999704 0.89999545 ]]
259
+ """
260
+
261
+ @staticmethod
262
+ @_primexpr
263
+ def _check_input_dim(shape, cls_name):
264
+ def _check(dim):
265
+ if dim not in (2, 3):
266
+ raise ValueError(f"For '{cls_name}', the must have 2 dims or 3 dims, but got {dim}.")
267
+ dim = len(shape)
268
+ _check(dim)
269
+
270
+
271
+ class BatchNorm2d(_BatchNorm):
272
+ r"""
273
+ Batch Normalization is widely used in convolutional networks. This layer
274
+ applies Batch Normalization over a 4D input (a mini-batch of 2D inputs with
275
+ additional channel dimension) to avoid internal covariate shift as described
276
+ in the paper `Batch Normalization: Accelerating Deep Network Training by
277
+ Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It
278
+ rescales and recenters the feature using a mini-batch of data and
279
+ the learned parameters which can be described in the following formula.
280
+
281
+ .. math::
282
+ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
283
+
284
+ Note:
285
+ The implementation of BatchNorm is different in graph mode and pynative mode, therefore that mode can not be
286
+ changed after net was initialized.
287
+ Note that the formula for updating the :math:`moving\_mean` and :math:`moving\_var` is
288
+
289
+ .. math::
290
+ \text{moving_mean}=\text{moving_mean*momentum}+μ_β\text{*(1−momentum)}\\
291
+ \text{moving_var}=\text{moving_var*momentum}+σ^2_β\text{*(1−momentum)}
292
+
293
+ where :math:`moving\_mean` is the updated mean, :math:`moving\_var` is the updated variance,
294
+ :math:`μ_β, σ^2_β` are the observed value (mean and variance) of each batch of data.
295
+
296
+ Args:
297
+ num_features (int): The number of channels of the input tensor. Expected input size is :math:`(N, C, H, W)`,
298
+ `C` represents the number of channels.
299
+ eps (float): :math:`\epsilon` added to the denominator for numerical stability. Default: ``1e-5`` .
300
+ momentum (float): A floating hyperparameter of the momentum for the
301
+ running_mean and running_var computation. Default: ``0.9`` .
302
+ affine (bool): A bool value. When set to ``True`` , :math:`\gamma` and :math:`\beta` can be learned.
303
+ Default: ``True`` .
304
+ gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
305
+ The values of str refer to the function `mindspore.common.initializer
306
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.common.initializer.html>`_
307
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
308
+ beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
309
+ The values of str refer to the function `mindspore.common.initializer
310
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.common.initializer.html>`_
311
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'zeros'`` .
312
+ moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
313
+ The values of str refer to the function `mindspore.common.initializer
314
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.common.initializer.html>`_
315
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'zeros'`` .
316
+ moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
317
+ The values of str refer to the function `mindspore.common.initializer
318
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.common.initializer.html>`_
319
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
320
+ use_batch_statistics (bool): Default: ``None`` .
321
+
322
+ - If ``true`` , use the mean value and variance value of current batch data and track running mean
323
+ and running variance.
324
+ - If ``false`` , use the mean value and variance value of specified value, and not track statistical value.
325
+ - If ``None`` , the use_batch_statistics is automatically set to ``true`` or ``false`` according to the
326
+ training and evaluation mode. During training, the parameter is set to true, and during evaluation, the
327
+ parameter is set to false.
328
+
329
+ data_format (str): The optional value for data format, is ``'NHWC'`` or ``'NCHW'`` .
330
+ Default: ``'NCHW'`` .
331
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
332
+
333
+ Inputs:
334
+ - **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. Supported types: float16, float32.
335
+
336
+ Outputs:
337
+ Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, H, W)`.
338
+
339
+ Raises:
340
+ TypeError: If `num_features` is not an int.
341
+ TypeError: If `eps` is not a float.
342
+ ValueError: If `num_features` is less than 1.
343
+ ValueError: If `momentum` is not in range [0, 1].
344
+ ValueError: If `data_format` is neither 'NHWC' not 'NCHW'.
345
+
346
+ Supported Platforms:
347
+ ``Ascend`` ``GPU`` ``CPU``
348
+
349
+ Examples:
350
+ >>> import numpy as np
351
+ >>> import mindspore as ms
352
+ >>> net = ms.nn.BatchNorm2d(num_features=3)
353
+ >>> x = ms.Tensor(np.ones([1, 3, 2, 2]).astype(np.float32))
354
+ >>> output = net(x)
355
+ >>> print(output)
356
+ [[[[ 0.999995 0.999995 ]
357
+ [ 0.999995 0.999995 ]]
358
+ [[ 0.999995 0.999995 ]
359
+ [ 0.999995 0.999995 ]]
360
+ [[ 0.999995 0.999995 ]
361
+ [ 0.999995 0.999995 ]]]]
362
+ """
363
+
364
+ @staticmethod
365
+ @_primexpr
366
+ def _check_input_dim(shape, cls_name):
367
+ dim = len(shape)
368
+ _check_dim(dim, 4, cls_name)
369
+
370
+
371
+ class BatchNorm3d(Cell):
372
+ r"""
373
+ Batch Normalization is widely used in convolutional networks. This layer
374
+ applies Batch Normalization over a 5D input (a mini-batch of 3D inputs with
375
+ additional channel dimension) to avoid internal covariate shift.
376
+
377
+ .. math::
378
+ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
379
+
380
+ Note:
381
+ The implementation of BatchNorm is different in graph mode and pynative mode, therefore that mode can not be
382
+ changed after net was initialized.
383
+ Note that the formula for updating the running_mean and running_var is
384
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`,
385
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
386
+
387
+ Args:
388
+ num_features (int): `C` from an expected input of size :math:`(N, C, D, H, W)` .
389
+ eps (float): A value added to the denominator for numerical stability. Default: ``1e-5`` .
390
+ momentum (float): A floating hyperparameter of the momentum for the
391
+ running_mean and running_var computation. Default: ``0.9`` .
392
+ affine (bool): A bool value. When set to ``True`` , gamma and beta can be learned. Default: ``True`` .
393
+ gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
394
+ The values of str refer to the function `mindspore.common.initializer
395
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.common.initializer.html>`_
396
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
397
+ beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
398
+ The values of str refer to the function `mindspore.common.initializer
399
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.common.initializer.html>`_
400
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'zeros'`` .
401
+ moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
402
+ The values of str refer to the function `mindspore.common.initializer
403
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.common.initializer.html>`_
404
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'zeros'`` .
405
+ moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
406
+ The values of str refer to the function `mindspore.common.initializer
407
+ <https://www.mindspore.cn/docs/en/master/api_python/mindspore.common.initializer.html>`_
408
+ including ``'zeros'`` , ``'ones'`` , etc. Default: ``'ones'`` .
409
+ use_batch_statistics (bool): If true, use the mean value and variance value of current batch data. If
410
+ ``false``, use the mean value and variance value of specified value. If ``None`` , the training process
411
+ will use the mean and variance of current batch data and track the running mean and variance, the
412
+ evaluation process will use the running mean and variance. Default: ``None`` .
413
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
414
+
415
+ Inputs:
416
+ - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
417
+ Supported types: float16, float32.
418
+
419
+ Outputs:
420
+ Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, D_{out},H_{out}, W_{out})`.
421
+
422
+ Raises:
423
+ TypeError: If `num_features` is not an int.
424
+ TypeError: If `eps` is not a float.
425
+ ValueError: If `num_features` is less than 1.
426
+ ValueError: If `momentum` is not in range [0, 1].
427
+
428
+ Supported Platforms:
429
+ ``Ascend`` ``GPU`` ``CPU``
430
+
431
+ Examples:
432
+ >>> import numpy as np
433
+ >>> import mindspore as ms
434
+ >>> net = ms.nn.BatchNorm3d(num_features=3)
435
+ >>> x = ms.Tensor(np.ones([16, 3, 10, 32, 32]).astype(np.float32))
436
+ >>> output = net(x)
437
+ >>> print(output.shape)
438
+ (16, 3, 10, 32, 32)
439
+ """
440
+
441
+ def __init__(self,
442
+ num_features,
443
+ eps=1e-5,
444
+ momentum=0.9,
445
+ affine=True,
446
+ gamma_init='ones',
447
+ beta_init='zeros',
448
+ moving_mean_init='zeros',
449
+ moving_var_init='ones',
450
+ use_batch_statistics=None,
451
+ dtype=mstype.float32):
452
+ """Initialize BatchNorm3d."""
453
+ super(BatchNorm3d, self).__init__()
454
+ self.bn2d = BatchNorm2d(num_features=num_features,
455
+ eps=eps,
456
+ momentum=momentum,
457
+ affine=affine,
458
+ gamma_init=gamma_init,
459
+ beta_init=beta_init,
460
+ moving_mean_init=moving_mean_init,
461
+ moving_var_init=moving_var_init,
462
+ use_batch_statistics=use_batch_statistics,
463
+ data_format="NCHW",
464
+ dtype=dtype)
465
+ self.shape = P.Shape()
466
+ self.reshape = P.Reshape()
467
+
468
+ @staticmethod
469
+ @_primexpr
470
+ def _check_input_dim(shape, cls_name):
471
+ dim = len(shape)
472
+ _check_dim(dim, 5, cls_name)
473
+
474
+ def construct(self, x):
475
+ x_shape = self.shape(x)
476
+ self._check_input_dim(x_shape, self.cls_name)
477
+ x = self.reshape(x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
478
+ bn2d_out = self.bn2d(x)
479
+ bn3d_out = self.reshape(bn2d_out, x_shape)
480
+ return bn3d_out
481
+
482
+
483
+ SYNCBN_GROUP_DICT = None
484
+
485
+
486
+ def _syncbatchnorm_group_dict():
487
+ global SYNCBN_GROUP_DICT
488
+ if SYNCBN_GROUP_DICT is None:
489
+ SYNCBN_GROUP_DICT = dict()
490
+ return SYNCBN_GROUP_DICT
491
+
492
+
493
+ class SyncBatchNorm(_BatchNorm):
494
+ r"""
495
+ Sync Batch Normalization layer over a N-dimension input.
496
+
497
+ Sync Batch Normalization is cross device synchronized Batch Normalization. The implementation of Batch
498
+ Normalization only normalizes the data within each device. Sync Batch Normalization will normalize the input
499
+ within the group. It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by
500
+ Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the
501
+ feature using a mini-batch of data and the learned parameters which can be described in the following formula.
502
+
503
+ .. math::
504
+ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
505
+
506
+ Note:
507
+ Currently, SyncBatchNorm only supports 2D and 4D inputs.
508
+ :math:`\gamma` and :math:`\beta` are trainable scale and shift.
509
+
510
+ Args:
511
+ num_features (int): `C` from an expected input of size :math:`(N, C, H, W)`.
512
+ eps (float): :math:`\epsilon`, a value added to the denominator for numerical stability. Default: ``1e-5`` .
513
+ momentum (float): A floating hyperparameter of the momentum for the
514
+ running_mean and running_var computation. Default: ``0.9`` .
515
+ affine (bool): A bool value. When set to ``True`` , :math:`\gamma` and :math:`\beta` can be learned.
516
+ Default: ``True`` .
517
+ gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
518
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
519
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'ones'`` .
520
+ beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
521
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
522
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'zeros'`` .
523
+ moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean.
524
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
525
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'zeros'`` .
526
+ moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance.
527
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
528
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'ones'`` .
529
+ use_batch_statistics (bool): If ``true`` , use the mean value and variance value of current batch data. If
530
+ ``false`` , use the mean value and variance value of specified value. If ``None`` , training process will
531
+ use the mean and variance of current batch data and track the running mean and variance, eval process will
532
+ use the running mean and variance. Default: ``None`` .
533
+ process_groups (list): A list to divide devices into different sync groups, containing N subtraction lists.
534
+ Each subtraction list contains int numbers identifying rank ids which need to be synchronized in the same
535
+ group. All int values must be in [0, rank_size) and different from each other. Default: ``None`` ,
536
+ indicating synchronization across all devices.
537
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
538
+
539
+ Inputs:
540
+ - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
541
+
542
+ Outputs:
543
+ Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
544
+
545
+ Raises:
546
+ TypeError: If `num_features` is not an int.
547
+ TypeError: If `eps` is not a float.
548
+ TypeError: If `process_groups` is not a list.
549
+ ValueError: If `num_features` is less than 1.
550
+ ValueError: If `momentum` is not in range [0, 1].
551
+ ValueError: If rank_id in `process_groups` is not in range [0, rank_size).
552
+
553
+ Supported Platforms:
554
+ ``Ascend``
555
+
556
+ Examples:
557
+ .. note::
558
+ Before running the following examples, you need to configure the communication environment variables.
559
+
560
+ For the Ascend devices, users need to prepare the rank table, set rank_id and device_id.
561
+ Please see the `Ascend tutorial
562
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/rank_table.html>`_
563
+ for more details.
564
+
565
+ For the GPU devices, users need to prepare the host file and mpi, please see the `mpirun Startup
566
+ <https://www.mindspore.cn/docs/en/master/model_train/parallel/mpirun.html>`_ .
567
+
568
+ For the CPU device, users need to write a dynamic cluster startup script, please see the `Dynamic Cluster
569
+ Startup <https://www.mindspore.cn/docs/en/master/model_train/parallel/dynamic_cluster.html>`_ .
570
+
571
+ This example should be run with multiple devices.
572
+
573
+ >>> import numpy as np
574
+ >>> import mindspore as ms
575
+ >>> from mindspore.communication import init
576
+ >>>
577
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
578
+ >>> init()
579
+ >>> ms.reset_auto_parallel_context()
580
+ >>> ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL)
581
+ >>> sync_bn_op = ms.nn.SyncBatchNorm(num_features=3, process_groups=[[0, 1], [2, 3]])
582
+ >>> x = ms.Tensor(np.ones([1, 3, 2, 2]), ms.float32)
583
+ >>> output = sync_bn_op(x)
584
+ >>> print(output)
585
+ [[[[ 0.999995 0.999995 ]
586
+ [ 0.999995 0.999995 ]]
587
+ [[ 0.999995 0.999995 ]
588
+ [ 0.999995 0.999995 ]]
589
+ [[ 0.999995 0.999995 ]
590
+ [ 0.999995 0.999995 ]]]]
591
+ """
592
+ @cell_attr_register(attrs=['num_features', 'process_groups'])
593
+ def __init__(self,
594
+ num_features,
595
+ eps=1e-5,
596
+ momentum=0.9,
597
+ affine=True,
598
+ gamma_init='ones',
599
+ beta_init='zeros',
600
+ moving_mean_init='zeros',
601
+ moving_var_init='ones',
602
+ use_batch_statistics=None,
603
+ process_groups=None,
604
+ dtype=mstype.float32):
605
+ """Initialize SyncBatchNorm."""
606
+ super(SyncBatchNorm, self).__init__(num_features,
607
+ eps,
608
+ momentum,
609
+ affine,
610
+ gamma_init,
611
+ beta_init,
612
+ moving_mean_init,
613
+ moving_var_init,
614
+ use_batch_statistics,
615
+ dtype=dtype)
616
+ self.is_global = False
617
+ self.group_name = None
618
+ self.process_groups = process_groups
619
+ if self.process_groups != 0:
620
+ self.rank_id = get_rank()
621
+ self.rank_size = get_group_size()
622
+ if self.process_groups is not None:
623
+ validator.check_isinstance("process_groups", self.process_groups, list)
624
+ self._check_rank_ids(self.process_groups, self.rank_size)
625
+ self._create_sync_groups()
626
+ elif self.rank_size > 1:
627
+ self.is_global = True
628
+ self.group_device_num = self.rank_size
629
+ if context.get_context("device_target") == "Ascend":
630
+ self.group_name = "hccl_world_group"
631
+ elif context.get_context("device_target") == "GPU":
632
+ self.group_name = "nccl_world_group"
633
+
634
+ if self.is_global:
635
+ self.bn_train = inner.SyncBatchNorm(epsilon=self.eps,
636
+ momentum=self.momentum,
637
+ group=self.group_name,
638
+ device_num=self.group_device_num)
639
+
640
+ def _create_sync_groups(self):
641
+ """ create groups by process groups. """
642
+ for sub_group in self.process_groups:
643
+ validator.check_isinstance("sub group", sub_group, list)
644
+ self.group_device_num = len(sub_group)
645
+ if self.rank_id in sub_group and self.group_device_num > 1:
646
+ self.is_global = True
647
+ rank_list_name = '_'.join('%s' % id for id in sub_group)
648
+ group_dict = _syncbatchnorm_group_dict()
649
+ if rank_list_name not in group_dict:
650
+ md5 = hashlib.md5()
651
+ md5.update(rank_list_name.encode('utf-8'))
652
+ hash_name = md5.hexdigest()
653
+ self.group_name = str(self.group_device_num) + '_' + hash_name
654
+ group_dict[rank_list_name] = self.group_name
655
+ management.create_group(self.group_name, sub_group)
656
+ logger.info("create group for sync batchnorm, the rank list is {}, the group name is {}".format(
657
+ rank_list_name, self.group_name))
658
+ else:
659
+ self.group_name = group_dict[rank_list_name]
660
+ logger.info("the group for {} already exists, no need to create".format(rank_list_name))
661
+
662
+ def _check_rank_ids(self, process_groups, rank_size):
663
+ seen = set()
664
+ for rid in itertools.chain(*process_groups):
665
+ validator.check_int_range(rid, 0, rank_size, validator.INC_LEFT, "rank id in process_groups", self.cls_name)
666
+ if rid in seen:
667
+ raise ValueError(f"For '{self.cls_name}', rank id in 'process_groups' must not be duplicated, "
668
+ f"but got {process_groups}.")
669
+ seen.add(rid)
670
+
671
+ @staticmethod
672
+ @_primexpr
673
+ def _check_input_dim(shape, cls_name):
674
+ def _check(dim):
675
+ if dim not in (2, 4):
676
+ raise ValueError(f"For '{cls_name}', the must have 2 dims or 4 dims, but got {dim}.")
677
+ dim = len(shape)
678
+ _check(dim)
679
+
680
+
681
+ class LayerNorm(Cell):
682
+ r"""
683
+ Applies Layer Normalization over a mini-batch of inputs.
684
+
685
+ Layer Normalization is widely used in recurrent neural networks. It applies
686
+ normalization on a mini-batch of inputs for each single training case as described
687
+ in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_. Unlike Batch
688
+ Normalization, Layer Normalization performs exactly the same computation at training and
689
+ testing time. It is applied across all channels and pixel but only one batch size.
690
+ :math:`\gamma` and :math:`\beta` are trainable scale and shift.
691
+ It can be described using the following formula:
692
+
693
+ .. math::
694
+ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
695
+
696
+ Args:
697
+ normalized_shape (Union(tuple[int], list[int])): The normalization is performed over axis
698
+ `begin_norm_axis ... R - 1`. R is the dimension size of input `x`.
699
+ begin_norm_axis (int): The first normalization dimension: normalization will be performed along dimensions
700
+ `begin_norm_axis: R`, the value should be in [-1, R). Default: ``-1`` .
701
+ begin_params_axis (int): The begin axis of the parameter input :math:`(\gamma, \beta)` to
702
+ apply LayerNorm, the value should be in [-1, R). Default: ``-1`` .
703
+ gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\gamma` weight.
704
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
705
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'ones'`` .
706
+ beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the :math:`\beta` weight.
707
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
708
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'zeros'`` .
709
+ epsilon (float): A value added to the denominator for numerical stability(:math:`\epsilon`). Default: ``1e-7`` .
710
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
711
+
712
+ Inputs:
713
+ - **x** (Tensor) - The shape of `x` is :math:`(x_1, x_2, ..., x_R)`,
714
+ and `input_shape[begin_norm_axis:]` is equal to `normalized_shape`.
715
+
716
+ Outputs:
717
+ Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `x`.
718
+
719
+ Raises:
720
+ TypeError: If `normalized_shape` is neither a list nor tuple.
721
+ TypeError: If `begin_norm_axis` or `begin_params_axis` is not an int.
722
+ TypeError: If `epsilon` is not a float.
723
+
724
+ Supported Platforms:
725
+ ``Ascend`` ``GPU`` ``CPU``
726
+
727
+ Examples:
728
+ >>> import mindspore as ms
729
+ >>> import numpy as np
730
+ >>> x = ms.Tensor(np.ones([20, 5, 10, 10]), ms.float32)
731
+ >>> shape1 = x.shape[1:]
732
+ >>> m = ms.nn.LayerNorm(shape1, begin_norm_axis=1, begin_params_axis=1)
733
+ >>> output = m(x).shape
734
+ >>> print(output)
735
+ (20, 5, 10, 10)
736
+ """
737
+
738
+ def __init__(self,
739
+ normalized_shape,
740
+ begin_norm_axis=-1,
741
+ begin_params_axis=-1,
742
+ gamma_init='ones',
743
+ beta_init='zeros',
744
+ epsilon=1e-7,
745
+ dtype=mstype.float32
746
+ ):
747
+ """Initialize LayerNorm."""
748
+ super(LayerNorm, self).__init__()
749
+ if not isinstance(normalized_shape, (tuple, list)):
750
+ raise TypeError(f"For '{self.cls_name}', the type of 'normalized_shape' must be tuple[int] or list[int], "
751
+ f"but got {normalized_shape} and the type is {type(normalized_shape)}.")
752
+ if not normalized_shape:
753
+ raise ValueError(
754
+ f"Expected normalized_shape to be at least 1-dimensional, i.e., containing at "
755
+ f"least one element, but got normalized_shape = {normalized_shape}"
756
+ )
757
+ self.normalized_shape = normalized_shape
758
+ self.begin_norm_axis = begin_norm_axis
759
+ self.begin_params_axis = begin_params_axis
760
+ self.epsilon = epsilon
761
+ self.gamma = Parameter(initializer(
762
+ gamma_init, normalized_shape, dtype=dtype), name="gamma")
763
+ self.beta = Parameter(initializer(
764
+ beta_init, normalized_shape, dtype=dtype), name="beta")
765
+ self.layer_norm = P.LayerNorm(begin_norm_axis=self.begin_norm_axis,
766
+ begin_params_axis=self.begin_params_axis,
767
+ epsilon=self.epsilon)
768
+
769
+ def construct(self, input_x):
770
+ y, _, _ = self.layer_norm(input_x, self.gamma.astype(input_x.dtype), self.beta.astype(input_x.dtype))
771
+ return y
772
+
773
+ def extend_repr(self):
774
+ return 'normalized_shape={}, begin_norm_axis={}, begin_params_axis={}, gamma{}, beta={}'.format(
775
+ self.normalized_shape, self.begin_norm_axis, self.begin_params_axis, self.gamma, self.beta)
776
+
777
+
778
+ class LayerNormExt(Cell):
779
+ r"""
780
+ Applies Layer Normalization over a mini-batch of inputs.
781
+
782
+ Layer Normalization is widely used in recurrent neural networks. It applies
783
+ normalization on a mini-batch of inputs for each single training case as described
784
+ in the paper `Layer Normalization <https://arxiv.org/pdf/1607.06450.pdf>`_.
785
+
786
+ Unlike Batch Normalization, Layer Normalization performs exactly the same computation at training and
787
+ testing time. It is applied across all channels and pixel but only one batch size.
788
+ :math:`\gamma` is the scale value learned through training and :math:`\beta` is the shift value.
789
+ It can be described using the following formula:
790
+
791
+ .. math::
792
+ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
793
+
794
+ .. warning::
795
+ This is an experimental API that is subject to change or deletion.
796
+
797
+ Args:
798
+ normalized_shape (Union(tuple[int], list[int], int)): The normalized shape of `x` for LayerNorm
799
+ eps (float): A value added to the denominator for numerical stability(:math:`\epsilon`). Default: ``1e-5`` .
800
+ elementwise_affine (bool): Whether affine transformation is required. When this parameter is set to ``True``,
801
+ the weight parameter is initialized to 1 and the offset is initialized to 0. Default: ``True``.
802
+ bias (bool): If set to ``False``, the layer will not learn an additive bias (only relevant if
803
+ `elementwise_affine` is ``True``). Default: ``True``.
804
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``None`` .
805
+
806
+ Inputs:
807
+ - **x** (Tensor) - The shape is :math:`(N, *)`, where :math:`*` is equal to normalized_shape.
808
+
809
+ Outputs:
810
+ Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `x`.
811
+
812
+ Raises:
813
+ TypeError: If `eps` is not a float.
814
+
815
+ Supported Platforms:
816
+ ``Ascend``
817
+
818
+ Examples:
819
+ >>> import mindspore as ms
820
+ >>> import numpy as np
821
+ >>> x = ms.Tensor(np.ones([20, 5, 10, 10]), ms.float32)
822
+ >>> shape1 = x.shape[1:]
823
+ >>> m = ms.nn.LayerNormExt(shape1)
824
+ >>> output = m(x).shape
825
+ >>> print(output)
826
+ (20, 5, 10, 10)
827
+ """
828
+
829
+ def __init__(self,
830
+ normalized_shape,
831
+ eps=1e-5,
832
+ elementwise_affine=True,
833
+ bias=True,
834
+ dtype=None
835
+ ):
836
+ """Initialize LayerNormExt."""
837
+ super(LayerNormExt, self).__init__()
838
+ if isinstance(normalized_shape, numbers.Integral):
839
+ # mypy error: incompatible types in assignment
840
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
841
+ if not isinstance(normalized_shape, (tuple, list)):
842
+ raise TypeError(f"For '{self.cls_name}', the type of 'normalized_shape' must be tuple[int] or list[int], "
843
+ f"but got {normalized_shape} and the type is {type(normalized_shape)}.")
844
+ if not normalized_shape:
845
+ raise ValueError(
846
+ f"Expected normalized_shape to be at least 1-dimensional, i.e., containing at "
847
+ f"least one element, but got normalized_shape = {normalized_shape}"
848
+ )
849
+ self.normalized_shape = tuple(normalized_shape)
850
+ self.eps = eps
851
+ self.elementwise_affine = elementwise_affine
852
+ ms_dtype = mstype.float32 if dtype is None else dtype
853
+ if self.elementwise_affine:
854
+ self.weight = Parameter(Tensor(np.ones(normalized_shape), ms_dtype), name="weight")
855
+ if bias:
856
+ self.bias = Parameter(Tensor(np.zeros(normalized_shape), ms_dtype), name="bias")
857
+ else:
858
+ self.bias = None
859
+ else:
860
+ self.weight = None
861
+ self.bias = None
862
+
863
+ def construct(self, input):
864
+ y = ops.layer_norm(input, self.normalized_shape, self.weight,
865
+ self.bias, self.eps)
866
+ return y
867
+
868
+ def extend_repr(self):
869
+ return 'normalized_shape={}, eps={}, elementwise_affine={}'.format(
870
+ self.normalized_shape, self.eps, self.elementwise_affine)
871
+
872
+
873
+ class _InstanceNorm(Cell):
874
+ """Instance Normalization base class."""
875
+ @cell_attr_register
876
+ def __init__(self,
877
+ num_features,
878
+ eps=1e-5,
879
+ momentum=0.1,
880
+ affine=True,
881
+ gamma_init='ones',
882
+ beta_init='zeros',
883
+ dtype=mstype.float32):
884
+ """Initialize Normalization base class."""
885
+ super(_InstanceNorm, self).__init__()
886
+ validator.check_value_type('num_features', num_features, [int], self.cls_name)
887
+ validator.check_value_type('eps', eps, [float], self.cls_name)
888
+ validator.check_value_type('momentum', momentum, [float], self.cls_name)
889
+ validator.check_value_type('affine', affine, [bool], self.cls_name)
890
+ args_input = {"gamma_init": gamma_init, "beta_init": beta_init}
891
+ self.check_types_valid(args_input, 'InstanceNorm2d')
892
+ if num_features < 1:
893
+ raise ValueError(f"For '{self.cls_name}', the 'num_features' must be at least 1, but got {num_features}.")
894
+
895
+ if momentum < 0 or momentum > 1:
896
+ raise ValueError(f"For '{self.cls_name}', the 'momentum' must be a number in range [0, 1], "
897
+ f"but got {momentum}.")
898
+ self.num_features = num_features
899
+ self.eps = eps
900
+ self.moving_mean = Parameter(initializer('zeros', num_features, dtype=dtype), name="mean", requires_grad=False)
901
+ self.moving_variance = Parameter(initializer('ones', num_features, dtype=dtype), name="variance",
902
+ requires_grad=False)
903
+ self.gamma = Parameter(initializer(
904
+ gamma_init, num_features, dtype=dtype), name="gamma", requires_grad=affine)
905
+ self.beta = Parameter(initializer(
906
+ beta_init, num_features, dtype=dtype), name="beta", requires_grad=affine)
907
+
908
+ self.shape = P.Shape()
909
+ self.momentum = momentum
910
+ self.instance_bn = P.InstanceNorm(epsilon=self.eps, momentum=self.momentum)
911
+
912
+ def construct(self, x):
913
+ self._check_input_dim(self.shape(x), self.cls_name)
914
+ return self.instance_bn(x,
915
+ self.gamma,
916
+ self.beta,
917
+ self.moving_mean,
918
+ self.moving_variance)[0]
919
+
920
+ def extend_repr(self):
921
+ return 'num_features={}, eps={}, momentum={}, gamma={}, beta={}, moving_mean={}, moving_variance={}'.format(
922
+ self.num_features, self.eps, self.momentum, self.gamma, self.beta, self.moving_mean, self.moving_variance)
923
+
924
+ def check_types_valid(self, args_dict, name):
925
+ for key, _ in args_dict.items():
926
+ val = args_dict[key]
927
+ if not isinstance(val, (Tensor, numbers.Number, str, Initializer)):
928
+ raise TypeError(f"For '{self.cls_name}', the type of '{key}' must be in "
929
+ f"[Tensor, numbers.Number, str, Initializer], but got type {type(val).__name__}.")
930
+ if isinstance(val, Tensor) and val.dtype != mstype.float32:
931
+ raise TypeError(f"For '{self.cls_name}', the type of '{key}' must be float32, "
932
+ f"but got {val.dtype}.")
933
+
934
+
935
+ class InstanceNorm1d(_InstanceNorm):
936
+ r"""
937
+ This layer applies Instance Normalization over a 3D input (a mini-batch of 1D inputs with
938
+ additional channel dimension). Refer to the paper `Instance Normalization: The Missing Ingredient for
939
+ Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
940
+ of data and the learned parameters which can be described in the following formula.
941
+
942
+ .. math::
943
+ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
944
+
945
+ The size of :math:`\gamma` and :math:`\beta`, learnable parameters vectors, is num_features if affine is True.
946
+ The standard-deviation is calculated via the biased estimator.
947
+
948
+ This layer uses instance statistics computed from input data in both training and evaluation modes.
949
+
950
+ InstanceNorm1d and BatchNorm1d are very similar, but have some differences. InstanceNorm1d is applied on each
951
+ channel of channeled data like RGB images, but BatchNorm1d is usually applied on each batch of batched data.
952
+
953
+ Note:
954
+ Note that the formula for updating the running_mean and running_var is
955
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`,
956
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
957
+
958
+ Args:
959
+ num_features (int): `C` from an expected input of size :math:`(N, C, L)`.
960
+ eps (float): A value added to the denominator for numerical stability. Default: ``1e-5`` .
961
+ momentum (float): A floating hyperparameter of the momentum for the
962
+ running_mean and running_var computation. Default: ``0.1`` .
963
+ affine (bool): A bool value. When set to True, gamma and beta can be learned. Default: ``True`` .
964
+ gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
965
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` , etc.
966
+ When initialized with Tensor, the shape should be :math:`(C)`. Default: ``'ones'`` .
967
+ beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
968
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` , etc.
969
+ When initialized with Tensor, the shape should be :math:`(C)`. Default: ``'zeros'`` .
970
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
971
+
972
+ Inputs:
973
+ - **x** (Tensor) - Tensor of shape :math:`(N, C, L)`. Data type: float16 or float32.
974
+
975
+ Outputs:
976
+ Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, L)`. Same type and
977
+ shape as the `x`.
978
+
979
+ Raises:
980
+ TypeError: If the type of `num_features` is not int.
981
+ TypeError: If the type of `eps` is not float.
982
+ TypeError: If the type of `momentum` is not float.
983
+ TypeError: If the type of `affine` is not bool.
984
+ TypeError: If the type of `gamma_init`/`beta_init` is not same, or if the initialized element type is not
985
+ float32.
986
+ ValueError: If `num_features` is less than 1.
987
+ ValueError: If `momentum` is not in range [0, 1].
988
+ ValueError: If the shape of `gamma_init` / `beta_init` is not :math:`(C)`.
989
+ KeyError: If any of `gamma_init`/`beta_init` is str and the homonymous class inheriting from `Initializer` not
990
+ exists.
991
+
992
+ Supported Platforms:
993
+ ``GPU``
994
+
995
+ Examples:
996
+ >>> import mindspore as ms
997
+ >>> import numpy as np
998
+ >>> net = ms.nn.InstanceNorm1d(3)
999
+ >>> x = ms.Tensor(np.ones([2, 3, 5]), ms.float32)
1000
+ >>> output = net(x)
1001
+ >>> print(output.shape)
1002
+ (2, 3, 5)
1003
+ """
1004
+
1005
+ @staticmethod
1006
+ @_primexpr
1007
+ def _check_input_dim(shape, cls_name):
1008
+ dim = len(shape)
1009
+ _check_dim(dim, 3, cls_name)
1010
+
1011
+
1012
+
1013
+ class InstanceNorm2d(_InstanceNorm):
1014
+ r"""
1015
+ This layer applies Instance Normalization over a 4D input (a mini-batch of 2D inputs with
1016
+ additional channel dimension). Refer to the paper `Instance Normalization: The Missing Ingredient for
1017
+ Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
1018
+ of data and the learned parameters which can be described in the following formula.
1019
+
1020
+ .. math::
1021
+ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
1022
+
1023
+ :math:`\gamma` and :math:`\beta` are learnable parameter vectors of size num_features if affine is True.
1024
+ The standard-deviation is calculated via the biased estimator.
1025
+
1026
+ This layer uses instance statistics computed from input data in both training and evaluation modes.
1027
+
1028
+ InstanceNorm2d and BatchNorm2d are very similar, but have some differences. InstanceNorm2d is applied on each
1029
+ channel of channeled data like RGB images, but BatchNorm2d is usually applied on each batch of batched data.
1030
+
1031
+ Note:
1032
+ Note that the formula for updating the running_mean and running_var is
1033
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`,
1034
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
1035
+
1036
+ Args:
1037
+ num_features (int): `C` from an expected input of size :math:`(N, C, H, W)`.
1038
+ eps (float): A value added to the denominator for numerical stability. Default: ``1e-5`` .
1039
+ momentum (float): A floating hyperparameter of the momentum for the
1040
+ running_mean and running_var computation. Default: ``0.1`` .
1041
+ affine (bool): A bool value. When set to ``True`` , gamma and beta can be learned. Default: ``True`` .
1042
+ gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
1043
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` , etc.
1044
+ When initialized with Tensor, the shape should be :math:`(C)`. Default: ``'ones'`` .
1045
+ beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
1046
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` , etc.
1047
+ When initialized with Tensor, the shape should be :math:`(C)`. Default: ``'zeros'`` .
1048
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
1049
+
1050
+ Inputs:
1051
+ - **x** (Tensor) - Tensor of shape :math:`(N, C, H, W)`. Data type: float16 or float32.
1052
+
1053
+ Outputs:
1054
+ Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, H, W)`. Same type and
1055
+ shape as the `x`.
1056
+
1057
+ Raises:
1058
+ TypeError: If the type of `num_features` is not int.
1059
+ TypeError: If the type of `eps` is not float.
1060
+ TypeError: If the type of `momentum` is not float.
1061
+ TypeError: If the type of `affine` is not bool.
1062
+ TypeError: If the type of `gamma_init`/`beta_init` is not same, or if the initialized element type is not
1063
+ float32.
1064
+ ValueError: If `num_features` is less than 1.
1065
+ ValueError: If `momentum` is not in range [0, 1].
1066
+ ValueError: If the shape of `gamma_init` / `beta_init` is not :math:`(C)`.
1067
+ KeyError: If any of `gamma_init`/`beta_init` is str and the homonymous class inheriting from `Initializer` not
1068
+ exists.
1069
+
1070
+ Supported Platforms:
1071
+ ``GPU``
1072
+
1073
+ Examples:
1074
+ >>> import mindspore as ms
1075
+ >>> import numpy as np
1076
+ >>> net = ms.nn.InstanceNorm2d(3)
1077
+ >>> x = ms.Tensor(np.ones([2, 3, 2, 2]), ms.float32)
1078
+ >>> output = net(x)
1079
+ >>> print(output.shape)
1080
+ (2, 3, 2, 2)
1081
+ """
1082
+
1083
+ @staticmethod
1084
+ @_primexpr
1085
+ def _check_input_dim(shape, cls_name):
1086
+ dim = len(shape)
1087
+ _check_dim(dim, 4, cls_name)
1088
+
1089
+
1090
+ class InstanceNorm3d(_InstanceNorm):
1091
+ r"""
1092
+ This layer applies Instance Normalization over a 5D input (a mini-batch of 3D inputs with
1093
+ additional channel dimension). Refer to the paper `Instance Normalization: The Missing Ingredient for
1094
+ Fast Stylization <https://arxiv.org/abs/1607.08022>`_. It rescales and recenters the feature using a mini-batch
1095
+ of data and the learned parameters which can be described in the following formula.
1096
+
1097
+ .. math::
1098
+ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
1099
+
1100
+ :math:`\gamma` and :math:`\beta` are learnable parameter vectors of size num_features if affine is True.
1101
+ The standard-deviation is calculated via the biased estimator.
1102
+
1103
+ This layer uses instance statistics computed from input data in both training and evaluation modes.
1104
+
1105
+ InstanceNorm3d and BatchNorm3d are very similar, but have some differences. InstanceNorm3d is applied on each
1106
+ channel of channeled data like RGB images, but BatchNorm3d is usually applied on each batch of batched data.
1107
+
1108
+ Note:
1109
+ Note that the formula for updating the running_mean and running_var is
1110
+ :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times x_t + \text{momentum} \times \hat{x}`,
1111
+ where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the new observed value.
1112
+
1113
+ Args:
1114
+ num_features (int): `C` from an expected input of size :math:`(N, C, D, H, W)`.
1115
+ eps (float): A value added to the denominator for numerical stability. Default: ``1e-5`` .
1116
+ momentum (float): A floating hyperparameter of the momentum for the
1117
+ running_mean and running_var computation. Default: ``0.1`` .
1118
+ affine (bool): A bool value. When set to ``True`` , gamma and beta can be learned. Default: ``True`` .
1119
+ gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
1120
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` , etc.
1121
+ When initialized with Tensor, the shape should be :math:`(C)`. Default: ``'ones'`` .
1122
+ beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
1123
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` , etc.
1124
+ When initialized with Tensor, the shape should be :math:`(C)`. Default: ``'zeros'`` .
1125
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
1126
+
1127
+ Inputs:
1128
+ - **x** (Tensor) - Tensor of shape :math:`(N, C, D, H, W)`. Data type: float16 or float32.
1129
+
1130
+ Outputs:
1131
+ Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C, D, H, W)`. Same type and
1132
+ shape as the `x`.
1133
+
1134
+ Raises:
1135
+ TypeError: If the type of `num_features` is not int.
1136
+ TypeError: If the type of `eps` is not float.
1137
+ TypeError: If the type of `momentum` is not float.
1138
+ TypeError: If the type of `affine` is not bool.
1139
+ TypeError: If the type of `gamma_init`/`beta_init` is not same, or if the initialized element type is not
1140
+ float32.
1141
+ ValueError: If `num_features` is less than 1.
1142
+ ValueError: If `momentum` is not in range [0, 1].
1143
+ ValueError: If the shape of `gamma_init` / `beta_init` is not :math:`(C)`.
1144
+ KeyError: If any of `gamma_init`/`beta_init` is str and the homonymous class inheriting from `Initializer` not
1145
+ exists.
1146
+
1147
+ Supported Platforms:
1148
+ ``GPU``
1149
+
1150
+ Examples:
1151
+ >>> import mindspore as ms
1152
+ >>> import numpy as np
1153
+ >>> net = ms.nn.InstanceNorm3d(3)
1154
+ >>> x = ms.Tensor(np.ones([2, 3, 5, 2, 2]), ms.float32)
1155
+ >>> output = net(x)
1156
+ >>> print(output.shape)
1157
+ (2, 3, 5, 2, 2)
1158
+ """
1159
+
1160
+ @staticmethod
1161
+ @_primexpr
1162
+ def _check_input_dim(shape, cls_name):
1163
+ dim = len(shape)
1164
+ _check_dim(dim, 5, cls_name)
1165
+
1166
+
1167
+ class GroupNorm(Cell):
1168
+ r"""
1169
+ Group Normalization over a mini-batch of inputs.
1170
+
1171
+ Group Normalization is widely used in recurrent neural networks. It applies
1172
+ normalization on a mini-batch of inputs for each single training case as described
1173
+ in the paper `Group Normalization <https://arxiv.org/pdf/1803.08494.pdf>`_. Group Normalization
1174
+ divides the channels into groups and computes within each group the mean and variance for normalization,
1175
+ and it performs very stable over a wide range of batch size. :math:`\gamma` and :math:`\beta` are trainable scale
1176
+ and shift.
1177
+ It can be described using the following formula:
1178
+
1179
+ .. math::
1180
+ y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
1181
+
1182
+ Args:
1183
+ num_groups (int): The number of groups to be divided along the channel dimension.
1184
+ num_channels (int): The number of input channels.
1185
+ eps (float): A value added to the denominator for numerical stability. Default: ``1e-05`` .
1186
+ affine (bool): A bool value, this layer will have learnable affine parameters when set to ``true`` .
1187
+ Default: ``True`` .
1188
+ gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight.
1189
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
1190
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'ones'`` . If gamma_init is a Tensor, the shape
1191
+ must be :math:`(num\_channels)`.
1192
+ beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight.
1193
+ The values of str refer to the function `initializer` including ``'zeros'`` , ``'ones'`` ,
1194
+ ``'xavier_uniform'`` , ``'he_uniform'`` , etc. Default: ``'zeros'`` . If beta_init is a Tensor, the shape
1195
+ must be :math:`(num\_channels)`.
1196
+ dtype (:class:`mindspore.dtype`): Dtype of Parameters. Default: ``mstype.float32`` .
1197
+
1198
+ Inputs:
1199
+ - **x** (Tensor) - The input feature with shape :math:`(N, C, *)`, where :math:`*` means, any number of
1200
+ additional dimensions.
1201
+
1202
+ Outputs:
1203
+ Tensor, the normalized and scaled offset tensor, has the same shape and data type as the `x`.
1204
+
1205
+ Raises:
1206
+ TypeError: If `num_groups` or `num_channels` is not an int.
1207
+ TypeError: If `eps` is not a float.
1208
+ TypeError: If `affine` is not a bool.
1209
+ ValueError: If `num_groups` or `num_channels` is less than 1.
1210
+ ValueError: If `num_channels` is not divided by `num_groups`.
1211
+
1212
+ Supported Platforms:
1213
+ ``Ascend`` ``GPU`` ``CPU``
1214
+
1215
+ Examples:
1216
+ >>> import mindspore as ms
1217
+ >>> import numpy as np
1218
+ >>> group_norm_op = ms.nn.GroupNorm(2, 2)
1219
+ >>> x = ms.Tensor(np.ones([1, 2, 4, 4], np.float32))
1220
+ >>> output = group_norm_op(x)
1221
+ >>> print(output)
1222
+ [[[[0. 0. 0. 0.]
1223
+ [0. 0. 0. 0.]
1224
+ [0. 0. 0. 0.]
1225
+ [0. 0. 0. 0.]]
1226
+ [[0. 0. 0. 0.]
1227
+ [0. 0. 0. 0.]
1228
+ [0. 0. 0. 0.]
1229
+ [0. 0. 0. 0.]]]]
1230
+ """
1231
+
1232
+ def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, gamma_init='ones', beta_init='zeros',
1233
+ dtype=mstype.float32):
1234
+ """Initialize GroupNorm."""
1235
+ super(GroupNorm, self).__init__()
1236
+ self.num_groups = validator.check_positive_int(num_groups, "num_groups", self.cls_name)
1237
+ self.num_channels = validator.check_positive_int(num_channels, "num_channels", self.cls_name)
1238
+ if num_channels % num_groups != 0:
1239
+ raise ValueError(f"For '{self.cls_name}', the 'num_channels' must be divided by 'num_groups', "
1240
+ f"but got 'num_channels': {num_channels}, 'num_groups': {num_groups}.")
1241
+ self.eps = validator.check_value_type('eps', eps, (float,), type(self).__name__)
1242
+ self.affine = validator.check_bool(affine, arg_name="affine", prim_name=self.cls_name)
1243
+
1244
+ self.gamma = Parameter(initializer(
1245
+ gamma_init, self.num_channels, dtype=dtype), name="gamma", requires_grad=affine)
1246
+ self.beta = Parameter(initializer(
1247
+ beta_init, self.num_channels, dtype=dtype), name="beta", requires_grad=affine)
1248
+
1249
+ def _cal_output(self, x):
1250
+ """calculate groupnorm output"""
1251
+ return group_norm(x, self.num_groups, self.gamma.to(x.dtype), self.beta.to(x.dtype), self.eps)
1252
+
1253
+ @staticmethod
1254
+ @_primexpr
1255
+ def _channel_check(channel, num_channel, prim_name=None):
1256
+ def _check():
1257
+ msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
1258
+ if channel != num_channel:
1259
+ raise ValueError(f"{msg_prefix} channel(the second dim of the input 'x') must be equal to "
1260
+ f"num_channels, but got channel: {channel}, num_channels: {num_channel}.")
1261
+ _check()
1262
+
1263
+ @staticmethod
1264
+ @constexpr
1265
+ def _check_dtype(dtype, valid_dtypes, prim_name=None):
1266
+ validator.check_type_name("input", dtype, valid_dtypes, prim_name)
1267
+
1268
+ def extend_repr(self):
1269
+ return 'num_groups={}, num_channels={}'.format(self.num_groups, self.num_channels)
1270
+
1271
+ def construct(self, x):
1272
+ output = self._cal_output(x)
1273
+ return output