mindspore 2.4.0__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-311-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1419 @@
1
+ # Copyright 2020-2021 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+ """Check parameters."""
16
+ from __future__ import absolute_import
17
+
18
+ import re
19
+ import inspect
20
+ import math
21
+ from types import FunctionType, MethodType
22
+ from functools import reduce, wraps
23
+ from itertools import repeat
24
+ from collections.abc import Iterable
25
+ import numpy as np
26
+
27
+ from mindspore import context
28
+ from mindspore import log as logger
29
+ from mindspore.common import dtype as mstype
30
+ from mindspore._c_expression import Tensor as Tensor_
31
+
32
+ EQ = 1 # ==
33
+ NE = 2 # !=
34
+ LT = 3 # <
35
+ LE = 4 # <=
36
+ GT = 5 # >
37
+ GE = 6 # >=
38
+ # scalar range check
39
+ INC_NEITHER = 7 # (), include neither
40
+ INC_LEFT = 8 # [), include left
41
+ INC_RIGHT = 9 # (], include right
42
+ INC_BOTH = 10 # [], include both
43
+ # collection in, not in
44
+ IN = 11
45
+ NOT_IN = 12
46
+
47
+
48
+ def _check_binary_rel(val1, val2, rel):
49
+ """check binary relation"""
50
+ if rel == EQ:
51
+ return val1 == val2
52
+ if rel == NE:
53
+ return val1 != val2
54
+ if rel == LT:
55
+ return val1 < val2
56
+ if rel == LE:
57
+ return val1 <= val2
58
+ if rel == GT:
59
+ return val1 > val2
60
+ if rel == GE:
61
+ return val1 >= val2
62
+ if rel == IN:
63
+ return val1 in val2
64
+ if rel == NOT_IN:
65
+ return val1 not in val2
66
+
67
+ return False
68
+
69
+
70
+ def _check_inc_rel(val, lower, upper, rel):
71
+ """check include relation"""
72
+ if rel == INC_NEITHER:
73
+ return not (val <= lower or val >= upper)
74
+ if rel == INC_LEFT:
75
+ return not (val < lower or val >= upper)
76
+ if rel == INC_RIGHT:
77
+ return not (val <= lower or val > upper)
78
+ if rel == INC_BOTH:
79
+ return not (val < lower or val > upper)
80
+
81
+ return False
82
+
83
+
84
+ def _format_str_one_value(value, rel):
85
+ """format string"""
86
+ if rel == EQ:
87
+ return f"= {value}"
88
+ if rel == NE:
89
+ return f"!= {value}"
90
+ if rel == LT:
91
+ return f"< {value}"
92
+ if rel == LE:
93
+ return f"<= {value}"
94
+ if rel == GT:
95
+ return f"> {value}"
96
+ if rel == GE:
97
+ return f">= {value}"
98
+ if rel == IN:
99
+ return f"in {value}"
100
+ if rel == NOT_IN:
101
+ return f"not in {value}"
102
+
103
+ return ""
104
+
105
+
106
+ def _format_str_two_value(val1, val2, rel):
107
+ """format string"""
108
+ if rel == INC_NEITHER:
109
+ return f"({val1}, {val2})"
110
+ if rel == INC_LEFT:
111
+ return f"[{val1}, {val2})"
112
+ if rel == INC_RIGHT:
113
+ return f"({val1}, {val2}]"
114
+ if rel == INC_BOTH:
115
+ return f"[{val1}, {val2}]"
116
+
117
+ return ""
118
+
119
+
120
+ def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret_five=False,
121
+ greater_zero=True, third_one=False, three_input=False):
122
+ """
123
+ Checks whether an argument is a positive int or tuple with 3 or 5(when allow_five is True) positive int elements.
124
+ """
125
+
126
+ def _raise_message(third_one_flag=False, three_input_flag=False):
127
+ if third_one_flag:
128
+ raise ValueError(f"For '{prim_name}', the depth of parameter '{arg_name}' must be 1, " \
129
+ f"but got {ret_value[-3]}.")
130
+ if three_input_flag:
131
+ raise ValueError(f"For '{prim_name}', the parameter '{arg_name}' must be an positive integer " \
132
+ f"or a tuple of three positive integer, but got {arg_value}.")
133
+ raise ValueError(f"For '{prim_name}', the parameter '{arg_name}' must be an positive integer or " \
134
+ f"a tuple of three {'or five ' if allow_five else ''}positive integer, but got {arg_value}")
135
+
136
+ def _get_return_value():
137
+ def _check():
138
+ if not isinstance(arg_value, int):
139
+ if len(arg_value) == 5:
140
+ if not allow_five:
141
+ _raise_message()
142
+ elif not len(arg_value) == 3:
143
+ _raise_message()
144
+
145
+ _check()
146
+ if isinstance(arg_value, int):
147
+ ret = (1, 1, arg_value, arg_value, arg_value) if ret_five else (arg_value, arg_value, arg_value)
148
+ elif len(arg_value) == 3:
149
+ ret = (1, 1, arg_value[0], arg_value[1], arg_value[2]) if ret_five else arg_value
150
+ else: # case: len(arg_value) == 5
151
+ ret = arg_value if ret_five else (arg_value[2], arg_value[3], arg_value[4])
152
+
153
+ return ret
154
+
155
+ def _check_value(ret_value):
156
+ for item in ret_value:
157
+ if isinstance(item, int) and not isinstance(item, bool):
158
+ if greater_zero and item > 0:
159
+ continue
160
+ if not greater_zero and item >= 0:
161
+ continue
162
+ _raise_message()
163
+
164
+ def _check_third_one(ret_value):
165
+ if third_one:
166
+ if ret_value[-3] != 1:
167
+ _raise_message(third_one_flag=third_one)
168
+
169
+ check_value_type(arg_name, arg_value, (int, tuple), prim_name)
170
+ if three_input and isinstance(arg_value, tuple):
171
+ if len(arg_value) != 3:
172
+ _raise_message(three_input_flag=three_input)
173
+ ret_value = _get_return_value()
174
+ _check_value(ret_value)
175
+ _check_third_one(ret_value)
176
+
177
+ return tuple(ret_value)
178
+
179
+
180
+ def _check_dup(axes):
181
+ for item in axes:
182
+ count = 0
183
+ for item2 in axes:
184
+ if item == item2:
185
+ count += 1
186
+
187
+ if count > 1:
188
+ raise ValueError(f"The element of parameter 'axis' can not be duplicate, but got {axes}.")
189
+
190
+
191
+ def _check_number(arg_value, value, rel, arg_type=int, arg_name=None, prim_name=None):
192
+ """
193
+ Check argument integer.
194
+
195
+ Usage:
196
+ - arg_value = _check_number(arg_value, 2, GT, int, "value", None)
197
+ """
198
+ prim_name = f"For \'{prim_name}\', the " if prim_name else 'The '
199
+ arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
200
+
201
+ def _check_param():
202
+ prim_info = f'{prim_name}' + f'{arg_name}'
203
+ if isinstance(arg_value, arg_type):
204
+ if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
205
+ raise ValueError(f"{prim_info} must be a legal value, but got '{arg_value}'.")
206
+ else:
207
+ raise TypeError(f"{prim_info} must be {arg_type.__name__}, but got '{type(arg_value).__name__}'")
208
+
209
+ type_mismatch = not isinstance(arg_value, arg_type) or isinstance(arg_value, bool)
210
+ rel_ret = _check_binary_rel(arg_value, value, rel)
211
+ if type_mismatch or not rel_ret:
212
+ rel_str = _format_str_one_value(value, rel)
213
+ msg = f"{prim_info} must be {arg_type.__name__} and must {rel_str}, " \
214
+ f"but got '{arg_value}' with type '{type(arg_value).__name__}'."
215
+ if type_mismatch:
216
+ raise TypeError(msg)
217
+ raise ValueError(msg)
218
+
219
+ _check_param()
220
+ return arg_value
221
+
222
+
223
+ def check_is_number(arg_value, arg_type, arg_name=None, prim_name=None):
224
+ """
225
+ Checks input value is float type or not.
226
+
227
+ Usage:
228
+ - number = check_is_number(number, int)
229
+ - number = check_is_number(number, int, "bias")
230
+ - number = check_is_number(number, int, "bias", "bias_class")
231
+ """
232
+ prim_name = f"For \'{prim_name}\', the" if prim_name else 'The'
233
+ arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
234
+
235
+ def _check_param():
236
+ if isinstance(arg_value, arg_type) and not isinstance(arg_value, bool):
237
+ if math.isinf(arg_value) or math.isnan(arg_value) or np.isinf(arg_value) or np.isnan(arg_value):
238
+ raise ValueError(f"{prim_name} {arg_name} must be a legal float, but got '{arg_value}'.")
239
+ else:
240
+ raise TypeError(f"{prim_name} type of {arg_name} must be '{arg_type.__name__}', " \
241
+ f"but got '{type(arg_value).__name__}'.")
242
+
243
+ _check_param()
244
+ return arg_value
245
+
246
+
247
+ def check_number_range(arg_value, lower_limit, upper_limit, rel, value_type, arg_name=None, prim_name=None):
248
+ """
249
+ Method for checking whether an int value is in some range.
250
+
251
+ Usage:
252
+ - number = check_number_range(number, 0.0, 1.0, INC_NEITHER, "number", float) # number in [0.0, 1.0]
253
+ - number = check_number_range(number, 0, 1, INC_NEITHER, "number", int) # number in [0, 1]
254
+ """
255
+ prim_name = f"For \'{prim_name}\', the" if prim_name else 'The'
256
+ arg_name = f"\'{arg_name}\'" if arg_name else 'input value'
257
+
258
+ def _check_param():
259
+ type_mismatch = not isinstance(arg_value, (np.ndarray, np.generic, value_type)) or isinstance(arg_value, bool)
260
+ if type_mismatch:
261
+ raise TypeError(f"{prim_name} {arg_name} must be '{value_type.__name__}', " \
262
+ f"but got '{type(arg_value).__name__}'.")
263
+
264
+ if not _check_inc_rel(arg_value, lower_limit, upper_limit, rel):
265
+ rel_str = _format_str_two_value(lower_limit, upper_limit, rel)
266
+ raise ValueError(f"{prim_name} {arg_name} must be in range of {rel_str}, " \
267
+ f"but got {arg_value} with type '{type(arg_value).__name__}'.")
268
+
269
+ _check_param()
270
+ return arg_value
271
+
272
+
273
+ def check(arg_name, arg_value, value_name, value, rel=EQ, prim_name=None, excp_cls=ValueError):
274
+ """
275
+ Method for judging relation between two int values or list/tuple made up of ints.
276
+ This method is not suitable for judging relation between floats, since it does not consider float error.
277
+ """
278
+
279
+ def _check():
280
+ if not _check_binary_rel(arg_value, value, rel):
281
+ rel_str = _format_str_one_value(f'{value_name}: {value}', rel)
282
+ msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
283
+ msg_subject = f"{msg_prefix} \'{arg_name}\'" if " " not in arg_name else f"{msg_prefix} {arg_name}"
284
+ raise excp_cls(f'{msg_subject} should be {rel_str}, but got {arg_value}.')
285
+
286
+ _check()
287
+ return arg_value
288
+
289
+
290
+ def check_int(arg_value, value, rel, arg_name=None, prim_name=None):
291
+ """
292
+ Checks input integer value `arg_value` compare to `value`.
293
+
294
+ Usage:
295
+ - number = check_int(number, 0, GE, "number", None) # number >= 0
296
+ """
297
+ return _check_number(arg_value, value, rel, int, arg_name, prim_name)
298
+
299
+
300
+ def check_is_int(arg_value, arg_name=None, prim_name=None):
301
+ """
302
+ Checks input value is float type or not.
303
+
304
+ Usage:
305
+ - number = check_is_int(number, int)
306
+ - number = check_is_int(number, int, "bias")
307
+ - number = check_is_int(number, int, "bias", "bias_class")
308
+ """
309
+ return check_is_number(arg_value, int, arg_name, prim_name)
310
+
311
+
312
+ def check_equal_int(arg_value, value, arg_name=None, prim_name=None):
313
+ """
314
+ Checks input integer value `arg_value` compare to `value`.
315
+
316
+ Usage:
317
+ - number = check_equal_int(number, 0, "number", None) # number == 0
318
+ """
319
+ return _check_number(arg_value, value, EQ, int, arg_name, prim_name)
320
+
321
+
322
+ def check_positive_int(arg_value, arg_name=None, prim_name=None):
323
+ """
324
+ Check argument is positive integer, which mean arg_value > 0.
325
+
326
+ Usage:
327
+ - number = check_positive_int(number)
328
+ - number = check_positive_int(number, "bias")
329
+ """
330
+ return _check_number(arg_value, 0, GT, int, arg_name, prim_name)
331
+
332
+
333
+ def check_positive_int_sequence(sequence, arg_name=None, prim_name=None):
334
+ """
335
+ Check argument is positive int sequence, which mean all element > 0 in sequence.
336
+
337
+ Usage:
338
+ - sequence = check_positive_int_sequence(sequence)
339
+ - sequence = check_positive_int_sequence(sequence, "dims")
340
+ """
341
+ for idx in range(len(sequence)):
342
+ element = sequence[idx]
343
+ arg_idx = f"{arg_name if arg_name else 'arg_name'}[{idx}]"
344
+ _check_number(element, 0, GT, int, arg_idx, prim_name)
345
+ return sequence
346
+
347
+
348
+ def check_negative_int(arg_value, arg_name=None, prim_name=None):
349
+ """
350
+ Check argument is negative integer, which mean arg_value < 0.
351
+
352
+ Usage:
353
+ - number = check_negative_int(number)
354
+ - number = check_negative_int(number, "bias")
355
+ """
356
+ return _check_number(arg_value, 0, LT, int, arg_name, prim_name)
357
+
358
+
359
+ def check_non_positive_int(arg_value, arg_name=None, prim_name=None):
360
+ """
361
+ Check argument is non-negative integer, which mean arg_value <= 0.
362
+
363
+ Usage:
364
+ - number = check_non_positive_int(number)
365
+ - number = check_non_positive_int(number, "bias")
366
+ """
367
+ return _check_number(arg_value, 0, LE, int, arg_name, prim_name)
368
+
369
+
370
+ def check_non_negative_int(arg_value, arg_name=None, prim_name=None):
371
+ """
372
+ Check argument is non-negative integer, which mean arg_value >= 0.
373
+
374
+ Usage:
375
+ - number = check_non_negative_int(number)
376
+ - number = check_non_negative_int(number, "bias")
377
+ """
378
+ return _check_number(arg_value, 0, GE, int, arg_name, prim_name)
379
+
380
+
381
+ def check_non_negative_int_sequence(sequence, arg_name=None, prim_name=None):
382
+ """
383
+ Check argument is positive sequence, which mean all element >= 0 in sequence.
384
+
385
+ Usage:
386
+ - sequence = check_non_negative_int_sequence(sequence)
387
+ - sequence = check_non_negative_int_sequence(sequence, "dims")
388
+ """
389
+ for idx in range(len(sequence)):
390
+ element = sequence[idx]
391
+ arg_idx = f"{arg_name if arg_name else 'arg_name'}[{idx}]"
392
+ _check_number(element, 0, GE, int, arg_idx, prim_name)
393
+ return sequence
394
+
395
+
396
+ def check_float(arg_value, value, rel, arg_name=None, prim_name=None):
397
+ """
398
+ Checks input float value `arg_value` compare to `value`.
399
+
400
+ Usage:
401
+ - number = check_float(number, 0.0, GE, "number", None) # number >= 0
402
+ """
403
+ return _check_number(arg_value, value, rel, float, arg_name, prim_name)
404
+
405
+
406
+ def check_is_float(arg_value, arg_name=None, prim_name=None):
407
+ """
408
+ Checks input value is float type or not.
409
+
410
+ Usage:
411
+ - number = check_is_float(number)
412
+ - number = check_is_float(number, "bias")
413
+ - number = check_is_float(number, "bias", "bias_class")
414
+ """
415
+ return check_is_number(arg_value, float, arg_name, prim_name)
416
+
417
+
418
+ def check_positive_float(arg_value, arg_name=None, prim_name=None):
419
+ """
420
+ Check argument is positive float, which mean arg_value > 0.
421
+
422
+ Usage:
423
+ - number = check_positive_float(number)
424
+ - number = check_positive_float(number, "bias")
425
+ - number = check_positive_float(number, "bias", "bias_class")
426
+ """
427
+ return _check_number(arg_value, 0, GT, float, arg_name, prim_name)
428
+
429
+
430
+ def check_positive_float_sequence(sequence, arg_name=None, prim_name=None):
431
+ """
432
+ Check argument is positive sequence, which mean all element > 0 in sequence.
433
+
434
+ Usage:
435
+ - sequence = check_positive_float_sequence(sequence)
436
+ - sequence = check_positive_float_sequence(sequence, "dims")
437
+ """
438
+ for idx in range(len(sequence)):
439
+ element = sequence[idx]
440
+ arg_idx = f"{arg_name if arg_name else 'arg_name'}[{idx}]"
441
+ _check_number(element, 0, GT, float, arg_idx, prim_name)
442
+ return sequence
443
+
444
+
445
+ def check_negative_float(arg_value, arg_name=None, prim_name=None):
446
+ """
447
+ Check argument is negative float, which mean arg_value < 0.
448
+
449
+ Usage:
450
+ - number = check_negative_float(number)
451
+ - number = check_negative_float(number, "bias")
452
+ """
453
+ return _check_number(arg_value, 0, LT, float, arg_name, prim_name)
454
+
455
+
456
+ def check_non_positive_float(arg_value, arg_name=None, prim_name=None):
457
+ """
458
+ Check argument is non-negative float, which mean arg_value <= 0.
459
+
460
+ Usage:
461
+ - number = check_non_positive_float(number)
462
+ - number = check_non_positive_float(number, "bias")
463
+ """
464
+ return _check_number(arg_value, 0, LE, float, arg_name, prim_name)
465
+
466
+
467
+ def check_non_negative_float(arg_value, arg_name=None, prim_name=None):
468
+ """
469
+ Check argument is non-negative float, which mean arg_value >= 0.
470
+
471
+ Usage:
472
+ - number = check_non_negative_float(number)
473
+ - number = check_non_negative_float(number, "bias")
474
+ """
475
+ return _check_number(arg_value, 0, GE, float, arg_name, prim_name)
476
+
477
+
478
+ def check_number(arg_name, arg_value, value, rel, prim_name):
479
+ """Number value judgment."""
480
+
481
+ def _check():
482
+ if not _check_binary_rel(arg_value, value, rel):
483
+ rel_str = _format_str_one_value(value, rel)
484
+ raise ValueError(f'For \'{prim_name}\', the argument \'{arg_name}\' ' \
485
+ f'must {rel_str}, but got {arg_value}.')
486
+
487
+ _check()
488
+ return arg_value
489
+
490
+
491
+ def check_isinstance(arg_name, arg_value, classes):
492
+ """Check arg isinstance of classes"""
493
+
494
+ def _check():
495
+ if not isinstance(arg_value, classes):
496
+ raise ValueError(f'The parameter \'{arg_name}\' must be isinstance of {classes}, but got {arg_value}.')
497
+
498
+ _check()
499
+ return arg_value
500
+
501
+
502
+ def check_bool(arg_value, arg_name=None, prim_name=None):
503
+ """
504
+ Check argument is instance of bool.
505
+
506
+ Usage:
507
+ - has_bias = check_bool(has_bias)
508
+ - has_bias = check_bool(has_bias, "has_bias")
509
+ """
510
+ prim_name = f"For '{prim_name}', the" if prim_name else 'The'
511
+ arg_name = f"'{arg_name}'" if arg_name else 'input value'
512
+
513
+ def _check():
514
+ if not isinstance(arg_value, bool):
515
+ raise TypeError(f"{prim_name} {arg_name} must be a bool, but got {type(arg_value).__name__}.")
516
+
517
+ _check()
518
+ return arg_value
519
+
520
+
521
+ def check_int_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
522
+ """
523
+ Method for checking whether input value is in int range.
524
+
525
+ Usage:
526
+ - number = check_int_range(number, 0, 1, INC_NEITHER) # number in [0, 1]
527
+ - number = check_int_range(number, 0, 1, INC_NEITHER, "number") # number in [0, 1]
528
+ """
529
+ return check_number_range(arg_value, lower_limit, upper_limit, rel, int, arg_name, prim_name)
530
+
531
+
532
+ def check_float_range(arg_value, lower_limit, upper_limit, rel, arg_name=None, prim_name=None):
533
+ """
534
+ Method for checking whether input value is in float range.
535
+
536
+ Usage:
537
+ - number = check_float_range(number, 0.0, 1.0, INC_NEITHER) # number in [0.0, 1.0]
538
+ - number = check_float_range(number, 0.0, 1.0, INC_NEITHER, "number") # number in [0.0, 1.0]
539
+ """
540
+ return check_number_range(arg_value, lower_limit, upper_limit, rel, float, arg_name, prim_name)
541
+
542
+
543
+ def check_string(arg_value, valid_values, arg_name=None, prim_name=None):
544
+ """
545
+ Check whether string is in some value list.
546
+
547
+ Usage:
548
+ - method = check_string(method, ["string1", "string2", "string3"], "method")
549
+ """
550
+ arg_name = arg_name if arg_name else "parameter"
551
+ msg_prefix = f'For \'{prim_name}\', the' if prim_name else "The"
552
+
553
+ def _check():
554
+ if not (isinstance(arg_value, str) and arg_value in valid_values):
555
+ raise ValueError(f"{msg_prefix} '{arg_name}' must be str and must be in '{valid_values}'," \
556
+ f" but got '{arg_value}'.")
557
+
558
+ _check()
559
+ return arg_value
560
+
561
+
562
+ def check_str_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
563
+ if reg is None:
564
+ # Named string regular expression
565
+ reg = r"^\w+[0-9a-zA-Z\_\.]*$"
566
+ if re.match(reg, target, flag) is None:
567
+ prim_name = f"For '{prim_name}', the" if prim_name else "The"
568
+ raise ValueError(f"{prim_name} '{target}' is illegal, it must be match regular'{reg}' by flags'{flag}.'")
569
+ return True
570
+
571
+
572
+ # pylint: disable=missing-docstring
573
+ def check_str_and_none_by_regular(target, reg=None, flag=re.ASCII, prim_name=None):
574
+ if reg is None:
575
+ # Named string regular expression
576
+ reg = r"^\w*[0-9a-zA-Z\_\.\-]*$"
577
+ if re.match(reg, target, flag) is None:
578
+ prim_name = f"For '{prim_name}', the" if prim_name else "The"
579
+ raise ValueError(f"{prim_name} '{target}' is illegal, it must be match regular'{reg}' by flags'{flag}.'")
580
+ return True
581
+
582
+
583
+ def check_file_name_by_regular(target, reg=None, prim_name=None):
584
+ """Check whether file name is legitimate."""
585
+ if not isinstance(target, str):
586
+ prim_name = f"For '{prim_name}', the" if prim_name else "The"
587
+ raise TypeError(f"{prim_name} '{target}' must be string, but got {type(target)}.")
588
+ if target.endswith("\\") or target.endswith("/"):
589
+ prim_name = f"For '{prim_name}', the" if prim_name else "The"
590
+ raise ValueError(f"{prim_name} '{target}' cannot be a directory path.")
591
+ if reg is None:
592
+ reg = r"^[0-9a-zA-Z@\_\-\.\:\/\\]+$"
593
+ if re.match(reg, target) is None:
594
+ prim_name = f"For '{prim_name}', the" if prim_name else "The"
595
+ raise ValueError(f"{prim_name} '{target}' is illegal, it must be match regular '{reg}'.")
596
+
597
+ return True
598
+
599
+
600
+ def check_pad_value_by_mode(pad_mode, padding, prim_name):
601
+ """Validates value of padding according to pad_mode"""
602
+ if pad_mode != 'pad' and padding != 0:
603
+ raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'," \
604
+ f" but got {padding}.")
605
+ return padding
606
+
607
+
608
+ def check_subclass(arg_name, type_, template_types, prim_name, addition_error_info=None):
609
+ """Checks whether some type is subclass of another type"""
610
+ if not isinstance(template_types, Iterable):
611
+ template_types = (template_types,)
612
+ hit = False
613
+ for template_type in template_types:
614
+ if isinstance(template_type, mstype.Type):
615
+ if mstype._issubclass_(type_, template_type): # pylint: disable=W0212
616
+ hit = True
617
+ break
618
+ elif type_ is template_type:
619
+ hit = True
620
+ break
621
+ if not hit:
622
+ if addition_error_info is None:
623
+ addition_error_info = ''
624
+ else:
625
+ addition_error_info = ' ' + addition_error_info
626
+ type_str = (f"type '{type(type_).__name__}'" if isinstance(type_, (tuple, list)) else str(type_))
627
+ raise TypeError(f"For '{prim_name}', the element of '{arg_name}'" \
628
+ f" must be {'one of ' if len(template_types) > 1 else ''}" \
629
+ f"{', '.join((str(x) for x in template_types))}, but got {type_str}" \
630
+ f"{addition_error_info}.The supported data types depend on the hardware that" \
631
+ f" executes the operator, for more details, please refer to the MindSpore official " \
632
+ f"website to get more information about the data type.")
633
+
634
+
635
+ def check_valid_input(arg_name, arg_value, prim_name):
636
+ """Checks valid value."""
637
+
638
+ def _check():
639
+ if arg_value is None:
640
+ raise ValueError(f"For \'{prim_name}\', the argument '{arg_name}'" \
641
+ f"can not be None, but got {arg_value}.")
642
+
643
+ _check()
644
+ return arg_value
645
+
646
+
647
+ def check_types_same_and_valid(args, valid_values, prim_name):
648
+ """Checks whether the types of inputs are the same and valid."""
649
+
650
+ def _check_type_valid(arg):
651
+ arg_key, arg_val = arg
652
+ elem_type = arg_val
653
+ check_subclass(arg_key, elem_type, valid_values, prim_name)
654
+ return (arg_key, elem_type)
655
+
656
+ def _check_types_same(arg1, arg2):
657
+ arg1_name, arg1_type = arg1
658
+ arg2_name, arg2_type = arg2
659
+ if arg1_type != arg2_type:
660
+ raise TypeError(f"For '{prim_name}', the type of '{arg2_name}' should be same as '{arg1_name}'," \
661
+ f" but got '{arg1_name}' with type {arg1_type}" \
662
+ f" and '{arg2_name}' with type {arg2_type}.")
663
+ return arg1
664
+
665
+ elem_types = map(_check_type_valid, args.items())
666
+ reduce(_check_types_same, elem_types)
667
+
668
+
669
+ def check_tensors_dtypes_same_and_valid(args, valid_dtypes, prim_name):
670
+ """Checks whether the element types of input tensors are the same and valid."""
671
+ valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
672
+ tensor_types = [mstype.TensorType(t) for t in valid_dtypes]
673
+ check_types_same_and_valid(args, tensor_types, prim_name)
674
+
675
+
676
+ def check_tensor_dtype_valid(arg_name, arg_type, valid_dtypes, prim_name):
677
+ """Checks whether the element types of input tensors are valid."""
678
+ valid_dtypes = valid_dtypes if isinstance(valid_dtypes, Iterable) else [valid_dtypes]
679
+ tensor_types = [mstype.TensorType(t) for t in valid_dtypes]
680
+ check_subclass(arg_name, arg_type, tensor_types, prim_name)
681
+
682
+
683
+ def check_scalar_or_tensor_types_same(args, valid_values, prim_name, allow_mix=False):
684
+ """
685
+ Checks whether the types of inputs are the same. If the input args are tensors, checks their element types.
686
+ If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised.
687
+ """
688
+
689
+ def _check_argument_type(arg):
690
+ arg_key, arg_val = arg
691
+ if isinstance(arg_val, type(mstype.tensor_type)):
692
+ arg_val = arg_val.element_type()
693
+ if arg_val not in valid_values:
694
+ raise TypeError(f'For \'{prim_name}\', the type of \'{arg_key}\' must be in {valid_values},' \
695
+ f' but got {arg_val}.')
696
+ return arg
697
+
698
+ def _check_types_same(arg1, arg2):
699
+ arg1_name, arg1_type = arg1
700
+ arg2_name, arg2_type = arg2
701
+ except_flag = False
702
+ if isinstance(arg1_type, type(mstype.tensor_type)) and isinstance(arg2_type, type(mstype.tensor_type)):
703
+ arg1_type = arg1_type.element_type()
704
+ arg2_type = arg2_type.element_type()
705
+ elif not (isinstance(arg1_type, type(mstype.tensor_type)) or isinstance(arg2_type, type(mstype.tensor_type))):
706
+ pass
707
+ elif allow_mix:
708
+ arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor_type)) else arg1_type
709
+ arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor_type)) else arg2_type
710
+ else:
711
+ except_flag = True
712
+
713
+ if except_flag or arg1_type != arg2_type:
714
+ raise TypeError(f"For '{prim_name}', the type of '{arg2_name}' must be same as '{arg1_name}'," \
715
+ f" but got '{arg1_name}' with type {arg1_type}" \
716
+ f" and '{arg2_name}' with type {arg2_type}.")
717
+ return arg1
718
+
719
+ args_map = map(_check_argument_type, args.items())
720
+ reduce(_check_types_same, args_map)
721
+
722
+
723
+ def check_value_type(arg_name, arg_value, valid_types, prim_name=None):
724
+ """Checks whether a value is instance of some types."""
725
+ valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
726
+
727
+ def raise_error_msg(cond, arg_value):
728
+ """func for raising error message when check failed"""
729
+ if not cond:
730
+ return
731
+ type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
732
+ num_types = len(valid_types)
733
+ msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
734
+ raise TypeError(f'{msg_prefix} type of \'{arg_name}\' should be {"one of " if num_types > 1 else ""}' \
735
+ f'\'{type_names if num_types > 1 else type_names[0]}\', ' \
736
+ f'but got type \'{type(arg_value).__name__}\'.')
737
+
738
+ # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and
739
+ # `check_value_type('x', True, [bool, int])` will check pass
740
+ cond = isinstance(arg_value, bool) and bool not in tuple(valid_types)
741
+ raise_error_msg(cond, arg_value)
742
+ if isinstance(arg_value, float) and float not in tuple(valid_types):
743
+ arg_value = round(arg_value, 6)
744
+ cond = not isinstance(arg_value, tuple(valid_types))
745
+ raise_error_msg(cond, arg_value)
746
+ return arg_value
747
+
748
+
749
+ def check_type_name(arg_name, arg_type, valid_types, prim_name):
750
+ """Checks whether a type in some specified types"""
751
+ valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
752
+
753
+ def raise_error_msg(cond, arg_type):
754
+ """func for raising error message when check failed"""
755
+ if not cond:
756
+ return
757
+ type_names = [t.__name__ if hasattr(t, '__name__') else t for t in valid_types]
758
+ num_types = len(valid_types)
759
+ msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
760
+ raise TypeError(f"{msg_prefix} '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
761
+ f"{type_names if num_types > 1 else type_names[0]}, " \
762
+ f"but got '{arg_type.__name__ if hasattr(arg_type, '__name__') else repr(arg_type)}'.")
763
+
764
+ if isinstance(arg_type, type(mstype.tensor_type)):
765
+ arg_type = arg_type.element_type()
766
+ cond = arg_type not in valid_types
767
+ raise_error_msg(cond, arg_type)
768
+ return arg_type
769
+
770
+
771
+ def check_reduce_shape(ori_shape, shape, axis, prim_name, arg_name1, arg_name2):
772
+ """Checks whether shape is ori_shape reduced on axis"""
773
+ axis_origin = axis
774
+ axis = axis if isinstance(axis, Iterable) else (axis,)
775
+ exp_shape = [ori_shape[i] for i in range(len(ori_shape)) if i not in axis]
776
+ if list(shape) != exp_shape:
777
+ raise ValueError(f"For '{prim_name}', " \
778
+ f"the shape of parameter '{arg_name1}' reduce on 'axis': {axis_origin} must " \
779
+ f"be equal to the shape of '{arg_name2}': {shape}, but got {ori_shape}.")
780
+
781
+
782
+ def check_astype_dtype(dtype):
783
+ """Check whether dtype is a valid input, and convert to mstype"""
784
+ all_types = mstype.__dtype__ + ["int", "float", "bool"]
785
+ if isinstance(dtype, str):
786
+ if dtype.lower() not in all_types:
787
+ raise TypeError(f"For Tensor.astype, the input type must be one of {all_types}, but got '{dtype}'.")
788
+ dtype = mstype.pytype_to_dtype(np.dtype(dtype.lower()))
789
+ elif isinstance(dtype, type):
790
+ dtype = mstype.pytype_to_dtype(dtype)
791
+ elif not dtype in mstype.number_type + (mstype.bool_,):
792
+ raise TypeError(f"For Tensor.astype, the input type must be one of {mstype.number_type + (mstype.bool_,)}," \
793
+ f" but got '{dtype}'.")
794
+ return dtype
795
+
796
+
797
+ def check_transpose_axis(axes, ndim):
798
+ """Check the axis argument for tensor.transpose"""
799
+
800
+ def _check_dim():
801
+ # if multiple arguments provided, it must be `ndim` number of ints
802
+ if len(axes) != ndim:
803
+ raise ValueError(f"For Tensor.transpose, the number of axes must be equal to the dimension of Tensor, " \
804
+ f"but got {len(axes)} in the number of axes.")
805
+
806
+ if not axes or (len(axes) == 1 and axes[0] is None):
807
+ return tuple(range(ndim - 1, -1, -1))
808
+
809
+ if len(axes) == 1:
810
+ perm = axes[0]
811
+ # if only one argument provided, it must be tuple or list
812
+ if isinstance(perm, list):
813
+ perm = tuple(perm)
814
+ elif isinstance(perm, int):
815
+ perm = (perm,)
816
+ _check_dim()
817
+ else:
818
+ if not isinstance(perm, tuple):
819
+ raise TypeError(f"For Tensor.transpose, the parameter 'axes' must be a tuple/list, " \
820
+ f"or series of integer, but got {type(axes[0])}")
821
+ return perm
822
+
823
+ _check_dim()
824
+ return axes
825
+
826
+
827
+ def check_reshape_shp(shp):
828
+ """Check the shape argument for tensor.reshape"""
829
+
830
+ if len(shp) == 1:
831
+ new_shape = shp[0]
832
+ # if only one argument provided, it must be int, tuple or list
833
+ if isinstance(new_shape, int):
834
+ return shp
835
+ if isinstance(new_shape, list):
836
+ new_shape = tuple(new_shape)
837
+ else:
838
+ if not isinstance(new_shape, tuple):
839
+ raise TypeError(
840
+ f"For Tensor.reshape, the parameter 'shape' must be an integer, or tuple/list, " \
841
+ f"or series of integer, but got {type(shp[0])}")
842
+ return new_shape
843
+
844
+ return shp
845
+
846
+
847
+ def check_flatten_order(order):
848
+ """Check flatten function input order"""
849
+ if not isinstance(order, str):
850
+ raise TypeError(f"For Tensor.flatten, the parameter 'order' must be a string, but got {type(order)}")
851
+ if order not in ('C', 'F'):
852
+ raise ValueError(f"For Tensor.flatten, the parameter 'order' must be 'C' or 'F', but got '{order}'")
853
+
854
+
855
+ def check_swapaxes_axis(axes, ndim):
856
+ """Check all the axes argument for ops.swapaxes"""
857
+ if isinstance(axes, int):
858
+ return check_axis_in_range(axes, ndim)
859
+ if isinstance(axes, (tuple, list)):
860
+ for axis in axes:
861
+ if not isinstance(axis, int):
862
+ raise TypeError(f"For ops.swapaxes, the axis argument must be integer, but got {type(axis)}.")
863
+ check_axis_in_range(axis, ndim)
864
+ tmp = ()
865
+ for x in axes:
866
+ tmp = tmp + ((x + ndim) % ndim,)
867
+ return tmp
868
+ raise TypeError(f"For ops.swapaxes, the argument 'axes' must be integer, list or tuple for check, " \
869
+ f"but got {type(axes)}.")
870
+
871
+
872
+ def prepare_shape_for_squeeze(shape, axes):
873
+ """
874
+ Creates the squeezed new shape based on the tensor and given axes.
875
+
876
+ Args:
877
+ shape (tuple): the shape of the tensor
878
+ axes Union[int, tuple(int), list(int)]: the axes with dimensions need to
879
+ be squeezed.
880
+
881
+ Returns:
882
+ new_shape(tuple): the shape with dimensions squeezed.
883
+ """
884
+ new_shape = ()
885
+ ndim = len(shape)
886
+
887
+ def _check(axes, ndim):
888
+ if axes >= ndim or axes < -ndim:
889
+ raise ValueError(f"For Tensor.squeeze, the 'axis' must be in the range of [-{ndim}, {ndim}), " \
890
+ f"but got {axes}.")
891
+
892
+ def _check_for(axes, ndim):
893
+ for axis in axes:
894
+ _check(axis, ndim)
895
+
896
+ if isinstance(axes, int):
897
+ _check(axes, ndim)
898
+ axes = (axes,)
899
+ elif isinstance(axes, (list, tuple)):
900
+ _check_for(axes, ndim)
901
+ new_axes = ()
902
+ for item in axes:
903
+ if item not in new_axes:
904
+ new_axes += (item,)
905
+ axes = new_axes
906
+ else:
907
+ raise TypeError(f"For Tensor.squeeze, the parameter 'axes' must be one of [int, tuple, list], " \
908
+ f"but got {type(axes)}")
909
+
910
+ def _check_axis(s, idx, axes, ndim):
911
+ # if an axis is selected with shape entry greater than one, an error is raised.
912
+ if s != 1 and ((idx in axes) or (idx - ndim in axes)):
913
+ raise ValueError(f"For Tensor.squeeze, the shape of parameter 'axis' {axes} must be 1, but got {s}.")
914
+
915
+ for idx in range(ndim):
916
+ s = shape[idx]
917
+ _check_axis(s, idx, axes, ndim)
918
+ if s != 1 or (idx not in axes) and (idx - ndim not in axes):
919
+ new_shape = new_shape + (s,)
920
+
921
+ return new_shape
922
+
923
+
924
+ def check_axis_in_range(axis, ndim):
925
+ """Checks axes are with the bounds of ndim"""
926
+
927
+ def _check():
928
+ if not isinstance(axis, int):
929
+ raise TypeError(f'The axes must be integers, but got {type(axis)}')
930
+
931
+ if axis >= ndim or axis < -ndim:
932
+ raise ValueError(f"The 'axis' must be in the range of [-{ndim}, {ndim}), but got {axis}.")
933
+
934
+ _check()
935
+ return (axis + ndim) % ndim
936
+
937
+
938
+ def check_axis_valid(axes, ndim):
939
+ """
940
+ Checks axes are valid given ndim, and returns axes that can be passed
941
+ to the built-in operator (non-negative, int or tuple)
942
+ """
943
+
944
+ def _check_range(axes):
945
+ for axis in axes:
946
+ check_axis_in_range(axis, ndim)
947
+
948
+ if axes is None:
949
+ axes = tuple(range(ndim))
950
+ return axes
951
+ if isinstance(axes, (tuple, list)):
952
+ _check_range(axes)
953
+ tmp = ()
954
+ for x in axes:
955
+ tmp = tmp + ((x + ndim) % ndim,)
956
+ _check_dup(tmp)
957
+ return tmp
958
+ check_axis_in_range(axes, ndim)
959
+ return (axes % ndim,)
960
+
961
+
962
+ def max_(*args):
963
+ """Return the maximum value of the input parameter."""
964
+ return max(*args)
965
+
966
+
967
+ def min_(*args):
968
+ """Return the minimum value of the input parameter."""
969
+ return min(*args)
970
+
971
+
972
+ def is_stub_tensor(tensor):
973
+ return hasattr(tensor, "stub")
974
+
975
+
976
+ def expanded_shape(ndim, axis_size, axis):
977
+ """
978
+ Returns a shape with size = 1 for all dimensions
979
+ except at axis.
980
+ """
981
+ return tuple(axis_size if i == axis else 1 for i in range(ndim))
982
+
983
+
984
+ def tuple_slice(tup, start, end):
985
+ """get sliced tuple from start and end."""
986
+ return tup[start:end]
987
+
988
+
989
+ def infer_out_shape(*shapes):
990
+ """
991
+ Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
992
+ """
993
+
994
+ def _check(items, max_size, shapes):
995
+ for item in items:
996
+ if item not in (1, max_size):
997
+ raise ValueError(f'For Tensor, the dimension on each axis must be 1 or the max value on the axis' \
998
+ f'to support broadcasting, but got shapes {shapes,}')
999
+
1000
+ shape_out = ()
1001
+ max_len = max([len(it) for it in shapes])
1002
+ for i in range(max_len):
1003
+ items = [it[i - (max_len - len(it))] if i - (max_len - len(it)) >= 0 else 1 for it in shapes]
1004
+ max_size = 0 if 0 in items else max(items)
1005
+ _check(items, max_size, shapes)
1006
+ shape_out = shape_out + (max_size,)
1007
+ return shape_out
1008
+
1009
+
1010
+ def check_axis_type(axis, type_int=True, type_tuple=True, type_list=True):
1011
+ """Check axis argument type."""
1012
+ if type_int and isinstance(axis, int):
1013
+ return True
1014
+ if (type_tuple and isinstance(axis, tuple)) or (type_list and isinstance(axis, list)):
1015
+ for ax in axis:
1016
+ if not isinstance(ax, int):
1017
+ raise TypeError(f"For Tensor.ptp, each axis must be integer, but got {type(ax)} in {axis}.")
1018
+ return True
1019
+
1020
+ type_str = ""
1021
+ if type_int:
1022
+ type_str += "int, "
1023
+ if type_tuple:
1024
+ type_str += "tuple, "
1025
+ if type_list:
1026
+ type_str += "list, "
1027
+ raise TypeError(f"For Tensor.ptp, the axis should be {type_str}, but got {type(axis)}.")
1028
+
1029
+
1030
+ def check_and_canonicalize_axes(axes, ndim):
1031
+ """Check whether the types and values of input axes are valid."""
1032
+
1033
+ def _check(axes, ax, ndim):
1034
+ if not isinstance(ax, int):
1035
+ raise TypeError(f"Each axis should be integer, but got {type(ax)} in {axes}.")
1036
+ if ax >= ndim or ax < -ndim:
1037
+ raise ValueError(f"The 'axis' must be in the range of [-{ndim}, {ndim}), but got {ax}.")
1038
+
1039
+ axes = axes if isinstance(axes, tuple) else (axes,)
1040
+ new_axes = ()
1041
+ for ax in axes:
1042
+ _check(axes, ax, ndim)
1043
+ ax = ax if ax >= 0 else ax + ndim
1044
+ new_axes += (ax,)
1045
+ _check_dup(new_axes)
1046
+ return new_axes
1047
+
1048
+
1049
+ def check_type_support(dtype, device, supported_dtypes):
1050
+ """Checks whether the data type is supported."""
1051
+ return dtype in supported_dtypes or not context.get_context('device_target') == device
1052
+
1053
+
1054
+ def check_sparse_tensor_input(indices, values, shape):
1055
+ """Common input check for SparseTensors."""
1056
+ if not isinstance(indices, Tensor_) and not is_stub_tensor(indices):
1057
+ raise TypeError(f"For SparseTensors, 'indices' must be Tensor, but got {type(indices)}.")
1058
+ if not isinstance(values, Tensor_) and not is_stub_tensor(values):
1059
+ raise TypeError(f"For SparseTensors, 'values' must be Tensor, but got {type(values)}.")
1060
+ if not isinstance(shape, tuple):
1061
+ raise TypeError(f"For SparseTensors, 'shape' must be tuple, but got {type(shape)}.")
1062
+
1063
+
1064
+ def check_csr_tensor_input(indptr, indices, values, shape):
1065
+ """Checks inputs type for CSRTensor."""
1066
+ if not isinstance(indptr, Tensor_) and not is_stub_tensor(indptr):
1067
+ raise TypeError(f"For CSRTensor, 'indptr' must be Tensor, but got {type(indptr)}.")
1068
+ check_sparse_tensor_input(indices, values, shape)
1069
+
1070
+
1071
+ def check_csr_tensor_shape(indptr_shp, indices_shp, values_shp, csr_shp):
1072
+ """Checks input tensors' shapes for CSRTensor."""
1073
+ # Support empty sparse tensor
1074
+ if (indptr_shp == (0,)) and (indices_shp == (0,)) and (values_shp == (0,)):
1075
+ return
1076
+ shape_size = 1
1077
+ val_shp_size = 1
1078
+ for item in csr_shp:
1079
+ if item <= 0:
1080
+ raise ValueError(f"For CSRTensor, the element of shape must be positive, but got {item}")
1081
+ if not isinstance(item, int):
1082
+ raise TypeError(f"For CSRTensor, the element type of shape must be int, but got {type(item)}")
1083
+ shape_size *= item
1084
+ for item in values_shp:
1085
+ if item <= 0:
1086
+ raise ValueError(f"The element of shape must be positive, but got {item}")
1087
+ val_shp_size *= item
1088
+ if shape_size < val_shp_size:
1089
+ raise ValueError(f"Shape total size: {shape_size} is too small to hold {val_shp_size} non-zero values.")
1090
+ if len(indices_shp) != 1:
1091
+ raise ValueError(f"For CSRTensor, indices must be a 1-dimensional tensor, " \
1092
+ f"but got a {len(indices_shp)} dimension tensor.")
1093
+ if len(indptr_shp) != 1:
1094
+ raise ValueError(f"For CSRTensor, indptr must be a 1-dimensional tensor, " \
1095
+ f"but got a {len(indptr_shp)} dimension tensor.")
1096
+ if csr_shp[0] + 1 != indptr_shp[0]:
1097
+ raise ValueError(f"For CSRTensor, indptr must have length (1 + shape[0]), " \
1098
+ f"but got: {indptr_shp[0]}")
1099
+ if indices_shp[0] != values_shp[0]:
1100
+ err_msg1 = "For CSRTensor, indices and values must equal in their shape, "
1101
+ err_msg2 = f"but got indices shape: {indices_shp[0]}, values shape: {values_shp[0]}."
1102
+ raise ValueError(err_msg1 + err_msg2)
1103
+ if len(values_shp) + 1 != len(csr_shp):
1104
+ raise ValueError(f"Values' dimension should equal to CSRTensor's dimension - 1, but got" \
1105
+ f"Values' dimension: {len(values_shp)} , CSRTensor's dimension: " \
1106
+ f"{len(csr_shp)}")
1107
+ if values_shp[1:] != csr_shp[2:]:
1108
+ raise ValueError(f"CSRTensor's shape[2: ] must be equal to value's shape[1: ]," \
1109
+ f"but CSRTensor's shape[2: ] got: {csr_shp[2:]} and value's shape[1: ]" \
1110
+ f"got: {values_shp[1:]}")
1111
+
1112
+
1113
+ def check_csr_tensor_dtype(indptr_dtype, indices_dtype):
1114
+ """Checks input tensors' data types for CSRTensor."""
1115
+ if indptr_dtype not in (mstype.int16, mstype.int32, mstype.int64):
1116
+ raise TypeError(f"For CSRTensor, indptr must have int16 or int32 or int64 data type, " \
1117
+ f"but got {indptr_dtype}.")
1118
+ if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64):
1119
+ raise TypeError(f"For CSRTensor, indices must have int16 or int32 or int64 data type, " \
1120
+ f"but got {indices_dtype}.")
1121
+
1122
+
1123
+ def check_coo_tensor_input(indices, values, shape):
1124
+ """Checks inputs type for COOTensor."""
1125
+ check_sparse_tensor_input(indices, values, shape)
1126
+
1127
+
1128
+ def check_coo_tensor_shape(indices_shp, values_shp, coo_shp):
1129
+ """Checks input tensors' shapes for COOTensor."""
1130
+ if len(coo_shp) != 2:
1131
+ raise ValueError(f"For COOTensor, the length of 'shape' must be 2, but got {coo_shp}.")
1132
+ if (indices_shp == (0,)) and (values_shp == (0,)):
1133
+ return
1134
+ shp_mul = 1
1135
+ for sh in coo_shp:
1136
+ if sh <= 0:
1137
+ raise ValueError(f"For COOTensor, the element of 'shape' must be positive, but got {sh} in {coo_shp}.")
1138
+ if not isinstance(sh, int):
1139
+ raise TypeError(f"For COOTensor, the element type of 'shape' must be int, but got {type(sh)}")
1140
+ shp_mul *= sh
1141
+ if shp_mul < values_shp[0]:
1142
+ raise ValueError(f"For COOTensor, shape is too small: ({shp_mul}) to hold all values({values_shp[0]}).")
1143
+ if len(indices_shp) != 2:
1144
+ raise ValueError(f"For COOTensor, 'indices' must be a 2-dimensional tensor, but got a {len(indices_shp)}" \
1145
+ f"-dimensional tensor.")
1146
+ if len(values_shp) != 1:
1147
+ raise ValueError(f"For COOTensor, 'values' must be a 1-dimensional tensor, but got a {len(values_shp)}" \
1148
+ f"-dimensional tensor.")
1149
+ if indices_shp[0] != values_shp[0]:
1150
+ raise ValueError(f"For COOTensor, 'indices.shape[0]' must be euqal to 'values.shape[0]', but got " \
1151
+ f"'indices.shape[0]' = {indices_shp[0]} and 'values.shape[0]' = {values_shp[0]}.")
1152
+ if indices_shp[1] != 2:
1153
+ raise ValueError(f"For COOTensor, 'indices.shape[1]' must be 2, but got {indices_shp[1]}.")
1154
+
1155
+
1156
+ def check_coo_tensor_dtype(indices_dtype):
1157
+ """Checks input tensors' data types for COOTensor."""
1158
+ if indices_dtype not in (mstype.int16, mstype.int32, mstype.int64):
1159
+ raise TypeError(f"For COOTensor, the type of 'indices' must be one of [int16, int32, int64], but got " \
1160
+ f"{indices_dtype}.")
1161
+
1162
+
1163
+ def check_element_type_of_iterable(arg_name, arg_value, valid_types, prim_name=None):
1164
+ """Check type of the element of a iterabel object, except dict."""
1165
+ check_value_type(arg_name, arg_value, [list, tuple], prim_name)
1166
+ type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in valid_types]
1167
+ num_types = len(valid_types)
1168
+ msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
1169
+ for element in arg_value:
1170
+ if not isinstance(element, tuple(valid_types)):
1171
+ raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
1172
+ f"{type_names if num_types > 1 else type_names[0]}, " \
1173
+ f"but got '{element}' with type '{type(element).__name__}'.")
1174
+
1175
+
1176
+ def check_element_type_of_dict(arg_name, arg_value, key_types, value_types, prim_name=None):
1177
+ """Check the type of key and value of a dict."""
1178
+ check_value_type(arg_name, arg_value, [dict], prim_name)
1179
+ msg_prefix = f"For '{prim_name}', the" if prim_name else "The"
1180
+ type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in key_types]
1181
+ num_types = len(key_types)
1182
+ for element in arg_value.keys():
1183
+ if not isinstance(element, tuple(key_types)):
1184
+ raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
1185
+ f"{type_names if num_types > 1 else type_names[0]}, " \
1186
+ f"but got '{element}' with type '{type(element).__name__}'.")
1187
+
1188
+ type_names = [t.__name__ if hasattr(t, '__name__') else str(t) for t in value_types]
1189
+ num_types = len(value_types)
1190
+ for element in arg_value.values():
1191
+ if not isinstance(element, tuple(value_types)):
1192
+ raise TypeError(f"{msg_prefix} type of '{arg_name}' should be {'one of ' if num_types > 1 else ''}" \
1193
+ f"{type_names if num_types > 1 else type_names[0]}, " \
1194
+ f"but got '{element}' with type '{type(element).__name__}'.")
1195
+
1196
+
1197
+ def check_size_and_element_type_of_tuple(arg_name, arg_value, expect_size, expect_element_type, prim_name=None):
1198
+ """Check the size and element type of a tuple."""
1199
+ check_value_type(arg_name, arg_value, [tuple], prim_name)
1200
+ check_equal_int(len(arg_value), expect_size, arg_name + ' size', prim_name)
1201
+ check_element_type_of_iterable('arg_name', arg_value, [expect_element_type], prim_name)
1202
+
1203
+
1204
+ def _check_symbol(dyn_input, net_input, index, symbolic_shape_data):
1205
+ """Check symbolic shape values."""
1206
+ actual_shape = net_input.shape
1207
+ for i, sym in enumerate(dyn_input.symbolic_shape):
1208
+ # the Symbol is converted to dict
1209
+ if not isinstance(sym, dict):
1210
+ continue
1211
+ # the value of symbols with same "id" should be equal.
1212
+ if "id" in sym:
1213
+ sym_id = sym["id"]
1214
+ k_idval = "unique_id_value_map"
1215
+ if k_idval not in symbolic_shape_data:
1216
+ symbolic_shape_data[k_idval] = {}
1217
+ unique_id_value = symbolic_shape_data[k_idval]
1218
+ if sym_id not in unique_id_value:
1219
+ unique_id_value[sym_id] = actual_shape[i]
1220
+ elif unique_id_value[sym_id] != actual_shape[i]:
1221
+ raise ValueError(
1222
+ f"The {i + 1}th shape value of {index + 1}th actual input args is a unique symbol, all values must "
1223
+ f"be the same. The previous value is {unique_id_value[sym_id]}, but the current value is "
1224
+ f"{actual_shape[i]}. Actual shape: {actual_shape}, axis: {i}.")
1225
+ # check the value in range [min, max].
1226
+ if "min" in sym and actual_shape[i] < sym["min"]:
1227
+ raise ValueError(
1228
+ f"The {i + 1}th shape value of {index + 1}th actual input args must be greater than or equal to the "
1229
+ f"'min' value '{sym['min']}' of `Symbol`, but got '{actual_shape[i]}'. Actual shape: {actual_shape}, "
1230
+ f"axis: {i}.")
1231
+ if "max" in sym and actual_shape[i] > sym["max"]:
1232
+ raise ValueError(
1233
+ f"The {i + 1}th shape value of {index + 1}th actual input args must be less than or equal to the "
1234
+ f"'max' value '{sym['max']}' of `Symbol`, but got '{actual_shape[i]}'. Actual shape: {actual_shape}, "
1235
+ f"axis: {i}.")
1236
+ # check the shape item that satisfies the "divisor * N + remainder, N >= 1".
1237
+ d = sym.get("divisor", 1)
1238
+ r = sym.get("remainder", 0)
1239
+ if actual_shape[i] < d or actual_shape[i] % d != r:
1240
+ raise ValueError(
1241
+ f"The {i + 1}th shape value of {index + 1}th actual input args must be match the 'divisor'(d) and "
1242
+ f"'remainder'(r) of `Symbol`. The value should be 'd * N + r' for 'N > 0', got d={d} and r={r}, but "
1243
+ f"actual shape value is '{actual_shape[i]}'. Actual shape: {actual_shape}, axis: {i}")
1244
+
1245
+
1246
+ def check_symbolic_shape(dynamic_inputs, actual_inputs):
1247
+ """Check the symboic shape"""
1248
+ symbolic_shape_data = {}
1249
+
1250
+ def run_check(dyn_inputs, net_inputs):
1251
+ """the real checking function"""
1252
+ for index, (dyn_input, net_input) in enumerate(zip(dyn_inputs, net_inputs)):
1253
+ if isinstance(dyn_input, (tuple, list)):
1254
+ run_check(dyn_input, net_input)
1255
+ elif hasattr(dyn_input, "symbolic_shape"):
1256
+ _check_symbol(dyn_input, net_input, index, symbolic_shape_data)
1257
+
1258
+ run_check(dynamic_inputs, actual_inputs)
1259
+
1260
+
1261
+ def check_input_format(input_param):
1262
+ """Judge input format."""
1263
+ if input_param == "NCHW":
1264
+ return input_param
1265
+ raise ValueError(f"The data format must be NCHW, but got {input_param}.")
1266
+
1267
+
1268
+ def _expand_tuple(n_dimensions):
1269
+ """To expand an int number to tuple."""
1270
+
1271
+ def convert(m):
1272
+ if not isinstance(m, tuple):
1273
+ if isinstance(m, int) and not isinstance(m, bool):
1274
+ return tuple(repeat(m, n_dimensions))
1275
+ raise TypeError(f"When expanding an int number to tuple, input type must be integer or tuple[int], " \
1276
+ f"but got {type(m)}")
1277
+
1278
+ if not len(m) is n_dimensions:
1279
+ raise TypeError(f"When expanding an int number to tuple, input tuple dimension must be {n_dimensions}, " \
1280
+ f"but got {m}")
1281
+
1282
+ for i in m:
1283
+ if not isinstance(i, int) or isinstance(i, bool):
1284
+ raise TypeError(f"When expanding an int number to tuple, " \
1285
+ f"the type of element in input tuple must be an integer, but got {type(i)}.")
1286
+ return m
1287
+
1288
+ return convert
1289
+
1290
+
1291
+ def _check_data_type_valid(data, valid_type):
1292
+ """Check data type valid."""
1293
+ if valid_type is None:
1294
+ return data is None
1295
+ if isinstance(data, valid_type):
1296
+ if hasattr(data, 'size') and data.size == 0:
1297
+ msg = "The input data can not be empty."
1298
+ logger.critical(msg)
1299
+ raise ValueError(msg)
1300
+ return True
1301
+ return False
1302
+
1303
+
1304
+ def check_input_data(*data, data_class):
1305
+ """Input data check."""
1306
+ for item in data:
1307
+ if isinstance(item, (list, tuple)):
1308
+ for v in item:
1309
+ check_input_data(v, data_class=data_class)
1310
+ elif isinstance(item, dict):
1311
+ for v in item.values():
1312
+ check_input_data(v, data_class=data_class)
1313
+ else:
1314
+ if isinstance(data_class, (tuple, list)):
1315
+ ret = True in tuple(_check_data_type_valid(item, data_type) for data_type in data_class)
1316
+ else:
1317
+ ret = _check_data_type_valid(item, data_class)
1318
+ if not ret:
1319
+ data_class_str = tuple(i.__name__ if hasattr(i, '__name__') else i for i in data_class) if isinstance(
1320
+ data_class, (tuple, list)) else (data_class if data_class is None else data_class.__name__)
1321
+ raise TypeError(f'The types of input data must be in the Union({data_class_str}, ' \
1322
+ f'tuple[{data_class_str}], list[{data_class_str}], dict[{data_class_str}]), ' \
1323
+ f'but got type {item if item is None else type(item).__name__}.')
1324
+
1325
+
1326
+ def check_input_dataset(*dataset, dataset_type):
1327
+ """Input dataset check."""
1328
+ if not dataset:
1329
+ return False
1330
+ for item in dataset:
1331
+ if not isinstance(item, dataset_type):
1332
+ return False
1333
+ return True
1334
+
1335
+
1336
+ def check_output_data(data):
1337
+ """Output data check."""
1338
+ if data is None:
1339
+ raise RuntimeError('The output data can not be None, please check your net or input data.')
1340
+
1341
+
1342
+ once = _expand_tuple(1)
1343
+ twice = _expand_tuple(2)
1344
+ triple = _expand_tuple(3)
1345
+
1346
+
1347
+ def args_type_check(*type_args, **type_kwargs):
1348
+ """Check whether input data type is correct."""
1349
+
1350
+ def type_check(func):
1351
+ sig = inspect.signature(func)
1352
+ bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
1353
+
1354
+ @wraps(func)
1355
+ def wrapper(*args, **kwargs):
1356
+ nonlocal bound_types
1357
+ bound_values = sig.bind(*args, **kwargs)
1358
+ argument_dict = bound_values.arguments
1359
+ if "kwargs" in bound_types:
1360
+ bound_types = bound_types["kwargs"]
1361
+ if "kwargs" in argument_dict:
1362
+ argument_dict = argument_dict["kwargs"]
1363
+ for name, value in argument_dict.items():
1364
+ if name in bound_types:
1365
+ if value is not None and not isinstance(value, bound_types[name]):
1366
+ raise TypeError(f"The parameter '{name}' must be {bound_types[name]}, but got {type(value)}")
1367
+ return func(*args, **kwargs)
1368
+
1369
+ return wrapper
1370
+
1371
+ return type_check
1372
+
1373
+
1374
+ def check_hook_fn(hook_type, hook_fn):
1375
+ """Check hook fn"""
1376
+ if context.get_context("mode") != context.PYNATIVE_MODE:
1377
+ logger.warning(f"'{hook_type}' function is only supported in pynative mode, you can use "
1378
+ f"context.set_context to set pynative mode.")
1379
+ return False
1380
+
1381
+ if not isinstance(hook_fn, (FunctionType, MethodType)):
1382
+ raise TypeError(f"When using 'hook_type(hook_fn)', the type of 'hook_fn' must be python "
1383
+ f"function, but got {type(hook_fn)}.")
1384
+
1385
+ if hook_fn.__code__.co_name == "staging_specialize":
1386
+ raise TypeError(f"Decorating hook function {hook_fn.__name__} with '@jit' is not supported.")
1387
+
1388
+ tensor_hook_func_args_num = 1
1389
+ pre_hook_func_args_num = 2
1390
+ forward_hook_and_backward_hook_func_args_num = 3
1391
+ # Real args number, exclude class method self param
1392
+ hook_fn_args_num = len(inspect.signature(hook_fn).parameters)
1393
+
1394
+ if hook_type == "register_hook" and hook_fn_args_num != tensor_hook_func_args_num:
1395
+ raise TypeError(f"Tensor hook function {hook_fn.__name__} arg num should be {tensor_hook_func_args_num}, but "
1396
+ f"got {hook_fn_args_num}")
1397
+
1398
+ if hook_type == "register_forward_pre_hook" and hook_fn_args_num != pre_hook_func_args_num:
1399
+ raise TypeError(f"forward_pre_hook function {hook_fn.__name__} args num should be {pre_hook_func_args_num}, "
1400
+ f"but got {hook_fn_args_num}")
1401
+
1402
+ if (hook_type == "register_forward_hook" and
1403
+ hook_fn_args_num != forward_hook_and_backward_hook_func_args_num):
1404
+ raise TypeError(f"forward_hook function {hook_fn.__name__} args num should be "
1405
+ f"{forward_hook_and_backward_hook_func_args_num}, but got {hook_fn_args_num}")
1406
+
1407
+ if hook_type == "register_backward_pre_hook" and hook_fn_args_num != pre_hook_func_args_num:
1408
+ raise TypeError(f"backward_pre_hook function {hook_fn.__name__} args num should be {pre_hook_func_args_num},"
1409
+ f" but got {hook_fn_args_num}")
1410
+
1411
+ if (hook_type == "register_backward_hook" and
1412
+ hook_fn_args_num != forward_hook_and_backward_hook_func_args_num):
1413
+ raise TypeError(f"backward_hook function {hook_fn.__name__} args num should be "
1414
+ f"{forward_hook_and_backward_hook_func_args_num}, but got {hook_fn_args_num}")
1415
+
1416
+ return True
1417
+
1418
+
1419
+ _set_record = {}