mindspore 2.4.0__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-311-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1459 @@
1
+ # Copyright 2020-2023 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """constexpr util"""
17
+ from __future__ import absolute_import
18
+ from enum import IntEnum
19
+
20
+
21
+ from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
22
+ from mindspore.ops import functional as F
23
+ from mindspore.ops import operations as P
24
+ from mindspore.ops.composite import base
25
+ from mindspore.ops._primitive_cache import _get_cache_prim
26
+ from mindspore.ops.operations._inner_ops import TensorCopySlices, SliceGetItem, \
27
+ TopTypeof, IsParameter, GetitemTensorIndexInfo, SetitemTensorIndexInfo, \
28
+ SelectView, CopyWithSlice
29
+ from mindspore.ops.operations._sequence_ops import TensorToTuple, TensorToScalar, TupleToTensor
30
+ from mindspore.common import dtype as mstype
31
+ from mindspore.common._register_for_tensor import tensor_operator_registry
32
+ from mindspore.common.initializer import Zero
33
+ from mindspore.common import Tensor, CSRTensor, COOTensor, mutable
34
+ from mindspore import ops
35
+ from mindspore.ops.primitive import _primexpr
36
+ from mindspore import _checkparam as validator
37
+ from mindspore.common._stub_tensor import _convert_stub
38
+
39
+ slice_get_item = SliceGetItem()
40
+ hyper_map = base.HyperMap()
41
+ stack = P.Stack(axis=-1)
42
+ copy_slice = TensorCopySlices()
43
+ toptypeof = TopTypeof()
44
+ is_parameter = IsParameter()
45
+ getitem_tensor_index_info = GetitemTensorIndexInfo(const_utils.is_ascend())
46
+ setitem_tensor_index_info = SetitemTensorIndexInfo(const_utils.is_ascend())
47
+
48
+ selevt_view = SelectView()
49
+ copy_with_slice = CopyWithSlice()
50
+
51
+ def strided_slice(data, begin_strides, end_strides, step_strides, begin_mask=0, end_mask=0, ellipsis_mask=0,
52
+ new_axis_mask=0, shrink_axis_mask=0):
53
+ """strided_slice primitive cache"""
54
+ strided_slice_ = _get_cache_prim(P.StridedSlice)(begin_mask, end_mask, ellipsis_mask, new_axis_mask,
55
+ shrink_axis_mask)
56
+ return strided_slice_(data, begin_strides, end_strides, step_strides)
57
+
58
+
59
+ class ValueTransferType(IntEnum):
60
+ """Transfer op types of handling tensor getitem/setitem"""
61
+ kUnknown = 0
62
+ kTensorScatterUpdate = 1
63
+ kExpandDims = 2
64
+ kBroadCast = 3
65
+ kCast = 4
66
+ kSelect = 5
67
+ kGather = 6
68
+ kStrideSlice = 7
69
+ kStrideSliceWithMask = 8
70
+ kGatherND = 9
71
+ kScatterNdUpdate = 10
72
+ kReshape = 11
73
+ kSelectView = 12
74
+ kUnsqueeze = 13
75
+ kCopyView = 14
76
+ kScatterND = 15
77
+ kNumberToTensor = 16
78
+ kHandleSequenceValue = 17
79
+ kByPass = 18
80
+ kReSetItemByIndex = 19
81
+ kCopySlice = 20
82
+ kSetItemByBool = 21
83
+ kEmptyTensor = 22
84
+ kSetItemByEllipsis = 23
85
+ kFormatIndexTensor = 24
86
+ kGetitemByBoolTensor = 25
87
+ kSetitemByBoolTensor = 26
88
+ kJustReturn = 27
89
+ kRaiseIndexError = 28
90
+
91
+
92
+ def data_update(transfer_types, args, data, new_index, value=None):
93
+ """
94
+ We finally generate a new tensor when handling tensor getitem/setitem
95
+ by transfer data and value with index.
96
+ """
97
+ origin_data = data
98
+ for transfer_type, arg in zip(transfer_types, args):
99
+ if transfer_type == ValueTransferType.kUnknown:
100
+ raise IndexError(f"Inlvaid transfer type {transfer_type}.")
101
+ if transfer_type <= ValueTransferType.kScatterND:
102
+ data = data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value)
103
+ if transfer_type == ValueTransferType.kJustReturn:
104
+ return _convert_stub(arg)
105
+ if transfer_type == ValueTransferType.kSetItemByBool:
106
+ return tensor_setitem_by_bool(data, new_index, value)
107
+ if transfer_type == ValueTransferType.kCopySlice:
108
+ return copy_slice(data, value.astype(data.dtype), arg[0], arg[1], arg[2])
109
+ if transfer_type == ValueTransferType.kSetItemByEllipsis:
110
+ return tensor_setitem_by_ellipsis(data, new_index, value)
111
+ if transfer_type == ValueTransferType.kReSetItemByIndex:
112
+ data[new_index] = value
113
+ return data
114
+ if transfer_type == ValueTransferType.kEmptyTensor:
115
+ return handle_empty_tensor(arg, data)
116
+ if transfer_type == ValueTransferType.kFormatIndexTensor:
117
+ new_index = format_index_tensor(new_index, arg)
118
+ if transfer_type == ValueTransferType.kGetitemByBoolTensor:
119
+ return F.gather_nd(data, new_index.nonzero())
120
+ if transfer_type == ValueTransferType.kSetitemByBoolTensor:
121
+ return handle_setitem_by_bool_tensor(data, new_index, value)
122
+ if transfer_type == ValueTransferType.kRaiseIndexError:
123
+ raise IndexError(
124
+ f'index {arg[0]} is out of bounds for dimension with size {arg[1]}')
125
+ return data
126
+
127
+
128
+ def data_update_by_ops(transfer_type, arg, data, new_index, origin_data, value=None):
129
+ """
130
+ Generate a new tensor when handling tensor getitem/setitem
131
+ by ops.
132
+ """
133
+ if transfer_type == ValueTransferType.kStrideSliceWithMask:
134
+ stride_info, mask_index = arg[0], arg[1]
135
+ data = strided_slice(data, stride_info[0], stride_info[1], stride_info[2],
136
+ mask_index[0], mask_index[1], 0, 0, mask_index[2])
137
+ elif transfer_type == ValueTransferType.kGatherND:
138
+ if isinstance(new_index, list):
139
+ new_index = handle_multi_dim_index_tensor(new_index, arg)
140
+ new_index = format_index_tensor(new_index, (None, F.shape(data)[:F.shape(new_index)[-1]]))
141
+ data = F.gather_nd(data, new_index)
142
+ elif transfer_type == ValueTransferType.kTensorScatterUpdate:
143
+ if isinstance(new_index, list):
144
+ new_index = handle_multi_dim_index_tensor(new_index, arg)
145
+ data = F.tensor_scatter_update(data, new_index, value)
146
+ elif transfer_type == ValueTransferType.kScatterNdUpdate:
147
+ F.scatter_nd_update(data, new_index, value)
148
+ elif transfer_type == ValueTransferType.kSelect:
149
+ data = F.select(Tensor(new_index), value, data)
150
+ elif transfer_type == ValueTransferType.kSelectView:
151
+ data = selevt_view(data, arg[0], arg[1])
152
+ elif transfer_type == ValueTransferType.kCopyView:
153
+ value = _broadcast(F.shape(data), F.cast(value, F.dtype(data)))
154
+ data = copy_with_slice(data, value)
155
+ return origin_data
156
+ elif transfer_type == ValueTransferType.kReshape:
157
+ data = F.reshape(data, arg)
158
+ elif transfer_type == ValueTransferType.kGather:
159
+ data = F.gather(data, new_index, 0)
160
+ elif transfer_type == ValueTransferType.kExpandDims:
161
+ data = F.expand_dims(data, 0)
162
+ elif transfer_type == ValueTransferType.kUnsqueeze:
163
+ data = F.unsqueeze(data, arg)
164
+ elif transfer_type == ValueTransferType.kStrideSlice:
165
+ data = strided_slice(data, arg[0], arg[1], arg[2])
166
+ else:
167
+ raise IndexError(f"Inlvaid transfer type {transfer_type}.")
168
+ return data
169
+
170
+
171
+ def value_update(transfer_types, args, data, value):
172
+ """Transfer value before set value to tensor when handling tensor setitem"""
173
+ for transfer_type, arg in zip(transfer_types, args):
174
+ if transfer_type == ValueTransferType.kByPass:
175
+ continue
176
+ if transfer_type == ValueTransferType.kNumberToTensor:
177
+ value = F.cast(value, F.dtype(data))
178
+ elif transfer_type == ValueTransferType.kHandleSequenceValue:
179
+ op_type, index = arg
180
+ if op_type == const_utils.SET_ITEM_BY_ONE_TENSOR:
181
+ index = Tensor(index)
182
+ value = _generate_updates_from_sequence(
183
+ data, index, value, op_type)
184
+ elif transfer_type == ValueTransferType.kExpandDims:
185
+ value = F.expand_dims(value, arg)
186
+ elif transfer_type == ValueTransferType.kBroadCast:
187
+ value = _broadcast(arg, value.astype(F.dtype(data)))
188
+ elif transfer_type == ValueTransferType.kCast:
189
+ value = F.cast(value, F.dtype(data))
190
+ elif transfer_type == ValueTransferType.kReshape:
191
+ value = F.reshape(value, arg)
192
+ elif transfer_type == ValueTransferType.kScatterND:
193
+ value = F.scatter_nd(arg[0], value, arg[1])
194
+ else:
195
+ raise IndexError(f"Inlvaid transfer type {transfer_type}.")
196
+ return value
197
+
198
+
199
+ def _tensor_getitem(self, index):
200
+ """Handle tensor getitem"""
201
+ new_index, tensor_update_types, tensor_update_args = getitem_tensor_index_info(
202
+ self, index)
203
+ return data_update(tensor_update_types, tensor_update_args, self, new_index)
204
+
205
+
206
+ def _tensor_setitem(self, index, value):
207
+ """Handle tensor setitem"""
208
+ setitem_info = setitem_tensor_index_info(self, index, value)
209
+ new_index = setitem_info[0]
210
+ v_transfer_types = setitem_info[1]
211
+ v_transfer_args = setitem_info[2]
212
+ data_update_types = setitem_info[3]
213
+ data_update_args = setitem_info[4]
214
+ value = value_update(v_transfer_types, v_transfer_args, self, value)
215
+ output = data_update(data_update_types, data_update_args, self, new_index, value)
216
+ if new_index == "view":
217
+ return (self,)
218
+ return output
219
+
220
+
221
+ setattr(tensor_operator_registry, "__getitem__", _tensor_getitem)
222
+ setattr(tensor_operator_registry, "__setitem__", _tensor_setitem)
223
+
224
+
225
+ def _tensor_add(self, other):
226
+ if isinstance(other, (tuple, list)):
227
+ other = sequence_to_tensor(other, F.dtype(self))
228
+ if isinstance(other, COOTensor):
229
+ return other + self
230
+ return F.add(self, other)
231
+
232
+
233
+ def _tensor_sub(self, other):
234
+ if isinstance(self, (tuple, list)):
235
+ self = sequence_to_tensor(self, F.dtype(other))
236
+ if isinstance(other, (tuple, list)):
237
+ other = sequence_to_tensor(other, F.dtype(self))
238
+ if isinstance(other, COOTensor):
239
+ return F.tensor_scatter_sub(self, other.indices, other.values)
240
+ return F.sub(self, other)
241
+
242
+
243
+ def _tensor_mul(self, other):
244
+ if isinstance(other, (tuple, list)):
245
+ other = sequence_to_tensor(other, F.dtype(self))
246
+ elif isinstance(other, (CSRTensor, COOTensor)):
247
+ return other * self
248
+ return F.mul(self, other)
249
+
250
+
251
+ def _tensor_matmul(self, other):
252
+ return F.matmul(self, other)
253
+
254
+
255
+ def _tensor_div(self, other):
256
+ if isinstance(self, (tuple, list)):
257
+ self = sequence_to_tensor(self, F.dtype(other))
258
+ if isinstance(other, (tuple, list)):
259
+ other = sequence_to_tensor(other, F.dtype(self))
260
+ return F.div(self, other)
261
+
262
+
263
+ def _tensor_mod(self, other):
264
+ if isinstance(self, (tuple, list)):
265
+ self = sequence_to_tensor(self, F.dtype(other))
266
+ if isinstance(other, (tuple, list)):
267
+ other = sequence_to_tensor(other, F.dtype(self))
268
+ return F.floormod(self, other)
269
+
270
+
271
+ def _tensor_pow(self, other):
272
+ if isinstance(other, (tuple, list)):
273
+ other = sequence_to_tensor(other, F.dtype(self))
274
+ return F.tensor_pow(self, other)
275
+
276
+
277
+ def _tensor_rpow(self, other):
278
+ if isinstance(other, (tuple, list)):
279
+ other = sequence_to_tensor(other, F.dtype(self))
280
+ return F.tensor_pow(other, self)
281
+
282
+
283
+ def _tensor_floordiv(self, other):
284
+ if isinstance(self, (tuple, list)):
285
+ self = sequence_to_tensor(self, F.dtype(other))
286
+ if isinstance(other, (tuple, list)):
287
+ other = sequence_to_tensor(other, F.dtype(self))
288
+ return F.floordiv(self, other)
289
+
290
+
291
+ setattr(tensor_operator_registry, '__add__', _tensor_add)
292
+ setattr(tensor_operator_registry, '__sub__', _tensor_sub)
293
+ setattr(tensor_operator_registry, '__mul__', _tensor_mul)
294
+ setattr(tensor_operator_registry, '__matmul__', _tensor_matmul)
295
+ setattr(tensor_operator_registry, '__truediv__', _tensor_div)
296
+ setattr(tensor_operator_registry, '__mod__', _tensor_mod)
297
+ setattr(tensor_operator_registry, '__pow__', _tensor_pow)
298
+ setattr(tensor_operator_registry, '__rpow__', _tensor_rpow)
299
+ setattr(tensor_operator_registry, '__floordiv__', _tensor_floordiv)
300
+
301
+
302
+ def _scalar_to_tensor(input_x):
303
+ if ops.isconstant(input_x):
304
+ return P.ScalarToTensor()(input_x, ops.dtype(input_x))
305
+ # use add Tensor([0]) cast scalar to tensor.
306
+ return ops.add(input_x, mutable(Tensor(0)))
307
+
308
+
309
+ @_primexpr
310
+ def _check_scalar_tensor_args(args):
311
+ """For the item, check that the index of the scalar tensor is set."""
312
+ if args not in ((None,), ()):
313
+ const_utils.raise_value_error("For item, the index of scalar Tensor should not be set.")
314
+
315
+
316
+ def tensor_item(data, *args):
317
+ """Tensor getitem by index whose dtype is int or tuple with int."""
318
+ # transform a.item(tuple(int)) -> a.item(int1,int2...intN)
319
+ if data.ndim == 0:
320
+ _check_scalar_tensor_args(args)
321
+ return TensorToScalar()(data)
322
+ if len(args) == 1 and isinstance(args[0], tuple):
323
+ args = args[0]
324
+
325
+ args_types = hyper_map(F.typeof, args)
326
+ if not args or const_utils.judge_index_type(args_types[0], mstype.type_none):
327
+ if data.shape == (1,):
328
+ return TensorToScalar()(data[0])
329
+ const_utils.raise_value_error("Can only convert an array of size 1 to a Python scalar")
330
+
331
+ if not const_utils.judge_indexes_types(args_types, mstype.int64):
332
+ const_utils.raise_type_error("The index object cannot be interpreted as an integer")
333
+
334
+ if len(args) == data.ndim:
335
+ return tensor_index_by_tuple(data, args)
336
+ if len(args) > 1:
337
+ const_utils.raise_value_error("Incorrect number of indices for array")
338
+ output = _tensor_index_by_integer(F.reshape(data, (-1,)), args[0])
339
+ return TensorToScalar()(output)
340
+
341
+
342
+ def tensor_itemset(data, *args):
343
+ """Tensor setitem by index and value."""
344
+ if not args:
345
+ const_utils.raise_value_error("'Tensor.itemset()' must have at least one argument, but got None.")
346
+ if len(args) == 2:
347
+ if const_utils.judge_index_type(F.typeof(args[0]), mstype.int64):
348
+ return tensor_itemset_by_number_with_number(data, args[0], args[1])
349
+ if isinstance(args[0], tuple):
350
+ return tensor_itemset_by_tuple_with_number(data, args[0], args[1])
351
+ const_utils.raise_type_error("The index object cannot be interpreted as an integer")
352
+ if len(args) > 2:
353
+ exp_msg = const_utils.gen_exception_msg("'Tensor.itemset()' must have at most 2 argument, but got {}.",
354
+ len(args))
355
+ const_utils.raise_value_error(exp_msg)
356
+ return tensor_itemset_with_number(data, args[0])
357
+
358
+
359
+ setattr(tensor_operator_registry, "item", tensor_item)
360
+ setattr(tensor_operator_registry, "itemset", tensor_itemset)
361
+
362
+
363
+ def tensor_itemset_with_number(data, number_value):
364
+ """set value of tensor whose shape is (1,)"""
365
+ if not const_utils.judge_index_type(F.typeof(number_value), mstype.number_type):
366
+ exp_msg = const_utils.gen_exception_msg(
367
+ "'Tensor.itemset()' only support number input, but got {}", number_value)
368
+ const_utils.raise_index_error(exp_msg)
369
+ if data.shape != (1,):
370
+ exp_msg = const_utils.gen_exception_msg(
371
+ "Only tensor which shape is (1,) support 1 arg that means omit index, "
372
+ "but the tensor shape is {} and got 1 input.", data.shape)
373
+ const_utils.raise_index_error(exp_msg)
374
+ return const_utils.make_tensor((number_value,), F.dtype(data))
375
+
376
+
377
+ def tensor_itemset_by_number_with_number(data, int_index, number_value):
378
+ flatten_data = F.reshape(data, (-1,))
379
+ itemset_data = tensor_setitem_by_number_with_number(flatten_data, int_index, number_value)
380
+ res_data = F.reshape(itemset_data, F.shape(data))
381
+ return res_data
382
+
383
+
384
+ def tensor_itemset_by_tuple_with_number(data, tuple_index, nubmer_value):
385
+ if len(tuple_index) != data.ndim:
386
+ exp_msg = const_utils.gen_exception_msg(
387
+ "Tuple index len({}) is not same to tensor dimension({})", len(tuple_index), data.ndim)
388
+ const_utils.raise_index_error(exp_msg)
389
+ nubmer_value = F.cast(nubmer_value, F.dtype(data))
390
+ return tensor_itemset_by_tuple_with_tensor(data, tuple_index, nubmer_value)
391
+
392
+
393
+ def _broadcast(broadcast_shape, x):
394
+ """Broadcast tensor to the required shape."""
395
+ if F.shape(x) == broadcast_shape:
396
+ return x
397
+ return F.broadcast_to(x, broadcast_shape)
398
+
399
+
400
+ def _transform_indexing_tensor(broadcast_shape, final_shape, new_shape, item):
401
+ """Transform indexing tensor to the required."""
402
+ item = _broadcast(broadcast_shape, item)
403
+ return _broadcast(final_shape, F.reshape(item, new_shape))
404
+
405
+
406
+ def _transform_ellipsis_to_slice(data, tuple_index, op_name):
407
+ """
408
+ Check if the tuple index len is longer than the data's dims and transform ellipsis in the indices
409
+ to several slice.
410
+ """
411
+ data_shape = F.shape(data)
412
+ indexes_types = hyper_map(toptypeof, tuple_index)
413
+ slice_positions, ellipsis_positions, _, int_positions, _, tensor_positions, sequence_positions = \
414
+ const_utils.get_pos_of_indexes_types(indexes_types, op_name)
415
+
416
+ ellipsis_occupy_dims = data.ndim - (len(slice_positions) + len(int_positions) +
417
+ len(tensor_positions) + len(sequence_positions))
418
+ ellipsis_cnt = len(ellipsis_positions)
419
+
420
+ if ellipsis_occupy_dims < 0:
421
+ if ellipsis_cnt >= 0:
422
+ exp_msg = const_utils.gen_exception_msg(
423
+ "Tuple index {} out rang of tensor shape {}.", tuple_index, data_shape)
424
+ const_utils.raise_index_error(exp_msg)
425
+
426
+ tuple_index_new = ()
427
+ for i, index in enumerate(tuple_index):
428
+ if i in ellipsis_positions:
429
+ for _ in range(ellipsis_occupy_dims):
430
+ empty_slice = const_utils.make_empty_slice()
431
+ tuple_index_new += (empty_slice,)
432
+ else:
433
+ tuple_index_new += (index,)
434
+ return tuple_index_new
435
+
436
+
437
+ def handle_empty_tensor(arg, data):
438
+ """handle data update with empty tensor"""
439
+ if 0 in arg:
440
+ init_func = Zero()
441
+ init_func.__enable_zero_dim__ = True
442
+ return Tensor(shape=arg, dtype=data.dtype, init=init_func)
443
+ return const_utils.make_tensor([], data.dtype, arg)
444
+
445
+
446
+ def handle_multi_dim_index_tensor(new_index, arg):
447
+ """handle data update with multi dim index tensor"""
448
+ slice_cnt = 0
449
+ new_indies_tensor = []
450
+ if len(arg) == 1:
451
+ broadcast_shape = arg[0]
452
+ new_index = hyper_map(F.partial(Tensor), new_index)
453
+ broadcast_tensors = hyper_map(
454
+ F.partial(_broadcast, broadcast_shape), new_index)
455
+ new_broadcast_tensors = ()
456
+ for tensor in broadcast_tensors:
457
+ new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
458
+ new_index = stack(new_broadcast_tensors)
459
+ return new_index
460
+ broadcast_shape, final_shape, index_tensor_new_shape, slice_shapes, tensor_positions, fancy_position = arg
461
+ for i, index in enumerate(new_index):
462
+ if i in tensor_positions:
463
+ transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
464
+ Tensor(index))
465
+ new_indies_tensor.append(F.cast(transform_tensor, mstype.int64))
466
+ else:
467
+ shape = const_utils.compute_slice_shape(
468
+ slice_shapes, len(broadcast_shape), slice_cnt, fancy_position)
469
+ array = Tensor(index).reshape(shape)
470
+ slice_index_tensor = _broadcast(final_shape, array)
471
+ new_indies_tensor.append(F.cast(slice_index_tensor, mstype.int64))
472
+ slice_cnt += 1
473
+ new_index = stack(new_indies_tensor)
474
+ return new_index
475
+
476
+
477
+ def format_index_tensor(index, arg):
478
+ """Format index tensor when tensor less than 0"""
479
+ format_indices, format_dims = arg
480
+ if isinstance(index, list):
481
+ for format_idx, format_dim in zip(format_indices, format_dims):
482
+ index_tensor = index[format_idx]
483
+ index[format_idx] = F.select(index_tensor < 0, index_tensor + format_dim, index_tensor)
484
+ return index
485
+ index = Tensor(index)
486
+ format_dims = Tensor(format_dims)
487
+ return F.select(index < 0, index + format_dims, index)
488
+
489
+
490
+ def handle_setitem_by_bool_tensor(data, index, value):
491
+ """Set a tensor item by a bool tensor with a tensor."""
492
+ value = F.cast(value, F.dtype(data))
493
+ indices = index.nonzero()
494
+ if indices.shape[0] == 0:
495
+ return data
496
+ value_shape = (indices.shape[0],) + data.shape[index.ndim:]
497
+ value = _broadcast(value_shape, value)
498
+ value = F.scatter_nd(indices, value, data.shape)
499
+ index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
500
+ index = _broadcast(data.shape, index)
501
+ result = F.select(index, value, data)
502
+ return result
503
+
504
+
505
+ def _expand_data_dims(data, tuple_index):
506
+ """expand the data's dim with 'None' and 'Boolean' in tuple_index"""
507
+ indexes_types = hyper_map(toptypeof, tuple_index)
508
+ expand_positions, tuple_index_new = (), ()
509
+ for i, (index, index_type) in enumerate(zip(tuple_index, indexes_types)):
510
+ if isinstance(index_type, mstype.NoneType):
511
+ tuple_index_new += (const_utils.make_empty_slice(),)
512
+ expand_positions += (i,)
513
+ elif isinstance(index_type, mstype.Bool):
514
+ if not index:
515
+ const_utils.raise_index_error("Bool element of tuple index must be 'True', but got 'False'.")
516
+ tuple_index_new += (const_utils.make_tensor([0], mstype.int64),)
517
+ expand_positions += (i,)
518
+ else:
519
+ tuple_index_new += (index,)
520
+
521
+ for dim in expand_positions:
522
+ data = F.expand_dims(data, dim)
523
+
524
+ return data, tuple_index_new
525
+
526
+
527
+ def _convert_list_index_to_tensor(list_index):
528
+ """convert list to tensor"""
529
+ has_bool = False
530
+ has_int = False
531
+ has_no_bool_int = False
532
+ for idx in list_index:
533
+ if isinstance(idx, bool):
534
+ has_bool = True
535
+ elif isinstance(idx, int):
536
+ has_int = True
537
+ else:
538
+ has_no_bool_int = True
539
+
540
+ all_bool = has_bool and not has_int and not has_no_bool_int
541
+ all_int = has_int and not has_bool and not has_no_bool_int
542
+ all_bool_or_int = not has_no_bool_int
543
+
544
+ if all_int:
545
+ index_tensor = TupleToTensor()(tuple(list_index), mstype.int64)
546
+ return index_tensor
547
+
548
+
549
+ if all_bool:
550
+ index_tensor = TupleToTensor()(tuple(list_index), mstype.bool_)
551
+ return index_tensor
552
+
553
+ # convert bool to int if index is mixture of (bool, int)
554
+ if all_bool_or_int:
555
+ new_index = []
556
+ for idx in list_index:
557
+ if isinstance(idx, bool):
558
+ new_idx = int(idx)
559
+ new_index.append(new_idx)
560
+ else:
561
+ new_index.append(idx)
562
+ index_tensor = TupleToTensor()(tuple(new_index), mstype.int64)
563
+ return index_tensor
564
+
565
+ return None
566
+
567
+
568
+ class _TensorIndexGetitem(base.TensorIndexGetitem_):
569
+ """
570
+ Getting item of Tensor.
571
+
572
+ Args:
573
+ data (Tensor): A tuple to be sliced.
574
+ index: Index of tensor.
575
+
576
+ Returns:
577
+ Type is the same as the element type of data.
578
+ """
579
+
580
+ def __call__(self, *args):
581
+ pass
582
+
583
+ _tensor_index_getitem = _TensorIndexGetitem('tensor_index_getitem')
584
+
585
+
586
+ def tensor_index_by_slice(data, slice_index):
587
+ """Tensor getitem by a slice."""
588
+ return _tensor_index_getitem(data, slice_index)
589
+
590
+
591
+ def tensor_index_by_number(data, number_index):
592
+ """Tensor getitem by a Number which may be integer/float/bool value"""
593
+ if isinstance(number_index, bool):
594
+ return _tensor_index_by_bool(data, number_index)
595
+ if isinstance(number_index, int):
596
+ return _tensor_index_by_integer(data, number_index)
597
+ exp_msg = const_utils.gen_exception_msg(
598
+ "Number index of tensor must be int or bool, but got {}.", number_index)
599
+ return const_utils.raise_index_error(exp_msg)
600
+
601
+
602
+ def _tensor_index_by_bool(data, bool_value):
603
+ """Tensor getitem by a single bool value"""
604
+ min_data_dim, max_data_dim = 0, 7
605
+ const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
606
+ output = data
607
+ if bool_value:
608
+ output = F.expand_dims(data, 0)
609
+ elif not F.is_sequence_value_unknown(F.shape(data)):
610
+ return const_utils.raise_index_error("When tensor is indexed by a bool object, the value only support 'True'.")
611
+ return output
612
+
613
+
614
+ def get_stride_info_from_integer(int_index):
615
+ """Convert integer to slice"""
616
+ begin_strides = (int_index,)
617
+ end_strides = (int_index + 1,)
618
+ step_strides = (1,)
619
+ return begin_strides, end_strides, step_strides
620
+
621
+
622
+ def _tensor_index_by_integer(data, int_index):
623
+ """Tensor getitem by a single integer number"""
624
+ begin_strides, end_strides, step_strides = get_stride_info_from_integer(int_index)
625
+
626
+ shrink_axis_mask = 1
627
+ begin_mask = 0
628
+ end_mask = 0
629
+ for i in range(2, 8):
630
+ begin_mask += 2 ** i
631
+ end_mask += 2 ** i
632
+ return strided_slice(data, begin_strides, end_strides, step_strides, begin_mask, end_mask, 0, 0, shrink_axis_mask)
633
+
634
+ def _check_dim_shape_valid(data, tensor_index):
635
+ """check dim and shape of tensor_index for tensor(bool) indexing"""
636
+ if data.ndim < tensor_index.ndim:
637
+ raise IndexError(f"The dim of index cannot be greater than indexed data, but got "
638
+ f"dim of index:{tensor_index.ndim}, dim of data:{data.ndim}")
639
+ if data.shape[:tensor_index.ndim] != tensor_index.shape[:]:
640
+ raise IndexError(f"The shape of index {tensor_index.shape} does not match the shape "
641
+ f"of the indexed data {data.shape}")
642
+
643
+
644
+ def tensor_index_by_bool_tensor(data, tensor_index):
645
+ """Tensor getitem by a bool tensor"""
646
+ if not F.is_sequence_value_unknown(F.shape(data)):
647
+ _check_dim_shape_valid(data, tensor_index)
648
+ tensor_index = tensor_index.nonzero()
649
+ return F.gather_nd(data, tensor_index)
650
+
651
+
652
+ def tensor_index_by_tensor(data, tensor_index):
653
+ """Tensor getitem by a single tensor"""
654
+ min_data_dim, max_data_dim = 0, 7
655
+ if not F.is_sequence_value_unknown(F.shape(data)):
656
+ const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
657
+ if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Int):
658
+ tensor_index = F.select(tensor_index < 0, tensor_index + F.shape(data)[0], tensor_index)
659
+ return F.gather(data, tensor_index, 0)
660
+ if const_utils.check_type_isinstance(F.dtype(tensor_index), mstype.Bool):
661
+ return tensor_index_by_bool_tensor(data, tensor_index)
662
+ exp_msg = const_utils.gen_exception_msg(
663
+ "The tensor index must be int or bool type, but got {}.", F.dtype(tensor_index))
664
+ const_utils.raise_index_error(exp_msg)
665
+ return data
666
+
667
+
668
+ def tensor_index_by_list(data, list_index):
669
+ """Tensor getitem by list of int and bool"""
670
+ min_data_dim, max_data_dim = 1, 8
671
+ if F.isconstant(data.ndim):
672
+ const_utils.judge_data_dim(data.ndim, min_data_dim, max_data_dim)
673
+
674
+ data_shape = F.shape(data)
675
+ if F.isconstant(data_shape[0]) and all(isinstance(i, bool) for i in list_index):
676
+ if data_shape[0] != len(list_index):
677
+ raise IndexError(
678
+ f'dimension is {data_shape[0]} but corresponding boolean dimension is {len(list_index)}')
679
+ tensor_index = Tensor(list_index).nonzero()
680
+ return F.gather_nd(data, tensor_index)
681
+
682
+ if not list_index:
683
+ const_utils.raise_index_error("When tensor is indexed by list, the list can't be empty.")
684
+
685
+ index_tensor = _convert_list_index_to_tensor(list_index)
686
+ if index_tensor is not None:
687
+ return tensor_index_by_tensor(data, index_tensor)
688
+
689
+ tuple_index_new = ()
690
+ for index in list_index:
691
+ tuple_index_new += (index,)
692
+ return tensor_index_by_tuple(data, tuple_index_new)
693
+
694
+
695
+ def judge_tuple_index_dim_check_error(index_dim, data_dim):
696
+ """raise IndexError when tuple_index's dim is invalid"""
697
+ if index_dim > data_dim:
698
+ raise IndexError(f"The dim of index cannot be greater than indexed data, but got "
699
+ f"dim of index:{index_dim}, dim of data:{data_dim}")
700
+
701
+
702
+ def judge_tuple_index_dim(data, tuple_index):
703
+ """Judge whether tuple_index's dim is valid"""
704
+ data_dim = data.ndim
705
+ index_dim = 0
706
+ for index in tuple_index:
707
+ if isinstance(toptypeof(index), mstype.TensorType) and index.dtype == mstype.bool_:
708
+ index_dim += index.ndim
709
+ elif not isinstance(toptypeof(index), (mstype.NoneType, mstype.Ellipsis_, mstype.Bool)):
710
+ index_dim += 1
711
+ judge_tuple_index_dim_check_error(index_dim, data_dim)
712
+
713
+
714
+ def tensor_index_by_tuple(data, tuple_index):
715
+ """Tensor getitem by tuple of various types with None"""
716
+ if not tuple_index:
717
+ return data
718
+
719
+ if not F.is_sequence_value_unknown(F.shape(data)):
720
+ judge_tuple_index_dim(data, tuple_index)
721
+ tuple_index, zero_index, non_zero_shapes = _handle_bool_tensor(tuple_index)
722
+ for non_zero_shape in non_zero_shapes:
723
+ if 0 in non_zero_shape:
724
+ tuple_index = zero_index
725
+ break
726
+
727
+ return _tensor_index_getitem(data, tuple_index)
728
+
729
+
730
+ def get_slice_stride(slice_index, dim_size):
731
+ """Get slice stride info"""
732
+ start = slice_get_item(slice_index, "start")
733
+ stop = slice_get_item(slice_index, "stop")
734
+ step = slice_get_item(slice_index, "step")
735
+
736
+ if start is None:
737
+ start = 0
738
+ if stop is None:
739
+ stop = dim_size
740
+ if step is None:
741
+ step = 1
742
+
743
+ if isinstance(start, Tensor):
744
+ start = int(start)
745
+
746
+ if isinstance(stop, Tensor):
747
+ stop = int(stop)
748
+
749
+ if isinstance(step, Tensor):
750
+ step = int(step)
751
+
752
+ return start, stop, step
753
+
754
+
755
+ def cal_tuple_slice_mask(data_shape, tuple_index):
756
+ """calculate the strided_slice begin and end mask"""
757
+ begin_mask = 0
758
+ end_mask = 0
759
+ for i, slice_index in enumerate(tuple_index):
760
+ if isinstance(slice_index, slice):
761
+ begin_mask += 2 ** i if slice_get_item(slice_index, "start") is None else 0
762
+ end_mask += 2 ** i if slice_get_item(slice_index, "stop") is None else 0
763
+ for i in range(len(tuple_index), len(data_shape)):
764
+ begin_mask += 2 ** i
765
+ end_mask += 2 ** i
766
+ return begin_mask, end_mask
767
+
768
+
769
+ def _generate_indices_from_tuple_of_tensor(tuple_index, op_name):
770
+ """Generate an indices tensor from a tuple of tensor."""
771
+ indexes_types = hyper_map(F.dtype, tuple_index)
772
+ const_utils.check_types_valid(indexes_types, mstype.int_type, op_name)
773
+ tensor_index_shape = hyper_map(F.shape, tuple_index)
774
+ broadcast_shape = const_utils.generate_broadcast_shape(tensor_index_shape, op_name)
775
+ if len(broadcast_shape) < 2:
776
+ broadcast_shape = (1,) + broadcast_shape
777
+ broadcast_tensors = hyper_map(F.partial(_broadcast, broadcast_shape), tuple_index)
778
+ new_broadcast_tensors = ()
779
+ for tensor in broadcast_tensors:
780
+ new_broadcast_tensors += (F.cast(tensor, mstype.int64),)
781
+ indices = stack(new_broadcast_tensors)
782
+ return indices
783
+
784
+
785
+ def parse_check_slice_index(index_out, dim_size):
786
+ """ Parse and check slice index """
787
+ has_false = False
788
+ start, stop, step = const_utils.normalize_slice(index_out, dim_size)
789
+ if F.isconstant(start) and F.isconstant(stop) and F.isconstant(step):
790
+ has_false = const_utils.check_slice_empty(start, stop, step)
791
+ return has_false
792
+
793
+
794
+ def _generate_indices_from_tuple(data, tuple_index, op_name, fancy_position):
795
+ """Generate an indices tensor from a tuple that contains slice, int, ellipsis, tensor."""
796
+ data_shape = F.shape(data)
797
+ tensor_indexes, slice_indexes = [], []
798
+ indexes_types = hyper_map(toptypeof, tuple_index)
799
+ slice_positions, _, _, int_positions, _, tensor_positions, sequence_positions = \
800
+ const_utils.get_pos_of_indexes_types(indexes_types, op_name)
801
+ tuple_index_new, slice_shapes = (), ()
802
+ for i, (index, dim_size) in enumerate(zip(tuple_index, data_shape)):
803
+ if i in int_positions:
804
+ int_index = const_utils.check_range(index, dim_size)
805
+ tensor_index = F.scalar_to_tensor(int_index, mstype.int64)
806
+ tuple_index_new += (tensor_index,)
807
+ tensor_indexes.append(tensor_index)
808
+ tensor_positions += (i,)
809
+ elif i in sequence_positions:
810
+ tensor_index = const_utils.sequence_to_index(index, dim_size)
811
+ tuple_index_new += (tensor_index,)
812
+ tensor_indexes.append(tensor_index)
813
+ tensor_positions += (i,)
814
+ elif i in tensor_positions:
815
+ invalid = const_utils.check_type_invalid(F.dtype(index), mstype.int_type)
816
+ if invalid:
817
+ exp_msg = const_utils.gen_exception_msg(
818
+ "The tensor element in tuple index must be int or bool type, but got {}.", F.dtype(index))
819
+ const_utils.raise_index_error(exp_msg)
820
+ tensor_index = F.cast(index, mstype.int64)
821
+ tuple_index_new += (tensor_index,)
822
+ tensor_indexes.append(tensor_index)
823
+ elif i in slice_positions:
824
+ if parse_check_slice_index(index, dim_size):
825
+ return False
826
+ slice_ele_list_index = const_utils.transform_slice_to_ele_list(index, dim_size)
827
+ slice_shapes += (len(slice_ele_list_index),)
828
+ tuple_index_new += (slice_ele_list_index,)
829
+ slice_indexes.append(slice_ele_list_index)
830
+
831
+ tensor_indexes_shapes = hyper_map(F.shape, tensor_indexes)
832
+ broadcast_shape, index_tensor_new_shape, final_shape, fancy_position = \
833
+ const_utils.generate_index_info_from_tuple_of_mixed_tensors(tensor_positions, tensor_indexes_shapes,
834
+ slice_shapes, op_name, fancy_position)
835
+
836
+ final_index_tensors = []
837
+ slice_cnt = 0
838
+ for i, index in enumerate(tuple_index_new):
839
+ if i in tensor_positions:
840
+ transform_tensor = _transform_indexing_tensor(broadcast_shape, final_shape, index_tensor_new_shape,
841
+ index)
842
+ final_index_tensors.append(transform_tensor)
843
+ elif i in slice_positions:
844
+ slice_index_tensor = convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape,
845
+ slice_shapes, fancy_position)
846
+ final_index_tensors.append(slice_index_tensor)
847
+ slice_cnt += 1
848
+
849
+ indices = stack(final_index_tensors)
850
+ return indices
851
+
852
+
853
+ def sequence_to_tensor(value, dtype):
854
+ """Generate an updates tensor from a tuple, can only handle 1-D tensor/non-tensor mixtures."""
855
+ value_types = hyper_map(toptypeof, value)
856
+ value_elements_type = const_utils.check_value_elements(value_types)
857
+
858
+ if value_elements_type == const_utils.ALL_TENSOR:
859
+ value = F.stack(value).astype(dtype)
860
+ elif value_elements_type == const_utils.NO_TENSOR:
861
+ if isinstance(value, list):
862
+ value = tuple(value)
863
+
864
+ if dtype == mstype.float16:
865
+ value = TupleToTensor()(value, mstype.float32)
866
+ value = F.cast(value, dtype)
867
+ else:
868
+ value = TupleToTensor()(value, dtype)
869
+ else:
870
+ new_value = ()
871
+ for ele in value:
872
+ ele = ele if isinstance(ele, Tensor) else const_utils.make_tensor(ele, dtype)
873
+ new_value += (ele,)
874
+ value = F.stack(new_value).astype(dtype)
875
+ return value
876
+
877
+
878
+ def _generate_updates_from_sequence(data, index, value, op_type):
879
+ """Generate an updates tensor from a tuple, can only handle 1-D tensor/non-tensor mixtures."""
880
+ value = sequence_to_tensor(value, F.dtype(data))
881
+ if op_type == const_utils.SET_ITEM_BY_NON_TENSOR:
882
+ return value
883
+ return _generate_updates_from_tensor(data, index, value, op_type)
884
+
885
+
886
+ def _generate_updates_from_tensor(data, index, value, op_type):
887
+ """Generate an updates tensor from a tensor."""
888
+ value = value.astype(data.dtype)
889
+ updates_shape = const_utils.generate_updates_shape(data.shape, index.shape, op_type)
890
+ updates = ops.broadcast_to(value, updates_shape)
891
+ return updates
892
+
893
+
894
+ # Tensor getitem implementations are above this line, setitem implementations below.
895
+
896
+ def _tensor_index_transfer(index, broadcast_shape, final_shape, new_shape):
897
+ """Transform tuple index tensor to the required."""
898
+ if 0 in final_shape:
899
+ return F.fill(index.dtype, final_shape, 0)
900
+
901
+ if broadcast_shape == ():
902
+ # broadcast_to () is not support on Ascend
903
+ item = index
904
+ else:
905
+ item = F.broadcast_to(index, broadcast_shape)
906
+ item = F.reshape(item, new_shape)
907
+ return F.broadcast_to(item, final_shape)
908
+
909
+
910
+ def reshape_with_check(x, new_shape):
911
+ if isinstance(new_shape, Tensor):
912
+ new_shape = TensorToTuple()(new_shape)
913
+ return F.reshape(x, new_shape)
914
+
915
+
916
+ class _TensorIndexSetitem(base.TensorIndexSetitem_):
917
+ """
918
+ Getting item of Tensor.
919
+
920
+ Args:
921
+ data (Tensor): A tuple to be sliced.
922
+ index: Index of tensor.
923
+
924
+ Returns:
925
+ Type is the same as the element type of data.
926
+ """
927
+
928
+ def __call__(self, *args):
929
+ pass
930
+
931
+
932
+ _tensor_index_setitem = _TensorIndexSetitem('tensor_index_setitem')
933
+
934
+
935
+ def tensor_setitem_by_slice(self, index, value):
936
+ """Set a tensor item by slice."""
937
+ indices, value_shape, start, stop, step, value = _tensor_index_setitem(
938
+ self, index, value)
939
+ if start == stop:
940
+ return self
941
+ value = F.broadcast_to(value, value_shape)
942
+ if not const_utils.is_ascend() and step == 1:
943
+ start = (start,)
944
+ stop = (stop,)
945
+ step = (step,)
946
+ return copy_slice(self, value, start, stop, step)
947
+ return F.tensor_scatter_update(self, indices, value)
948
+
949
+
950
+ def tensor_setitem_by_ellipsis(self, index, value):
951
+ if isinstance(value, (int, float, bool)):
952
+ return tensor_setitem_by_ellipsis_with_number(self, value)
953
+ if isinstance(value, Tensor):
954
+ return tensor_setitem_by_ellipsis_with_tensor(self, value)
955
+ return tensor_setitem_by_ellipsis_with_sequence(self, value)
956
+
957
+
958
+ def _tensor_setitem_by_int_tensor_with_tensor(data, index, value):
959
+ """Set a tensor item by an int tensor with a tensor."""
960
+ if F.rank(index) == 0:
961
+ index = F.expand_dims(index, -1)
962
+
963
+ data_shape = F.shape(data)
964
+ updates_shape = index.shape + data_shape[1:]
965
+ value = F.cast(value, F.dtype(data))
966
+ updates = ops.broadcast_to(value, updates_shape)
967
+ first_val = data_shape[0]
968
+ index = F.select(index < 0, index + first_val, index)
969
+ index = F.expand_dims(index, -1)
970
+ if is_parameter(data):
971
+ F.scatter_nd_update(data, index, updates)
972
+ return data
973
+ return F.tensor_scatter_update(data, index, updates)
974
+
975
+
976
+ def _tensor_setitem_by_bool_tensor_with_tensor(data, index, value):
977
+ """Set a tensor item by a bool tensor with a tensor."""
978
+ index = index.reshape(const_utils.generate_padding_shape(index.shape, len(data.shape)))
979
+ index = F.broadcast_to(index, data.shape)
980
+ value = F.cast(value, F.dtype(data))
981
+ value = value.reshape(const_utils.generate_padding_shape(value.shape, len(data.shape)))
982
+ value = F.broadcast_to(value, data.shape)
983
+ result = F.select(index, value, data)
984
+ return result
985
+
986
+
987
+ def tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):
988
+ """setitem by tensor index(dtype is int or bool) with tensor as value"""
989
+ index_dtype = F.dtype(index)
990
+ tensor_dtype = const_utils.get_index_tensor_dtype(index_dtype)
991
+ if tensor_dtype == const_utils.INT_:
992
+ return _tensor_setitem_by_int_tensor_with_tensor(data, index, value_tensor)
993
+
994
+ return _tensor_setitem_by_bool_tensor_with_tensor(data, index, value_tensor)
995
+
996
+
997
+ def tensor_setitem_by_tensor_with_number(data, index, value):
998
+ value = F.cast(value, F.dtype(data))
999
+ return tensor_setitem_by_tensor_with_tensor(data, index, value)
1000
+
1001
+
1002
+ def tensor_setitem_by_tensor_with_sequence(data, index, value):
1003
+ """Assigns the tensor by tensor with tuple value."""
1004
+ value = sequence_to_tensor(value, F.dtype(data))
1005
+ return tensor_setitem_by_tensor_with_tensor(data, index, value)
1006
+
1007
+
1008
+ def tensor_setitem_by_tuple_with_number(data, tuple_index, value):
1009
+ """Assigns the tensor by tuple with number value."""
1010
+ value = F.cast(value, F.dtype(data))
1011
+ return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
1012
+
1013
+
1014
+ def tensor_setitem_by_list(data, index, value):
1015
+ """list indices will be converted to tuple or tensor based on its contents."""
1016
+ index_tensor = _convert_list_index_to_tensor(index)
1017
+ if index_tensor is not None:
1018
+ return tensor_setitem_by_tensor_with_tensor(data, index_tensor, value)
1019
+
1020
+ return tensor_setitem_by_tuple_with_tensor(data, tuple(index), value)
1021
+
1022
+
1023
+
1024
+ class _PreSetitemByTuple(base.PreSetitemByTuple_):
1025
+ """
1026
+ Getting item of Tensor.
1027
+
1028
+ Args:
1029
+ data (Tensor): A tuple to be sliced.
1030
+ index: Index of tensor.
1031
+
1032
+ Returns:
1033
+ Type is the same as the element type of data.
1034
+ """
1035
+
1036
+ def __init__(self, name):
1037
+ """Initialize _PreSetitemByTuple."""
1038
+ base.PreSetitemByTuple_.__init__(self, name)
1039
+
1040
+ def __call__(self, *args):
1041
+ pass
1042
+
1043
+
1044
+ _pre_setitem_by_tuple = _PreSetitemByTuple('pre_setitem_by_tuple')
1045
+
1046
+
1047
+ class _HandleBoolTensor(base.HandleBoolTensor_):
1048
+ """
1049
+ Getting item of Tensor.
1050
+
1051
+ Args:
1052
+ data (Tensor): A tuple to be sliced.
1053
+ index: Index of tensor.
1054
+
1055
+ Returns:
1056
+ Type is the same as the element type of data.
1057
+ """
1058
+
1059
+ def __init__(self, name):
1060
+ """Initialize _HandleBoolTensor."""
1061
+ base.HandleBoolTensor_.__init__(self, name)
1062
+
1063
+ def __call__(self, *args):
1064
+ pass
1065
+
1066
+
1067
+ _handle_bool_tensor = _HandleBoolTensor('handle_bool_tensor')
1068
+
1069
+
1070
+ def tensor_setitem_by_tuple_with_tensor(data, tuple_index, value):
1071
+ """Assigns the tensor by tuple with tensor value."""
1072
+ if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
1073
+ dim1_start, dim1_stop, _ = const_utils.normalize_slice(
1074
+ tuple_index[1], data.shape[1])
1075
+ if isinstance(dim1_start, Tensor):
1076
+ dim1_start = int(dim1_start)
1077
+ if isinstance(dim1_stop, Tensor):
1078
+ dim1_stop = int(dim1_stop)
1079
+ if dim1_stop - dim1_start <= 0:
1080
+ return data
1081
+ dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
1082
+ start = (dim0_start, dim1_start)
1083
+ stop = (dim0_start + 1, dim1_stop)
1084
+ step = (1, 1)
1085
+ value_shape = (dim1_stop - dim1_start,) + data.shape[2:]
1086
+ value = F.broadcast_to(value, value_shape)
1087
+ return copy_slice(data, value.astype(data.dtype), start, stop, step)
1088
+ tuple_index, _, non_zero_shapes = _handle_bool_tensor(tuple_index)
1089
+
1090
+ for non_zero_shape in non_zero_shapes:
1091
+ if 0 in non_zero_shape:
1092
+ return data
1093
+ value = value.astype(data.dtype)
1094
+ special_index, tuple_index, new_value_shape, idx_advanced, _broadcast_data_shape \
1095
+ = _pre_setitem_by_tuple(data, tuple_index, value)
1096
+ if special_index == 0:
1097
+ return data
1098
+ value = F.reshape(value, new_value_shape)
1099
+ if not tuple_index or special_index == 1:
1100
+ data[True] = value
1101
+ return data
1102
+
1103
+ empty_broadcast_data_shape = False
1104
+ if isinstance(_broadcast_data_shape, Tensor) and _broadcast_data_shape == Tensor([0]):
1105
+ empty_broadcast_data_shape = True
1106
+ if isinstance(_broadcast_data_shape, tuple) and not _broadcast_data_shape:
1107
+ empty_broadcast_data_shape = True
1108
+ indices = _tensor_index_setitem(
1109
+ data, tuple_index, value, idx_advanced, empty_broadcast_data_shape)
1110
+
1111
+ updates = _generate_updates_from_tensor(
1112
+ data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
1113
+ if is_parameter(data):
1114
+ F.scatter_nd_update(data, indices, updates)
1115
+ return data
1116
+ return F.tensor_scatter_update(data, indices, updates)
1117
+
1118
+ def tensor_itemset_by_tuple_with_tensor(data, tuple_index, value):
1119
+ """Assigns the tensor by tuple with tensor value."""
1120
+ op_name = const_utils.TENSOR_SETITEM
1121
+ tuple_index = _transform_ellipsis_to_slice(data, tuple_index, op_name)
1122
+
1123
+ if const_utils.use_copy_slice(tuple_index) and not const_utils.is_ascend():
1124
+ dim1_start, dim1_stop, _ = const_utils.normalize_slice(tuple_index[1], data.shape[1])
1125
+ if isinstance(dim1_start, Tensor):
1126
+ dim1_start = int(dim1_start)
1127
+ if isinstance(dim1_stop, Tensor):
1128
+ dim1_stop = int(dim1_stop)
1129
+ if dim1_stop - dim1_start <= 0:
1130
+ return data
1131
+ dim0_start = tuple_index[0] if tuple_index[0] >= 0 else tuple_index[0] + data.shape[0]
1132
+ start = (dim0_start, dim1_start)
1133
+ stop = (dim0_start + 1, dim1_stop)
1134
+ step = (1, 1)
1135
+ value_shape = (dim1_stop - dim1_start,) + data.shape[2:]
1136
+ value = F.broadcast_to(value, value_shape)
1137
+ return copy_slice(data, value.astype(data.dtype), start, stop, step)
1138
+ tuple_index, value, idx_advanced = remove_expanded_dims(tuple_index, F.shape(data), value)
1139
+
1140
+ if tuple_index is False:
1141
+ return data
1142
+ if len(tuple_index) == 1:
1143
+ data[tuple_index[0]] = value
1144
+ return data
1145
+ indexes_types = hyper_map(toptypeof, tuple_index)
1146
+ contain_type = const_utils.tuple_index_type_cnt(indexes_types, op_name)
1147
+
1148
+ if contain_type == const_utils.ALL_TENSOR:
1149
+ indices = _generate_indices_from_tuple_of_tensor(tuple_index, op_name)
1150
+ else:
1151
+ indices = _generate_indices_from_tuple(data, tuple_index, op_name, idx_advanced)
1152
+ if indices is False:
1153
+ return data
1154
+ updates = _generate_updates_from_tensor(data, indices, value, const_utils.SET_ITEM_BY_TUPLE_OF_TENSOR)
1155
+ return F.tensor_scatter_update(data, indices, updates)
1156
+
1157
+
1158
+ def tensor_setitem_by_tuple_with_sequence(data, tuple_index, value):
1159
+ value = sequence_to_tensor(value, F.dtype(data))
1160
+ return tensor_setitem_by_tuple_with_tensor(data, tuple_index, value)
1161
+
1162
+
1163
+ def tensor_setitem_by_number_with_number(data, index, value):
1164
+ """Assigns the tensor by number with number value."""
1165
+ data_shape = F.shape(data)
1166
+ dim_size = data_shape[0]
1167
+ if index < 0:
1168
+ index += dim_size
1169
+ if index < -dim_size or index >= dim_size:
1170
+ raise IndexError(f'index {index} is out of bounds for axis 0 with size {dim_size}')
1171
+ index = F.cast(index, mstype.int64)
1172
+ index = F.reshape(index, (1, 1))
1173
+
1174
+ updates = F.cast(value, data.dtype)
1175
+ updates_shape = (1,) + data_shape[1:]
1176
+ updates = ops.broadcast_to(updates, updates_shape)
1177
+
1178
+ if is_parameter(data):
1179
+ F.scatter_nd_update(data, index, updates)
1180
+ return data
1181
+ return F.tensor_scatter_update(data, index, updates)
1182
+
1183
+
1184
+ def tensor_setitem_by_number_with_sequence(data, index, value):
1185
+ """Assigns a list/tuple value to the tensor by slice."""
1186
+ value = sequence_to_tensor(value, F.dtype(data))
1187
+ return tensor_setitem_by_number_with_tensor(data, index, value)
1188
+
1189
+
1190
+ def tensor_setitem_by_number_with_tensor(data, index, value):
1191
+ return tensor_setitem_by_number_with_number(data, index, value)
1192
+
1193
+
1194
+ def tensor_setitem_by_ellipsis_with_number(data, value):
1195
+ """Assigns the tensor by ellipsis with number value."""
1196
+ data_shape = F.shape(data)
1197
+ data_dtype = F.dtype(data)
1198
+ return F.fill(data_dtype, data_shape, value)
1199
+
1200
+
1201
+ def tensor_setitem_by_ellipsis_with_tensor(data, value):
1202
+ """Assigns the tensor by ellipsis with tensor value."""
1203
+ data_shape = F.shape(data)
1204
+ data_dtype = F.dtype(data)
1205
+ value = value.astype(data_dtype)
1206
+
1207
+ value_shape = F.shape(value)
1208
+
1209
+ if len(value_shape) > len(data_shape):
1210
+ source_shape = data_shape
1211
+ else:
1212
+ source_shape = value_shape
1213
+ value = F.reshape(value, source_shape)
1214
+ data = F.broadcast_to(value, data_shape)
1215
+ return data
1216
+
1217
+
1218
+ def tensor_setitem_by_ellipsis_with_sequence(data, value):
1219
+ """Assigns a list/tuple value to the tensor by ellipsis."""
1220
+ value = sequence_to_tensor(value, F.dtype(data))
1221
+ return tensor_setitem_by_ellipsis_with_tensor(data, value)
1222
+
1223
+
1224
+ def tensor_setitem_by_bool(data, index, value):
1225
+ """Assigns a value to the tensor by boolean."""
1226
+ data_shape = F.shape(data)
1227
+ data_dtype = F.dtype(data)
1228
+ if not index:
1229
+ data_shape = (0,) + data_shape
1230
+ if isinstance(value, (list, tuple)):
1231
+ value = sequence_to_tensor(value, data_dtype)
1232
+ else:
1233
+ value = F.cast(value, data_dtype)
1234
+
1235
+ if index:
1236
+ value_shape = F.shape(value)
1237
+ if len(value_shape) > len(data_shape):
1238
+ source_shape = data_shape
1239
+ else:
1240
+ source_shape = value_shape
1241
+ value = F.reshape(value, source_shape)
1242
+ data = F.broadcast_to(value, data_shape)
1243
+ return data
1244
+
1245
+
1246
+ def tensor_in_sequence(x, y):
1247
+ """Assigns whether a sequence contains the given tensor"""
1248
+ result = const_utils.scalar_to_tensor(False)
1249
+ for i in y:
1250
+ if isinstance(i, Tensor) and x.shape == i.shape and x.dtype == i.dtype:
1251
+ result = F.logical_or(F.equal(x, i).all(), result)
1252
+ return result
1253
+
1254
+
1255
+ @_primexpr
1256
+ def remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim):
1257
+ """ Parse bool tensor index """
1258
+ index_out = index_out.nonzero()
1259
+ if index_out.shape[0] == 0:
1260
+ return None, shapes, cur_dim
1261
+ for i in range(index_out.shape[1]):
1262
+ out = index_out[:, i]
1263
+ indices_out += (out,)
1264
+ shapes.append(F.shape(out))
1265
+ cur_dim += 1
1266
+ return indices_out, shapes, cur_dim
1267
+
1268
+
1269
+ def remove_expanded_dims_parse_tensor_index(index_out, indices_out, shapes, cur_dim):
1270
+ """ Parse tensor index """
1271
+ if index_out.dtype == mstype.bool_:
1272
+ return remove_expanded_dims_parse_bool_tensor_index(index_out, indices_out, shapes, cur_dim)
1273
+ indices_out += (index_out,)
1274
+ shapes.append(F.shape(index_out))
1275
+ cur_dim += 1
1276
+ return indices_out, shapes, cur_dim
1277
+
1278
+
1279
+ def remove_expanded_dims(tuple_index, data_shape, value):
1280
+ """Removes expanded dimensions in tuple_index and value."""
1281
+ not_expanded_dim = ()
1282
+ shapes = []
1283
+ has_true = False
1284
+ has_false = False
1285
+ has_sequence = False
1286
+ indices_out = () # with dimension expansion indices removed
1287
+ idx_tensor = -1 # index of the previous tensor
1288
+ idx_advanced = -1 # index of the first advanced index in expanded tensor
1289
+ cur_dim = 0 # current dimension of the data to be indexed
1290
+
1291
+ for i, v in enumerate(tuple_index):
1292
+ index_out = format_index(v, data_shape, cur_dim)
1293
+
1294
+ if index_out is None:
1295
+ not_expanded_dim += (False,)
1296
+ elif const_utils.is_slice(index_out):
1297
+ indices_out += (index_out,)
1298
+ not_expanded_dim += (True,)
1299
+ has_false = has_false or parse_check_slice_index(
1300
+ index_out, data_shape[cur_dim])
1301
+ cur_dim += 1
1302
+ elif isinstance(index_out, (Tensor, bool)): # advanced index
1303
+ if idx_advanced == -1:
1304
+ idx_advanced = len(not_expanded_dim)
1305
+ elif i - idx_tensor > 1:
1306
+ idx_advanced = 0
1307
+ idx_tensor = i
1308
+ if isinstance(index_out, Tensor):
1309
+ indices_out, shapes, cur_dim = \
1310
+ remove_expanded_dims_parse_tensor_index(index_out, indices_out, shapes, cur_dim)
1311
+ if indices_out is None:
1312
+ return False, value, 0
1313
+ if index_out.dtype != mstype.bool_ and F.rank(index_out) > 0:
1314
+ has_sequence = True
1315
+ has_true = has_true or index_out is True
1316
+ has_false = has_false or index_out is False
1317
+ else:
1318
+ const_utils.raise_index_error('invalid index type')
1319
+
1320
+ broadcast_shape = const_utils.generate_broadcast_shape(shapes, const_utils.TENSOR_SETITEM)
1321
+ if has_false:
1322
+ if F.shape_mul(broadcast_shape) != 1:
1323
+ const_utils.raise_index_error('unable to broadcast indices')
1324
+ indices_out = False
1325
+ else:
1326
+ expand_true = has_true and not (has_false or has_sequence) # whether to expand dimension at True
1327
+ tensor_index_ndim = len(broadcast_shape) # ndim of tensor indices
1328
+ rem_ndim = len(data_shape) - cur_dim # number of remaining dimensions in data not indexed
1329
+ not_expanded_dim, idx_advanced = const_utils.rem_not_expanded_dims(idx_advanced, expand_true,
1330
+ tensor_index_ndim,
1331
+ rem_ndim, not_expanded_dim)
1332
+ if not indices_out:
1333
+ indices_out = (True,)
1334
+
1335
+ value_shape = const_utils.filter_expanded_dims(F.shape(value), not_expanded_dim)
1336
+ value = F.reshape(value, value_shape)
1337
+ return indices_out, value, idx_advanced
1338
+
1339
+
1340
+ def format_index(idx, data_shape, cur_dim):
1341
+ """Converts advanced index into tensor."""
1342
+ if isinstance(idx, (tuple, list)):
1343
+ idx = const_utils.sequence_to_index(idx, data_shape[cur_dim])
1344
+ elif isinstance(idx, int) and not isinstance(idx, bool):
1345
+ idx = const_utils.make_tensor(idx, mstype.int64, None, data_shape[cur_dim])
1346
+ elif isinstance(idx, Tensor):
1347
+ tensor_dtype = const_utils.get_index_tensor_dtype(idx.dtype)
1348
+ if tensor_dtype == const_utils.INT_:
1349
+ idx = F.select(idx < 0, idx + data_shape[cur_dim], idx)
1350
+ elif tensor_dtype == const_utils.BOOL_:
1351
+ # index with tensor(bool) type is processed in remove_expanded_dims()
1352
+ pass
1353
+ return idx
1354
+
1355
+
1356
+ @_primexpr
1357
+ def _check_shape_mul(shape):
1358
+ if F.shape_mul(shape) == 0:
1359
+ raise ValueError('zero-size tensors are not supported.')
1360
+
1361
+
1362
+ def reduce_(a, reduce_fn, cmp_fn=None, axis=None, keepdims=False, initial=None, where=True, dtype=None):
1363
+ """
1364
+ Applies comparison based on cmp_fn and reduction based on reduce_fn.
1365
+ If cmp_fn is None, only reduction is performed.
1366
+ """
1367
+
1368
+ shape = F.shape(a)
1369
+ ndim = F.rank(a)
1370
+ if dtype is None:
1371
+ dtype = F.dtype(a)
1372
+ axes = validator.check_axis_valid(axis, ndim)
1373
+ if initial is not None:
1374
+ if ((isinstance(initial, Tensor) and F.rank(initial) > 0) or
1375
+ not isinstance(initial, (int, float, bool, Tensor))):
1376
+ const_utils.raise_type_error('initial must be scalar')
1377
+
1378
+ _check_shape_mul(shape)
1379
+
1380
+ if initial is not None:
1381
+ if isinstance(initial, Tensor):
1382
+ initial = F.tile(initial, shape).astype(dtype)
1383
+ else:
1384
+ initial = F.fill(dtype, shape, initial)
1385
+ a = cmp_fn(a, initial)
1386
+
1387
+ if where is not None and not isinstance(where, Tensor):
1388
+ where = Tensor(where, dtype=mstype.bool_)
1389
+
1390
+ if where is not None and (where.shape or not where):
1391
+ if initial is None:
1392
+ const_utils.raise_value_error('initial value must be provided for where masks')
1393
+ ndim_orig = F.rank(a)
1394
+ # broadcasts input tensors
1395
+ shape_out = const_utils.infer_out_shape(F.shape(where), F.shape(a), F.shape(initial))
1396
+ where = where.astype(mstype.float32)
1397
+ where = F.broadcast_to(where, shape_out)
1398
+ where = where.astype(mstype.bool_)
1399
+ a = F.broadcast_to(a, shape_out)
1400
+ initial = F.broadcast_to(initial, shape_out)
1401
+ a = F.select(where, a, initial)
1402
+ axes = const_utils.real_axes(ndim_orig, F.rank(a), axes)
1403
+
1404
+ return reduce_fn(a, axes).astype(dtype)
1405
+
1406
+
1407
+ setattr(tensor_operator_registry, "reduce", reduce_)
1408
+
1409
+
1410
+ def check_indices(dims, indices, mode, allow_negative_index=True):
1411
+ """Checks whether indices are out of bounds."""
1412
+ shape = F.shape(indices)
1413
+ dtype = F.dtype(indices)
1414
+ if not allow_negative_index:
1415
+ lowerbounds = F.fill(dtype, shape, 0)
1416
+ else:
1417
+ lowerbounds = F.fill(dtype, shape, -dims)
1418
+ upperbounds = F.fill(dtype, shape, dims - 1)
1419
+ out_of_lowerbounds = F.tensor_lt(indices, lowerbounds)
1420
+ out_of_upperbounds = F.tensor_gt(indices, upperbounds)
1421
+ if mode == 'raise':
1422
+ const_utils.raise_unimplemented_error('"raise" mode is not implemented')
1423
+ if mode == 'wrap':
1424
+ bounds = F.fill(dtype, shape, dims)
1425
+ quotient = F.tensor_floordiv(indices, bounds)
1426
+ prod = F.tensor_mul(bounds, quotient)
1427
+ return F.tensor_sub(indices, prod)
1428
+ zeros = F.fill(dtype, shape, 0)
1429
+ clipped = F.select(out_of_lowerbounds, zeros, indices)
1430
+ clipped = F.select(out_of_upperbounds, upperbounds, clipped)
1431
+ return clipped
1432
+
1433
+
1434
+ setattr(tensor_operator_registry, 'check_indices', check_indices)
1435
+
1436
+
1437
+ def convert_slice_to_tensor(index, final_shape, slice_cnt, broadcast_shape, slice_shapes, fancy_position):
1438
+ """Convert a slice to a tensor."""
1439
+ shape = const_utils.compute_slice_shape(slice_shapes, len(broadcast_shape), slice_cnt, fancy_position)
1440
+ array = const_utils.make_tensor(index, mstype.int64).reshape(shape)
1441
+ reps = const_utils.compute_multiples(shape, final_shape)
1442
+ slice_index_tensor = F.tile(array, reps)
1443
+ return slice_index_tensor
1444
+
1445
+
1446
+ def check_coo_tensor_input_length(coo_tuple):
1447
+ """Check length of coo tensor."""
1448
+ coo_length = 3
1449
+ if len(coo_tuple) != coo_length:
1450
+ raise ValueError(f"Expect coo_tuple have 3 inputs (indices, values, shape), but got {len(coo_tuple)}.")
1451
+ return coo_tuple
1452
+
1453
+
1454
+ def check_csr_tensor_input_length(csr_tuple):
1455
+ """Check length of csr tensor."""
1456
+ csr_length = 4
1457
+ if len(csr_tuple) != csr_length:
1458
+ raise ValueError(f"Expect csr_tuple have 4 inputs (indptr, indices, values, shape), but got {len(csr_tuple)}.")
1459
+ return csr_tuple