mindspore 2.4.0__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-310-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-310-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-310-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1425 @@
1
+ # Copyright 2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """Defines gradient related operators with functional form."""
17
+ from __future__ import absolute_import
18
+ from functools import partial
19
+ import numpy as np
20
+ from mindspore.common import jit, mutable
21
+ from mindspore.common import Tensor
22
+ from mindspore.common import dtype as mstype
23
+ from mindspore.nn.cell import Cell
24
+ from mindspore.nn.grad.cell_grad import _LinearizeInner
25
+ from mindspore.ops.operations.other_ops import stop_gradient_
26
+ from mindspore.ops.primitive import constexpr, _primexpr
27
+ from mindspore.ops.function.array_func import ones, expand_dims, size, reshape, broadcast_to, transpose, zeros
28
+ from mindspore.ops.composite import _Vmap, _Grad, _TaylorOperation, GradOperation
29
+ from mindspore.ops import operations as P
30
+ from mindspore.ops.operations import _inner_ops as inner
31
+
32
+ cast = P.Cast()
33
+ dtype = P.DType()
34
+ oneslike = P.OnesLike()
35
+
36
+
37
+ @constexpr
38
+ def _check_has_aux_type(inputs):
39
+ if not isinstance(inputs, bool):
40
+ raise TypeError("The 'has_aux' must be bool type.")
41
+
42
+
43
+ @constexpr
44
+ def _raise_type_error():
45
+ raise TypeError("The inputs type must be a Tensor, tuple or list of Tensors.")
46
+
47
+
48
+ @constexpr
49
+ def _check_duplicate_grad_position(grad_position):
50
+ """Check if `grad_position` has duplicate positions when `grad_position` has more than one numbers."""
51
+ if len(set(grad_position)) != len(grad_position):
52
+ raise ValueError("There are duplicate positions in `grad_position`, please check it")
53
+
54
+
55
+ @constexpr
56
+ def _convert_grad_position_type(grad_position):
57
+ """Check and convert the type and size of grad position index."""
58
+ if isinstance(grad_position, tuple):
59
+ _check_duplicate_grad_position(grad_position)
60
+ _grad_position = list(grad_position)
61
+ for i, gp in enumerate(_grad_position):
62
+ if isinstance(gp, bool):
63
+ _grad_position[i] = int(gp)
64
+ if not isinstance(gp, int):
65
+ raise TypeError(f"For 'F.grad', the element in 'grad_position' must be int.")
66
+ if gp < 0:
67
+ raise ValueError("The element in grad_position must be >= 0.")
68
+ grad_position = tuple(_grad_position)
69
+ elif isinstance(grad_position, int):
70
+ if grad_position < 0:
71
+ raise ValueError("grad_position must be >= 0.")
72
+ grad_position = (grad_position,)
73
+ else:
74
+ raise TypeError(f"For 'F.grad', the 'grad_position' must be int or tuple.")
75
+ return grad_position
76
+
77
+
78
+ @constexpr
79
+ def _check_grad_position(grad_position, args_num):
80
+ """Check and convert grad position index."""
81
+ grad_position = _convert_grad_position_type(grad_position)
82
+ for gp in grad_position:
83
+ if gp < 0 or gp >= args_num:
84
+ raise ValueError("The element in grad_position must belong to [0, args_num).")
85
+ return grad_position
86
+
87
+
88
+ @constexpr
89
+ def _get_grad_op(get_by_list, get_by_position, has_aux, get_value=False, return_ids=False):
90
+ return _Grad(get_by_list=get_by_list, get_by_position=get_by_position, has_aux=has_aux, get_value=get_value,
91
+ return_ids=return_ids)
92
+
93
+
94
+ def grad(fn, grad_position=0, weights=None, has_aux=False, return_ids=False):
95
+ """
96
+ A wrapper function to generate the gradient function for the input function.
97
+
98
+ As for gradient, three typical cases are included:
99
+
100
+ 1. gradient with respect to inputs. In this case, `grad_position` is not None while `weights` is None.
101
+ 2. gradient with respect to weights. In this case, `grad_position` is None while `weights` is not None.
102
+ 3. gradient with respect to inputs and weights. In this case, `grad_position` and `weights` are not None.
103
+
104
+ Args:
105
+ fn (Union[Cell, Function]): Function to do GradOperation.
106
+ grad_position (Union[NoneType, int, tuple[int]]): Index to specify which inputs to be differentiated.
107
+ If int, get the gradient with respect to single input.
108
+ If tuple, get the gradients with respect to selected inputs. `grad_position` begins with 0.
109
+ If None, none derivative of any input will be figured out, and in this case, `weights` is required.
110
+ Default: ``0`` .
111
+ weights (Union[ParameterTuple, Parameter, list[Parameter]]): The parameters of the training network that need to
112
+ calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
113
+ Default: ``None`` .
114
+ has_aux (bool): If ``True`` , only the first output of `fn` contributes the gradient of `fn`, while the other
115
+ outputs will be returned straightly. It means the `fn` must return more than one outputs in this case.
116
+ Default: ``False`` .
117
+ return_ids(bool): Whether return the tuple made by gradients and the index to specify which inputs
118
+ to be differentiated or the name of parameters of the training network that need to calculate the gradient.
119
+ If ``True`` , the output gradients will be replaced by the tuples made by gradients and the index to specify
120
+ which inputs to be differentiated or the name of parameters of the training network.
121
+ Default: ``False`` .
122
+
123
+ Returns:
124
+ Function, the gradient function to calculate gradient for the input function or cell.
125
+ For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set ``True`` , gradient function will return
126
+ outputs like `(gradient, out2)` and `out2` does not contribute to the differentiation, otherwise `gradient`.
127
+ When return_ids is set to ``True`` , the format of the output will be the same with the output of grad when
128
+ return_ids is set to ``False``, but every gradient in the output will be replaced by a tuple of position id or
129
+ parameter name and its gradient.
130
+
131
+ Raises:
132
+ ValueError: If both `grad_position` and `weights` are None.
133
+ TypeError: If type of Args does not belong to required ones.
134
+
135
+ Supported Platforms:
136
+ ``Ascend`` ``GPU`` ``CPU``
137
+
138
+ Examples:
139
+ >>> import numpy as np
140
+ >>> import mindspore
141
+ >>> from mindspore import Tensor, ops, nn, grad
142
+ >>>
143
+ >>> # Cell object to be differentiated
144
+ >>> class Net(nn.Cell):
145
+ ... def construct(self, x, y, z):
146
+ ... return x * y * z
147
+ >>> x = Tensor([1, 2], mindspore.float32)
148
+ >>> y = Tensor([-2, 3], mindspore.float32)
149
+ >>> z = Tensor([0, 3], mindspore.float32)
150
+ >>> net = Net()
151
+ >>> output = grad(net, grad_position=(1, 2))(x, y, z)
152
+ >>> print(output)
153
+ (Tensor(shape=[2], dtype=Float32, value=[ 0.00000000e+00, 6.00000000e+00]),
154
+ Tensor(shape=[2], dtype=Float32, value=[-2.00000000e+00, 6.00000000e+00]))
155
+ >>>
156
+ >>> # Function object to be differentiated
157
+ >>> def fn(x, y, z):
158
+ ... res = x * ops.exp(y) * ops.pow(z, 2)
159
+ ... return res, z
160
+ >>> x = Tensor([3, 3], mindspore.float32)
161
+ >>> y = Tensor([0, 0], mindspore.float32)
162
+ >>> z = Tensor([5, 5], mindspore.float32)
163
+ >>> gradient, aux = grad(fn, (1, 2), None, True)(x, y, z)
164
+ >>> print(gradient)
165
+ (Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01, 7.50000000e+01]),
166
+ Tensor(shape=[2], dtype=Float32, value= [ 3.00000000e+01, 3.00000000e+01]))
167
+ >>> print(aux)
168
+ (Tensor(shape=[2], dtype=Float32, value= [ 5.00000000e+00, 5.00000000e+00]),)
169
+ >>>
170
+ >>> # For given network to be differentiated with both inputs and weights, there are 4 cases.
171
+ >>> net = nn.Dense(10, 1)
172
+ >>> loss_fn = nn.MSELoss()
173
+ >>> def forward(inputs, labels):
174
+ ... logits = net(inputs)
175
+ ... loss = loss_fn(logits, labels)
176
+ ... return loss, logits
177
+ >>> inputs = Tensor(np.random.randn(16, 10).astype(np.float32))
178
+ >>> labels = Tensor(np.random.randn(16, 1).astype(np.float32))
179
+ >>> weights = net.trainable_params()
180
+ >>> # Case 1: gradient with respect to inputs.
181
+ >>> # Aux value does not contribute to the gradient.
182
+ >>> grad_fn = grad(forward, grad_position=(0, 1), weights=None, has_aux=True)
183
+ >>> inputs_gradient, (aux_logits,) = grad_fn(inputs, labels)
184
+ >>> print(len(inputs_gradient))
185
+ 2
186
+ >>> print(aux_logits.shape)
187
+ (16, 1)
188
+ >>>
189
+ >>> # Case 2: gradient with respect to weights.
190
+ >>> grad_fn = grad(forward, grad_position=None, weights=weights, has_aux=True)
191
+ >>> params_gradient, (aux_logits,) = grad_fn(inputs, labels)
192
+ >>> print(len(weights), len(params_gradient))
193
+ 2 2
194
+ >>> print(aux_logits.shape)
195
+ (16, 1)
196
+ >>>
197
+ >>> # Case 3: gradient with respect to inputs and weights.
198
+ >>> grad_fn = grad(forward, grad_position=0, weights=weights, has_aux=False)
199
+ >>> inputs_gradient, params_gradient = grad_fn(inputs, labels)
200
+ >>> print(len(weights), len(params_gradient))
201
+ 2 2
202
+ >>> # Case 4: return the gradient with ids.
203
+ >>> import numpy as np
204
+ >>> import mindspore
205
+ >>> import mindspore.nn as nn
206
+ >>> from mindspore import Tensor, ops
207
+ >>> from mindspore import grad
208
+ >>>
209
+ >>> # Cell object to be differentiated
210
+ >>> class Net(nn.Cell):
211
+ ... def construct(self, x, y, z):
212
+ ... return x * y * z
213
+ >>> x = Tensor([1, 2], mindspore.float32)
214
+ >>> y = Tensor([-2, 3], mindspore.float32)
215
+ >>> z = Tensor([0, 3], mindspore.float32)
216
+ >>> net = Net()
217
+ >>> output = grad(net, grad_position=(1, 2), return_ids = True)(x, y, z)
218
+ >>> print(output)
219
+ ((1, Tensor(shape=[2], dtype=Float32, value=[ 0.00000000e+00, 6.00000000e+00])),
220
+ (2, Tensor(shape=[2], dtype=Float32, value=[-2.00000000e+00, 6.00000000e+00])))
221
+ """
222
+ if grad_position is None and weights is None:
223
+ raise ValueError("`grad_position` and `weight` can not be None at the same time.")
224
+
225
+ if grad_position is None:
226
+ return _get_grad_op(True, False, has_aux, False, return_ids)(fn, weights)
227
+
228
+ grad_position = _convert_grad_position_type(grad_position)
229
+ if weights is None:
230
+ return _get_grad_op(False, True, has_aux, False, return_ids)(fn, None, grad_position)
231
+ return _get_grad_op(True, True, has_aux, False, return_ids)(fn, weights, grad_position)
232
+
233
+
234
+ def value_and_grad(fn, grad_position=0, weights=None, has_aux=False, return_ids=False):
235
+ """
236
+ A wrapper function to generate the function to calculate forward output and gradient for the input function.
237
+
238
+ As for gradient, three typical cases are included:
239
+
240
+ 1. gradient with respect to inputs. In this case, `grad_position` is not None while `weights` is None.
241
+ 2. gradient with respect to weights. In this case, `grad_position` is None while `weights` is not None.
242
+ 3. gradient with respect to inputs and weights. In this case, `grad_position` and `weights` are not None.
243
+
244
+ Args:
245
+ fn (Union[Cell, Function]): Function to do GradOperation.
246
+ grad_position (Union[NoneType, int, tuple[int]]): Index to specify which inputs to be differentiated.
247
+ If int, get the gradient with respect to single input.
248
+ If tuple, get the gradients with respect to selected inputs. `grad_position` begins with 0.
249
+ If None, none derivative of any input will be solved, and in this case, `weights` is required.
250
+ Default: ``0`` .
251
+ weights (Union[ParameterTuple, Parameter, list[Parameter]]): The parameters of the training network that need to
252
+ calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
253
+ Default: ``None`` .
254
+ has_aux (bool): If ``True`` , only the first output of `fn` contributes the gradient of `fn`, while the other
255
+ outputs will be returned straightly. It means the `fn` must return more than one outputs in this case.
256
+ Default: ``False`` .
257
+ return_ids(bool): Whether return the tuple made by gradients and the index to specify which inputs
258
+ to be differentiated or the name of parameters of the training network that need to calculate the gradient.
259
+ If ``True`` , the output gradients will be replaced by the tuples made by gradients and the index to specify
260
+ which inputs to be differentiated or the name of parameters of the training network.
261
+ Default: ``False`` .
262
+
263
+ Returns:
264
+ Function, returns the gradient function to calculate forward output and gradient for the input function or cell.
265
+ For example, as for `out1, out2 = fn(*args)` , gradient function will return outputs like
266
+ `((out1, out2), gradient)` . When `has_aux` is set to ``True``, only `out1` contributes to the differentiation.
267
+
268
+ Raises:
269
+ ValueError: If both `grad_position` and `weights` are None.
270
+ TypeError: If type of Args does not belong to required ones.
271
+
272
+ Supported Platforms:
273
+ ``Ascend`` ``GPU`` ``CPU``
274
+
275
+ Examples:
276
+ >>> import numpy as np
277
+ >>> import mindspore
278
+ >>> from mindspore import Tensor, ops, nn
279
+ >>> from mindspore import value_and_grad
280
+ >>>
281
+ >>> # Cell object to be differentiated
282
+ >>> class Net(nn.Cell):
283
+ ... def construct(self, x, y, z):
284
+ ... return x * y * z
285
+ >>> x = Tensor([1, 2], mindspore.float32)
286
+ >>> y = Tensor([-2, 3], mindspore.float32)
287
+ >>> z = Tensor([0, 3], mindspore.float32)
288
+ >>> net = Net()
289
+ >>> grad_fn = value_and_grad(net, grad_position=1)
290
+ >>> output, inputs_gradient = grad_fn(x, y, z)
291
+ >>> print(output)
292
+ [-0. 18.]
293
+ >>> print(inputs_gradient)
294
+ [0. 6.]
295
+ >>>
296
+ >>> # Function object to be differentiated
297
+ >>> def fn(x, y, z):
298
+ ... res = x * ops.exp(y) * ops.pow(z, 2)
299
+ ... return res, z
300
+ >>> x = Tensor(np.array([3, 3]).astype(np.float32))
301
+ >>> y = Tensor(np.array([0, 0]).astype(np.float32))
302
+ >>> z = Tensor(np.array([5, 5]).astype(np.float32))
303
+ >>> output, inputs_gradient = value_and_grad(fn, grad_position=(1, 2), weights=None, has_aux=True)(x, y, z)
304
+ >>> print(output)
305
+ (Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01, 7.50000000e+01]),
306
+ Tensor(shape=[2], dtype=Float32, value= [ 5.00000000e+00, 5.00000000e+00]))
307
+ >>> print(inputs_gradient)
308
+ (Tensor(shape=[2], dtype=Float32, value= [ 7.50000000e+01, 7.50000000e+01]),
309
+ Tensor(shape=[2], dtype=Float32, value= [ 3.00000000e+01, 3.00000000e+01]))
310
+ >>>
311
+ >>> # For given network to be differentiated with both inputs and weights, there are 3 cases.
312
+ >>> net = nn.Dense(10, 1)
313
+ >>> loss_fn = nn.MSELoss()
314
+ >>> def forward(inputs, labels):
315
+ ... logits = net(inputs)
316
+ ... loss = loss_fn(logits, labels)
317
+ ... return loss, logits
318
+ >>> inputs = Tensor(np.random.randn(16, 10).astype(np.float32))
319
+ >>> labels = Tensor(np.random.randn(16, 1).astype(np.float32))
320
+ >>> weights = net.trainable_params()
321
+ >>>
322
+ >>> # Case 1: gradient with respect to inputs.
323
+ >>> # For has_aux is set True, only loss contributes to the gradient.
324
+ >>> grad_fn = value_and_grad(forward, grad_position=0, weights=None, has_aux=True)
325
+ >>> (loss, logits), inputs_gradient = grad_fn(inputs, labels)
326
+ >>> print(logits.shape)
327
+ (16, 1)
328
+ >>> print(inputs.shape, inputs_gradient.shape)
329
+ (16, 10) (16, 10)
330
+ >>>
331
+ >>> # Case 2: gradient with respect to weights.
332
+ >>> # For has_aux is set True, only loss contributes to the gradient.
333
+ >>> grad_fn = value_and_grad(forward, grad_position=None, weights=weights, has_aux=True)
334
+ >>> (loss, logits), params_gradient = grad_fn(inputs, labels)
335
+ >>> print(logits.shape)
336
+ (16, 1)
337
+ >>> print(len(weights), len(params_gradient))
338
+ 2 2
339
+ >>>
340
+ >>> # Case 3: gradient with respect to inputs and weights.
341
+ >>> # For has_aux is set False, both loss and logits contribute to the gradient.
342
+ >>> grad_fn = value_and_grad(forward, grad_position=0, weights=weights, has_aux=False)
343
+ >>> (loss, logits), (inputs_gradient, params_gradient) = grad_fn(inputs, labels)
344
+ >>> print(logits.shape)
345
+ (16, 1)
346
+ >>> print(inputs.shape, inputs_gradient.shape)
347
+ (16, 10) (16, 10)
348
+ >>> print(len(weights), len(params_gradient))
349
+ 2 2
350
+ """
351
+ if grad_position is None and weights is None:
352
+ raise ValueError("`grad_position` and `weight` can not be None at the same time.")
353
+
354
+ if grad_position is None:
355
+ return _get_grad_op(True, False, has_aux, True, return_ids)(fn, weights)
356
+
357
+ grad_position = _convert_grad_position_type(grad_position)
358
+ if weights is None:
359
+ return _get_grad_op(False, True, has_aux, True, return_ids)(fn, None, grad_position)
360
+ return _get_grad_op(True, True, has_aux, True, return_ids)(fn, weights, grad_position)
361
+
362
+
363
+ def get_grad(gradients, identifier):
364
+ """
365
+ When `return_ids` of :func:`mindspore.grad` or :func:`mindspore.grad` is set to ``True`` ,
366
+ use return value of `mindspore.grad`, or the second return value of `mindspore.grad` as gradients.
367
+ Then find the specific gradient from `gradients` according to `identifier` .
368
+
369
+ As for gradient, two typical cases are included:
370
+
371
+ 1. `identifier` is the position of the specific tensor to get gradient.
372
+ 2. `identifier` is a parameter of a network.
373
+
374
+ Args:
375
+ gradients (Union[tuple[int, Tensor], tuple[tuple, tuple]]): The return value of :func:`mindspore.grad`
376
+ when `return_ids` is set to True.
377
+ identifier (Union[int, Parameter]): The position number of a tensor, or a parameter that is used in
378
+ :func:`mindspore.grad`.
379
+
380
+ Returns:
381
+ The gradient of the tensor on the position or in the parameter that specified by the `identifier`.
382
+
383
+ Raises:
384
+ RuntimeError: If gradient is not found.
385
+ TypeError: If type of Args does not belong to required ones.
386
+
387
+ Supported Platforms:
388
+ ``Ascend`` ``GPU`` ``CPU``
389
+
390
+ Examples:
391
+ >>> import mindspore
392
+ >>> from mindspore import Tensor, nn
393
+ >>> from mindspore import grad, get_grad
394
+ >>>
395
+ >>> # Cell object to be differentiated
396
+ >>> class Net(nn.Cell):
397
+ ... def construct(self, x, y, z):
398
+ ... return x * y * z
399
+ >>> x = Tensor([1, 2], mindspore.float32)
400
+ >>> y = Tensor([-2, 3], mindspore.float32)
401
+ >>> z = Tensor([0, 3], mindspore.float32)
402
+ >>> net = Net()
403
+ >>> out_grad = grad(net, grad_position=(1, 2), return_ids=True)(x, y, z)
404
+ >>> output = get_grad(out_grad, 1)
405
+ >>> print(output)
406
+ [0. 6.]
407
+ """
408
+ return inner.GetGrad()(gradients, identifier)
409
+
410
+
411
+ def _trans_jet_inputs(primals_item, series_item):
412
+ """Trans inputs of jet"""
413
+ value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64]
414
+ if not dtype(primals_item) in value_type or dtype(primals_item) != dtype(series_item):
415
+ raise TypeError(f"For `F.jet`, the elements' types of primals and series must be the same and belong to "
416
+ f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got other dtype.")
417
+ if dtype(primals_item) in [mstype.int32, mstype.int64]:
418
+ return cast(primals_item, mstype.float32), cast(series_item, mstype.float32)
419
+ return primals_item, series_item
420
+
421
+
422
+ def _check_jet_inputs(primals, series):
423
+ """Check inputs of jet"""
424
+ if not (isinstance(primals, Tensor) and isinstance(series, Tensor)) and \
425
+ not (isinstance(primals, tuple) and isinstance(series, tuple)):
426
+ raise TypeError(f"For 'F.jet', the 'primals' and `series` must be both Tensor or tuple.")
427
+ if isinstance(primals, Tensor):
428
+ if primals.shape == series.shape[1:]:
429
+ return _trans_jet_inputs(primals, series)
430
+ if primals.shape == series.shape:
431
+ return _trans_jet_inputs(primals, series.expand_dims(axis=0))
432
+ raise ValueError("In series, the shape of each element must be the same as the primals.")
433
+ if len(primals) != len(series):
434
+ raise ValueError("The lengths of primals and series must be the same.")
435
+ check_primals = []
436
+ check_series = []
437
+ for i, j in zip(primals, series):
438
+ trans_primals_item, trans_series_item = _trans_jet_inputs(i, j)
439
+ check_primals.append(trans_primals_item)
440
+ check_series.append(trans_series_item)
441
+ return check_primals, check_series
442
+
443
+
444
+ _taylor = _TaylorOperation()
445
+
446
+
447
+ def _preprocess_jet(x, y):
448
+ concat_op = P.Concat()
449
+ return concat_op((expand_dims(x, 0), y))
450
+
451
+
452
+ def jet(fn, primals, series):
453
+ """
454
+ This function is designed to calculate the higher order differentiation of given composite function. To figure out
455
+ first to `n`-th order differentiations, original inputs and first to `n`-th order derivative of original inputs
456
+ must be provided together. Generally, it is recommended to set the values of given first order derivative to 1,
457
+ while the other to 0, which is like the derivative of origin input with respect to itself.
458
+
459
+ Note:
460
+ If `primals` is Tensor of int type, it will be converted to Tensor of float type.
461
+
462
+ Args:
463
+ fn (Union[Cell, function]): Function to do TaylorOperation.
464
+ primals (Union[Tensor, tuple[Tensor]]): The inputs to `fn`.
465
+ series (Union[Tensor, tuple[Tensor]]): If tuple, the length and type of series should be the same as inputs.
466
+ For each Tensor, the length of first dimension `i` represents the `1` to `i+1`-th order of derivative of
467
+ output with respect to the inputs will be figured out.
468
+
469
+ Returns:
470
+ Tuple, tuple of out_primals and out_series.
471
+
472
+ - **out_primals** (Union[Tensor, list[Tensor]]) - The output of `fn(primals)`.
473
+ - **out_series** (Union[Tensor, list[Tensor]]) - The `1` to `i+1`-th order of derivative of output with respect
474
+ to the inputs.
475
+
476
+ Raises:
477
+ TypeError: If `primals` is not a tensor or tuple of tensors.
478
+ TypeError: If type of `primals` is not the same as type of `series`.
479
+
480
+ Supported Platforms:
481
+ ``Ascend`` ``GPU`` ``CPU``
482
+
483
+ Examples:
484
+ >>> import numpy as np
485
+ >>> import mindspore.nn as nn
486
+ >>> import mindspore as ms
487
+ >>> from mindspore import ops
488
+ >>> from mindspore import Tensor
489
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
490
+ >>> class Net(nn.Cell):
491
+ ... def __init__(self):
492
+ ... super().__init__()
493
+ ... self.sin = ops.Sin()
494
+ ... self.exp = ops.Exp()
495
+ ... def construct(self, x):
496
+ ... out1 = self.sin(x)
497
+ ... out2 = self.exp(out1)
498
+ ... return out2
499
+ >>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
500
+ >>> series = Tensor(np.array([[[1, 1], [1, 1]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]]).astype(np.float32))
501
+ >>> net = Net()
502
+ >>> out_primals, out_series = ops.jet(net, primals, series)
503
+ >>> print(out_primals, out_series)
504
+ [[2.319777 2.4825778]
505
+ [1.1515628 0.4691642]] [[[ 1.2533808 -1.0331168 ]
506
+ [-1.1400385 -0.3066662 ]]
507
+ [[-1.2748207 -1.8274734 ]
508
+ [ 0.966121 0.55551505]]
509
+ [[-4.0515366 3.6724353 ]
510
+ [ 0.5053504 -0.52061415]]]
511
+ """
512
+ primals, series = _check_jet_inputs(primals, series)
513
+ derivative_fn = _taylor(fn)
514
+ if isinstance(primals, list) and len(primals) > 1:
515
+ inputs = map(_preprocess_jet, primals, series)
516
+ outputs = derivative_fn(*inputs)
517
+ else:
518
+ inputs = _preprocess_jet(primals, series)
519
+ outputs = derivative_fn(inputs)
520
+ if isinstance(outputs, tuple) and len(outputs) > 1:
521
+ out_primals = []
522
+ out_series = []
523
+ for element in outputs:
524
+ out_primals.append(element[0])
525
+ out_series.append(element[1:])
526
+ else:
527
+ out_primals = outputs[0]
528
+ out_series = outputs[1:]
529
+ return out_primals, out_series
530
+
531
+
532
+ def _trans_derivative_inputs(primals_item):
533
+ """Trans inputs of derivative"""
534
+ value_type = [mstype.int32, mstype.int64, mstype.float32, mstype.float64]
535
+ if not dtype(primals_item) in value_type:
536
+ raise TypeError(f"For `F.derivative`, the elements of primals must belong to "
537
+ f"`mstype.int32, mstype.int64, mstype.float32, mstype.float64`, but got other dtype.")
538
+ if dtype(primals_item) in [mstype.int32, mstype.int64]:
539
+ return cast(primals_item, mstype.float32)
540
+ return primals_item
541
+
542
+
543
+ @constexpr
544
+ def _check_derivative_order(order):
545
+ """check input order of derivative"""
546
+ if not isinstance(order, int):
547
+ raise TypeError(f"For `F.derivative`, the type of order must be int.")
548
+ if order < 1:
549
+ raise ValueError(f"For `F.derivative`, value of order should not be less than 1, but got {order}.")
550
+ return True
551
+
552
+
553
+ def _preprocess_derivate_order_one(x):
554
+ concat_op = P.Concat()
555
+ return concat_op((expand_dims(x, 0), ones((1,) + x.shape, dtype(x))))
556
+
557
+
558
+ def _preprocess_derivate_order_more(x, order):
559
+ concat_op = P.Concat()
560
+ return concat_op((x, zeros((order - 1,) + x[0].shape, dtype(x))))
561
+
562
+
563
+ def derivative(fn, primals, order):
564
+ """
565
+ This function is designed to calculate the higher order differentiation of given composite function. To figure out
566
+ `order`-th order differentiations, original inputs and order must be provided together. In particular, the value of
567
+ input first order derivative is set to 1, while the other to 0.
568
+
569
+ Note:
570
+ If `primals` is Tensor of int type, it will be converted to Tensor of float type.
571
+
572
+ Args:
573
+ fn (Union[Cell, function]): Function to do TaylorOperation.
574
+ primals (Union[Tensor, tuple[Tensor]]): The inputs to `fn`.
575
+ order (int): For each Tensor, the `order`-th order of derivative of output with respect to the inputs will be
576
+ figured out.
577
+
578
+ Returns:
579
+ Tuple, tuple of out_primals and out_series.
580
+
581
+ - **out_primals** (Union[Tensor, list[Tensor]]) - The output of `fn(primals)`.
582
+ - **out_series** (Union[Tensor, list[Tensor]]) - The `order`-th order of derivative of output with respect
583
+ to the inputs.
584
+
585
+ Raises:
586
+ TypeError: If `primals` is not a tensor or tuple of tensors.
587
+ TypeError: If `order` is not int.
588
+ ValueError: If `order` is less than 1.
589
+
590
+ Supported Platforms:
591
+ ``Ascend`` ``GPU`` ``CPU``
592
+
593
+ Examples:
594
+ >>> import numpy as np
595
+ >>> import mindspore as ms
596
+ >>> import mindspore.nn as nn
597
+ >>> from mindspore import ops
598
+ >>> from mindspore import Tensor
599
+ >>> ms.set_context(mode=ms.GRAPH_MODE)
600
+ >>> class Net(nn.Cell):
601
+ ... def __init__(self):
602
+ ... super().__init__()
603
+ ... self.sin = ops.Sin()
604
+ ... self.exp = ops.Exp()
605
+ ... def construct(self, x):
606
+ ... out1 = self.sin(x)
607
+ ... out2 = self.exp(out1)
608
+ ... return out2
609
+ >>> primals = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
610
+ >>> order = 3
611
+ >>> net = Net()
612
+ >>> out_primals, out_series = ops.derivative(net, primals, order)
613
+ >>> print(out_primals, out_series)
614
+ [[2.319777 2.4825778]
615
+ [1.1515628 0.4691642]] [[-4.0515366 3.6724353 ]
616
+ [ 0.5053504 -0.52061415]]
617
+ """
618
+ derivative_fn = _taylor(fn)
619
+ concat_op = P.Concat()
620
+ series_one = 1
621
+ _check_derivative_order(order)
622
+ if isinstance(primals, tuple):
623
+ trans_primals = map(_trans_derivative_inputs, primals)
624
+ inputs = map(_preprocess_derivate_order_one, trans_primals)
625
+ if order > 1:
626
+ processed_inputs = []
627
+ for element in inputs:
628
+ processed_inputs.append(_preprocess_derivate_order_more(element, order))
629
+ outputs = derivative_fn(*processed_inputs)
630
+ else:
631
+ outputs = derivative_fn(*inputs)
632
+ else:
633
+ primals = _trans_derivative_inputs(primals)
634
+ series = zeros((order,) + primals.shape, dtype(primals))
635
+ series[0] = series_one
636
+ inputs = concat_op((expand_dims(primals, 0), series))
637
+ outputs = derivative_fn(inputs)
638
+ if isinstance(outputs, tuple) and len(outputs) > 1:
639
+ out_primals = []
640
+ out_series = []
641
+ for element in outputs:
642
+ out_primals.append(element[0])
643
+ out_series.append(element[-1])
644
+ else:
645
+ out_primals = outputs[0]
646
+ out_series = outputs[-1]
647
+ return out_primals, out_series
648
+
649
+
650
+ _grad_single = GradOperation(sens_param=True)
651
+ _grad_all = GradOperation(sens_param=True, get_all=True)
652
+
653
+
654
+ @constexpr
655
+ def _check_jvp_input_v_len(inputs_len, v_len):
656
+ if inputs_len != v_len:
657
+ raise ValueError(f'v has invalid length: should be {inputs_len}, but got {v_len}')
658
+
659
+
660
+ def jvp(fn, inputs, v, has_aux=False):
661
+ """
662
+ Compute the jacobian-vector-product of the given network. The calculation procedure of JVP can be found in
663
+ `forward-mode differentiation
664
+ <https://www.mindspore.cn/docs/en/master/design/programming_paradigm.html#forward-mode-ad>`_.
665
+
666
+ Args:
667
+ fn (Union[Function, Cell]): The function or net that takes Tensor inputs and returns single Tensor or tuple of
668
+ Tensors.
669
+ inputs (Union[Tensor, tuple[Tensor], list[Tensor]]): The inputs to `fn` .
670
+ v (Union[Tensor, tuple[Tensor], list[Tensor]]): The vector in jacobian-vector-product. The shape and type of `v`
671
+ should be the same as `inputs` .
672
+ has_aux (bool): If ``True`` , only the first output of `fn` contributes the gradient of `fn`, while the other
673
+ outputs will be returned straightly. It means the `fn` must return more than one outputs in this case.
674
+ Default: ``False`` .
675
+
676
+ Returns:
677
+ - **net_output** (Union[Tensor, tuple[Tensor]]) - The output of `fn(inputs)` . Specially, when `has_aux` is set
678
+ ``True`` , `netout` is the first output of `fn(inputs)` .
679
+ - **jvp** (Union[Tensor, tuple[Tensor]]) - The result of jacobian-vector-product.
680
+ - **aux_value** (Union[Tensor, tuple[Tensor]], optional) - When `has_aux` is ``True`` , `aux_value` will be
681
+ returned. It means the second to last outputs of `fn(inputs)` . Specially, `aux_value` does not contribute to
682
+ gradient.
683
+
684
+ Raises:
685
+ TypeError: `inputs` or `v` does not belong to required types.
686
+
687
+ Supported Platforms:
688
+ ``Ascend`` ``GPU`` ``CPU``
689
+
690
+ Examples:
691
+ >>> import numpy as np
692
+ >>> from mindspore import jvp
693
+ >>> from mindspore import Tensor
694
+ >>> import mindspore.nn as nn
695
+ >>> class Net(nn.Cell):
696
+ ... def construct(self, x, y):
697
+ ... return x**3 + y
698
+ >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
699
+ >>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
700
+ >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
701
+ >>> output = jvp(Net(), (x, y), (v, v))
702
+ >>> print(output[0])
703
+ [[ 2. 10.]
704
+ [30. 68.]]
705
+ >>> print(output[1])
706
+ [[ 4. 13.]
707
+ [28. 49.]]
708
+ >>>
709
+ >>> def fn(x, y):
710
+ ... return x ** 3 + y, y
711
+ >>> output, jvp_out, aux = jvp(fn, (x, y), (v, v), has_aux=True)
712
+ >>> print(output)
713
+ [[ 2. 10.]
714
+ [30. 68.]]
715
+ >>> print(jvp_out)
716
+ [[ 4. 13.]
717
+ [28. 49.]]
718
+ >>> print(aux)
719
+ [[ 1. 2.]
720
+ [3. 4.]]
721
+ """
722
+ _check_has_aux_type(has_aux)
723
+
724
+ def aux_fn(*args):
725
+ outputs = fn(*args)
726
+ if not isinstance(outputs, tuple) or len(outputs) < 2:
727
+ raise ValueError("When 'has_aux' is True, origin 'fn' requires more than one outputs.")
728
+ res = outputs[0]
729
+ return res
730
+
731
+ def grad_single(u, first_grad_single_value):
732
+ if has_aux:
733
+ return _grad_single(aux_fn)(*first_grad_single_value, u)
734
+ return _grad_single(fn)(*first_grad_single_value, u)
735
+
736
+ def grad_all(u, first_grad):
737
+ if has_aux:
738
+ return _grad_all(aux_fn)(*first_grad, u)
739
+ return _grad_all(fn)(*first_grad, u)
740
+
741
+ def _wrap_container_inner(*arg):
742
+ jvp_inputs = arg[1:]
743
+ vectors = arg[0]
744
+ if has_aux:
745
+ outputs = aux_fn(*jvp_inputs)
746
+ else:
747
+ outputs = fn(*jvp_inputs)
748
+ if isinstance(outputs, tuple):
749
+ u = ()
750
+ for item in outputs:
751
+ u = u + (mutable(oneslike(item)),)
752
+ else:
753
+ u = mutable(oneslike(outputs))
754
+ if len(jvp_inputs) == 1:
755
+ second_grad_net = _grad_single(grad_single)
756
+ gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
757
+ else:
758
+ second_grad_net = _grad_single(grad_all)
759
+ gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
760
+ if has_aux:
761
+ res = fn(*jvp_inputs)
762
+ if len(res) == 2:
763
+ return res[0], gradient_outputs, res[1]
764
+ return res[0], gradient_outputs, res[1:]
765
+ return outputs, gradient_outputs
766
+
767
+ if has_aux:
768
+ @jit(hash_args=aux_fn)
769
+ def _wrap_container(*arg):
770
+ return _wrap_container_inner(*arg)
771
+ else:
772
+ @jit(hash_args=fn)
773
+ def _wrap_container(*arg):
774
+ return _wrap_container_inner(*arg)
775
+
776
+ if not isinstance(inputs, (Tensor, tuple, list)) or not isinstance(v, (Tensor, tuple, list)):
777
+ _raise_type_error()
778
+
779
+ inputs_len = 1
780
+ v_len = 1
781
+ if isinstance(inputs, (tuple, list)):
782
+ inputs_len = len(inputs)
783
+ if isinstance(v, (tuple, list)):
784
+ v_len = len(v)
785
+ _check_jvp_input_v_len(inputs_len, v_len)
786
+
787
+ if isinstance(v, list):
788
+ v = tuple(v)
789
+ if isinstance(inputs, (tuple, list)):
790
+ return _wrap_container(v, *inputs)
791
+ return _wrap_container(v, inputs)
792
+
793
+
794
+ def linearize(fn, inputs):
795
+ """
796
+ Produces a linear approximation to fun using jvp() and partial eval.
797
+ This function is mainly useful if you want to apply jvp multiple times.
798
+
799
+ Args:
800
+ fn (Union[Function, Cell]): The function or net that takes Tensor inputs and returns single tensor or tuple of
801
+ Tensors.
802
+ inputs (Union[Tensor, Tuple or List of Tensors]): The inputs to `fn`.
803
+
804
+ Returns:
805
+ Tuple, tuple of output and jvp_fn.
806
+
807
+ - **netout** (Tensor or Tuple of Tensors) - The output of "fn(inputs)".
808
+ - **jvp_fn** (Function) - The function that evaluates the Jacobian-vector product.
809
+
810
+ Raises:
811
+ TypeError: If the input is not a tensor or tuple or list of tensors.
812
+
813
+ Supported Platforms:
814
+ ``Ascend`` ``GPU`` ``CPU``
815
+
816
+ Examples:
817
+ >>> import numpy as np
818
+ >>> from mindspore import Tensor, Parameter, ops
819
+ >>> from mindspore import nn
820
+ >>> from mindspore.ops.functional import linearize
821
+
822
+ >>> class Net(nn.Cell):
823
+ ... def __init__(self):
824
+ ... super(Net, self).__init__()
825
+ ... self.matmul = ops.MatMul()
826
+ ... def construct(self, x, y):
827
+ ... out = self.matmul(x, y)
828
+ ... return out
829
+ >>> x = Tensor(np.array([[1, 2, 3], [3, 4, 5]]).astype(np.float32))
830
+ >>> y = Tensor(np.array([[1, 2], [3, 4], [5, 6]]).astype(np.float32))
831
+ >>> v = (Tensor(np.array([[1, 1, 1], [1, 1, 1]]).astype(np.float32)),
832
+ ... Tensor(np.array([[1, 1], [1, 1], [0, 0]]).astype(np.float32)))
833
+ >>> output, jvp_fn = linearize(Net(), (x, y))
834
+ >>> print(output)
835
+ [[22. 28.]
836
+ [40. 52.]]
837
+ >>> jvp = jvp_fn(v)
838
+ >>> print(jvp)
839
+ [[12. 15.]
840
+ [16. 19.]]
841
+ """
842
+ linearize_inner = _LinearizeInner()
843
+
844
+ @jit(hash_args=fn)
845
+ def _wrap_container(*arg):
846
+ args = arg[1:-1]
847
+ vectors = arg[-1]
848
+ output = arg[0]
849
+ if isinstance(vectors, list):
850
+ vectors = tuple(vectors)
851
+ return linearize_inner(fn, vectors, output, args)
852
+
853
+ if not isinstance(inputs, (Tensor, tuple, list)):
854
+ _raise_type_error()
855
+ if isinstance(inputs, Tensor):
856
+ inputs = (inputs,)
857
+ output = fn(*inputs)
858
+ return output, partial(_wrap_container, output, *inputs)
859
+
860
+
861
+ def _check_tensor(inputs):
862
+ if not isinstance(inputs, (Tensor, tuple)):
863
+ raise TypeError("The inputs type must be Tensor.")
864
+ if isinstance(inputs, tuple):
865
+ for item in inputs:
866
+ if not isinstance(item, (Tensor, tuple, list)):
867
+ raise TypeError("The inputs type must be Tensor.")
868
+ return True
869
+
870
+
871
+ _vjp_grad_op = _Grad(get_all=True, sens_param=True, merge_forward=True)
872
+ _vjp_grad_op_with_weight = _Grad(get_all=True, get_by_list=True, sens_param=True, merge_forward=True)
873
+
874
+
875
+ def vjp(fn, *inputs, weights=None, has_aux=False):
876
+ """
877
+ Compute the vector-jacobian-product of the given network. `vjp` matches
878
+ `reverse-mode differentiation
879
+ <https://www.mindspore.cn/docs/en/master/design/programming_paradigm.html#reverse-mode-ad>`_.
880
+
881
+ Args:
882
+ fn (Union[Function, Cell]): The function or net that takes Tensor inputs and returns single Tensor or tuple of
883
+ Tensors.
884
+ inputs (Union[Tensor, tuple[Tensor], list[Tensor]]): The inputs to `fn` .
885
+ weights (Union[ParameterTuple, Parameter, list[Parameter]]): The parameters of the training network that need to
886
+ calculate the gradient. `weights` can be got through `weights = net.trainable_params()` .
887
+ Default: ``None`` .
888
+ has_aux (bool): If True, only the first output of `fn` contributes the gradient of `fn`, while the other outputs
889
+ will be returned straightly. It means the `fn` must return more than one outputs in this case.
890
+ Default: ``False``.
891
+
892
+ Returns:
893
+ Forward outputs and function to calculate vjp.
894
+
895
+ - **net_output** (Union[Tensor, tuple[Tensor]]) - The output of `fn(inputs)`.
896
+ Specially, when `has_aux` is set to
897
+ ``True``, `net_output` is the first output of `fn(inputs)`.
898
+ - **vjp_fn** (Function) - To calculate vector-jacobian-product. Its inputs are the vectors whose shape and
899
+ type should be the same as `net_output` .
900
+ - **aux_value** (Union[Tensor, tuple[Tensor]], optional) - When `has_aux` is True, `aux_value` will be returned.
901
+ It means the second to last outputs of `fn(inputs)`. Specially, `aux_value` does not contribute to gradient.
902
+
903
+ Raises:
904
+ TypeError: `inputs` or `v` does not belong to required types.
905
+
906
+ Supported Platforms:
907
+ ``Ascend`` ``GPU`` ``CPU``
908
+
909
+ Examples:
910
+ >>> import numpy as np
911
+ >>> import mindspore.nn as nn
912
+ >>> from mindspore import vjp
913
+ >>> from mindspore import Tensor
914
+ >>> class Net(nn.Cell):
915
+ ... def construct(self, x, y):
916
+ ... return x**3 + y
917
+ >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
918
+ >>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
919
+ >>> v = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
920
+ >>> outputs, vjp_fn = vjp(Net(), x, y)
921
+ >>> print(outputs)
922
+ [[ 2. 10.]
923
+ [30. 68.]]
924
+ >>> gradient = vjp_fn(v)
925
+ >>> print(gradient)
926
+ (Tensor(shape=[2, 2], dtype=Float32, value=
927
+ [[ 3.00000000e+00, 1.20000000e+01],
928
+ [ 2.70000000e+01, 4.80000000e+01]]), Tensor(shape=[2, 2], dtype=Float32, value=
929
+ [[ 1.00000000e+00, 1.00000000e+00],
930
+ [ 1.00000000e+00, 1.00000000e+00]]))
931
+ >>> def fn(x, y):
932
+ ... return 2 * x + y, y ** 3
933
+ >>> outputs, vjp_fn, aux = vjp(fn, x, y, has_aux=True)
934
+ >>> gradient = vjp_fn(v)
935
+ >>> print(outputs)
936
+ [[ 3. 6.]
937
+ [ 9. 12.]]
938
+ >>> print(aux)
939
+ [[ 1. 8.]
940
+ [27. 64.]]
941
+ >>> print(gradient)
942
+ (Tensor(shape=[2, 2], dtype=Float32, value=
943
+ [[ 2.00000000e+00, 2.00000000e+00],
944
+ [ 2.00000000e+00, 2.00000000e+00]]), Tensor(shape=[2, 2], dtype=Float32, value=
945
+ [[ 1.00000000e+00, 1.00000000e+00],
946
+ [ 1.00000000e+00, 1.00000000e+00]]))
947
+ """
948
+ _check_tensor(inputs)
949
+ _check_has_aux_type(has_aux)
950
+
951
+ def aux_fn(*args):
952
+ outputs = fn(*args)
953
+ if not isinstance(outputs, tuple) or len(outputs) < 2:
954
+ raise ValueError("When 'has_aux' is True, origin 'fn' requires more than one outputs.")
955
+ res = outputs[0]
956
+ return res
957
+
958
+ def wrap_container(*v):
959
+ _check_tensor(v)
960
+ if has_aux:
961
+ fn_ = aux_fn
962
+ else:
963
+ fn_ = fn
964
+ sens = v
965
+ if len(v) == 1:
966
+ sens = v[0]
967
+ if weights is None:
968
+ return _vjp_grad_op(fn_)(*inputs, sens)
969
+ return _vjp_grad_op_with_weight(fn_, weights)(*inputs, sens)
970
+
971
+ res = fn(*inputs)
972
+ if has_aux:
973
+ if len(res) == 2:
974
+ return res[0], wrap_container, res[1]
975
+ return res[0], wrap_container, res[1:]
976
+ return res, wrap_container
977
+
978
+
979
+ @_primexpr
980
+ def _jac_generate_target_dimension(x):
981
+ """For given length = len(x), this method generates target dimension tuple (1, 2, 3,..., length, 0)."""
982
+ dim = ()
983
+ for index in range(len(x[1:])):
984
+ dim += (index + 1,)
985
+ target_dimension = dim + (0,)
986
+ return target_dimension
987
+
988
+
989
+ def _jacfwd_trans_item(item, inputs_shape, grad_position):
990
+ """transfer origin item to derivative of each output with respect to each input."""
991
+ output_wrt_input_all = ()
992
+ for i in grad_position:
993
+ origin_output_wrt_input = item[inputs_shape[i][1]:inputs_shape[i + 1][1]]
994
+ target_dimension = _jac_generate_target_dimension(origin_output_wrt_input.shape)
995
+ temp = transpose(origin_output_wrt_input, target_dimension)
996
+ output_wrt_input = reshape(temp, temp.shape[:-1] + inputs_shape[i + 1][0])
997
+ output_wrt_input_all += (output_wrt_input,)
998
+ return output_wrt_input_all
999
+
1000
+
1001
+ def _jac_postprocess(x, shape, grad_position, mode):
1002
+ """reformat jacobian."""
1003
+
1004
+ if mode == 'forward':
1005
+ func = _jacfwd_trans_item
1006
+ args = (shape, grad_position)
1007
+ else:
1008
+ func = _jacrev_trans_item
1009
+ args = (shape,)
1010
+
1011
+ if isinstance(x, tuple):
1012
+ jacobian = ()
1013
+ for item in x:
1014
+ jacobian += func(item, *args)
1015
+ res = jacobian
1016
+ else:
1017
+ res = func(x, *args)
1018
+ if len(res) == 1:
1019
+ return res[0]
1020
+ input_num = len(grad_position)
1021
+ if len(res) % input_num != 0:
1022
+ raise ValueError("The numbers of inputs and outputs do not match.")
1023
+ output_num = len(res) // input_num
1024
+ if input_num == 1 or output_num == 1:
1025
+ return res
1026
+ jac = ()
1027
+ for i in range(output_num):
1028
+ input_grad = ()
1029
+ for j in range(input_num):
1030
+ input_grad += (res[i * input_num + j],) if mode == 'forward' else (res[j * output_num + i],)
1031
+ jac += (input_grad,)
1032
+ return jac
1033
+
1034
+
1035
+ def _jacfwd_postprocess(x, inputs_shape, grad_position):
1036
+ """reformat forward-computed Jacobian."""
1037
+ return _jac_postprocess(x, inputs_shape, grad_position, 'forward')
1038
+
1039
+
1040
+ def _jacfwd_construct_v(inputs, grad_position):
1041
+ """
1042
+ For input (x1, x2), x1.shape = (a, b), x2.shape = (c, d), this method generates corresponding v (v1, v2),
1043
+ v1.shape = (N, a, b), v2.shape = (N, c, d), while N = a*b + c*d.
1044
+ """
1045
+ v = ()
1046
+ primals = ()
1047
+ inputs_shape = (((), 0),)
1048
+ num = 0
1049
+ items_num = ()
1050
+ cum_num = (0,)
1051
+ for item in inputs:
1052
+ num += size(item)
1053
+ inputs_shape += ((item.shape, num),)
1054
+ items_num += (size(item),)
1055
+ cum_num += (num,)
1056
+ for i, element in enumerate(inputs):
1057
+ item_size = items_num[i]
1058
+ if i in grad_position:
1059
+ temp2 = Tensor(np.eye(num, item_size, -cum_num[i], np.float32))
1060
+ else:
1061
+ temp2 = zeros((num, item_size), mstype.float32)
1062
+ input_v = reshape(temp2, (num,) + element.shape)
1063
+ primal = broadcast_to(element, (num,) + element.shape)
1064
+ v += (input_v,)
1065
+ primals += (primal,)
1066
+ if len(inputs) == 1:
1067
+ return primals, v[0], inputs_shape
1068
+ return primals, v, inputs_shape
1069
+
1070
+
1071
+ _vmap = _Vmap()
1072
+
1073
+
1074
+ def jacfwd(fn, grad_position=0, has_aux=False):
1075
+ """
1076
+ Compute Jacobian via forward mode, corresponding to
1077
+ `forward-mode differentiation
1078
+ <https://www.mindspore.cn/docs/en/master/design/programming_paradigm.html#forward-mode-ad>`_.
1079
+ When number of outputs is much greater than that of inputs, it's better to calculate Jacobian via forward mode than
1080
+ reverse mode to get better performance.
1081
+
1082
+ Args:
1083
+ fn (Union[Cell, Function]): Function to do GradOperation.
1084
+ grad_position (Union[int, tuple[int]], optional): If int, get the gradient with respect to single input.
1085
+ If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: ``0`` .
1086
+ has_aux (bool, optional): If ``True`` , only the first output of `fn` contributes the gradient of `fn`,
1087
+ while the other outputs will be returned straightly. It means the `fn` must return more than one
1088
+ outputs in this case. Default: ``False`` .
1089
+
1090
+ Returns:
1091
+ Function, returns the Jacobian function for the input function or cell.
1092
+ For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set ``True`` , gradient function will return
1093
+ outputs like `(Jacobian, out2)` and `out2` does not contribute to the differentiation, otherwise `Jacobian` .
1094
+
1095
+ Raises:
1096
+ TypeError: `grad_position` or `has_aux` does not belong to required types.
1097
+
1098
+ Supported Platforms:
1099
+ ``Ascend`` ``GPU`` ``CPU``
1100
+
1101
+ Examples:
1102
+ >>> import numpy as np
1103
+ >>> import mindspore.nn as nn
1104
+ >>> from mindspore import jacfwd
1105
+ >>> from mindspore import Tensor
1106
+ >>> class MultipleInputsMultipleOutputsNet(nn.Cell):
1107
+ ... def construct(self, x, y, z):
1108
+ ... return x ** 2 + y ** 2 + z ** 2, x * y * z
1109
+ >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
1110
+ >>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
1111
+ >>> z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
1112
+ >>> net = MultipleInputsMultipleOutputsNet()
1113
+ >>> jac, aux = jacfwd(net, grad_position=0, has_aux=True)(x, y, z)
1114
+ >>> print(jac)
1115
+ [[[[ 2. 0.]
1116
+ [ 0. 0.]]
1117
+ [[ 0. 4.]
1118
+ [ 0. 0.]]]
1119
+ [[[ 0. 0.]
1120
+ [ 6. 0.]]
1121
+ [[ 0. 0.]
1122
+ [ 0. 8.]]]]
1123
+ >>> print(aux)
1124
+ [[ 1. 4.]
1125
+ [ 9. 16.]]
1126
+ """
1127
+ _check_has_aux_type(has_aux)
1128
+
1129
+ def aux_fn(*args):
1130
+ outputs = fn(*args)
1131
+ if not isinstance(outputs, tuple) or len(outputs) < 2:
1132
+ raise ValueError("When 'has_aux' is True, origin 'fn' requires more than one outputs.")
1133
+ res = outputs[0]
1134
+ return res
1135
+
1136
+ def grad_single(u, first_grad_single_value):
1137
+ if has_aux:
1138
+ return _grad_single(aux_fn)(*first_grad_single_value, u)
1139
+ return _grad_single(fn)(*first_grad_single_value, u)
1140
+
1141
+ def grad_all(u, first_grad):
1142
+ if has_aux:
1143
+ return _grad_all(aux_fn)(*first_grad, u)
1144
+ return _grad_all(fn)(*first_grad, u)
1145
+
1146
+ @jit
1147
+ def wrapped(*args):
1148
+ checked_grad_position = _check_grad_position(grad_position, len(args))
1149
+ primals, v, inputs_shape = _jacfwd_construct_v(args, checked_grad_position)
1150
+
1151
+ def inner_fn(jvp_inputs, vectors):
1152
+ outputs = fn(*jvp_inputs)
1153
+ if isinstance(outputs, tuple):
1154
+ u = ()
1155
+ for item in outputs:
1156
+ u = u + (mutable(oneslike(item)),)
1157
+ else:
1158
+ u = mutable(oneslike(outputs))
1159
+ if len(jvp_inputs) == 1:
1160
+ second_grad_net = _grad_single(grad_single)
1161
+ else:
1162
+ second_grad_net = _grad_single(grad_all)
1163
+ gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
1164
+ return gradient_outputs
1165
+
1166
+ def inner_aux_fn(jvp_inputs, vectors):
1167
+ outputs = aux_fn(*jvp_inputs)
1168
+ u = mutable(oneslike(outputs))
1169
+ if len(jvp_inputs) == 1:
1170
+ second_grad_net = _grad_single(grad_single)
1171
+ else:
1172
+ second_grad_net = _grad_single(grad_all)
1173
+ gradient_outputs = second_grad_net(u, jvp_inputs, vectors)
1174
+ return gradient_outputs
1175
+
1176
+ if has_aux:
1177
+ res = _vmap(inner_aux_fn)(primals, v)
1178
+ jac_res = _jacfwd_postprocess(res, inputs_shape, checked_grad_position)
1179
+ forward_outputs = fn(*args)
1180
+ if len(forward_outputs) == 2:
1181
+ return jac_res, forward_outputs[1]
1182
+ return jac_res, forward_outputs[1:]
1183
+ res = _vmap(inner_fn)(primals, v)
1184
+ jac_res = _jacfwd_postprocess(res, inputs_shape, checked_grad_position)
1185
+ return jac_res
1186
+
1187
+ return wrapped
1188
+
1189
+
1190
+ def _jacrev_trans_item(item, outputs_shape):
1191
+ """transfer origin item to derivative of each output with respect to each input."""
1192
+ output_wrt_input_all = ()
1193
+ length = len(outputs_shape) - 1
1194
+ for i in range(length):
1195
+ origin_output_wrt_input = item[outputs_shape[i][1]:outputs_shape[i + 1][1]]
1196
+ target_dimension = _jac_generate_target_dimension(origin_output_wrt_input.shape)
1197
+ temp = transpose(origin_output_wrt_input, target_dimension)
1198
+ output_wrt_input = reshape(origin_output_wrt_input, outputs_shape[i + 1][0] + temp.shape[:-1])
1199
+ output_wrt_input_all += (output_wrt_input,)
1200
+ return output_wrt_input_all
1201
+
1202
+
1203
+ def _jacrev_postprocess(x, outputs_shape, grad_position):
1204
+ """reformat reverse-computed jacobian."""
1205
+ return _jac_postprocess(x, outputs_shape, grad_position, 'reverse')
1206
+
1207
+
1208
+ def _jacrev_construct_v(inputs, outputs, has_aux=False):
1209
+ """
1210
+ For outputs (y1, y2), y1.shape = (a, b), y2.shape = (c, d), this method generates corresponding v (v1, v2),
1211
+ v1.shape = (N, a, b), v2.shape = (N, c, d), while N = a*b + c*d.
1212
+ """
1213
+ if isinstance(outputs, Tensor):
1214
+ outputs = (outputs,)
1215
+ if has_aux:
1216
+ outputs = (outputs[0],)
1217
+ v = ()
1218
+ primals = ()
1219
+ outputs_shape = (((), 0),)
1220
+ num = 0
1221
+ items_num = ()
1222
+ cum_num = (0,)
1223
+ for item in outputs:
1224
+ item_num = size(item)
1225
+ num += item_num
1226
+ outputs_shape += ((item.shape, num),)
1227
+ items_num += (item_num,)
1228
+ cum_num += (num,)
1229
+ for element in inputs:
1230
+ primal = broadcast_to(element, (num,) + element.shape)
1231
+ primals += (primal,)
1232
+ for i, element in enumerate(outputs):
1233
+ item_size = items_num[i]
1234
+ temp2 = Tensor(np.eye(num, item_size, -cum_num[i], np.float32))
1235
+ output_v = reshape(temp2, (num,) + element.shape)
1236
+ v += (output_v,)
1237
+ if len(outputs) == 1 or has_aux:
1238
+ return primals, v[0], outputs_shape
1239
+ return primals, v, outputs_shape
1240
+
1241
+
1242
+ _grad = _Grad(get_by_position=True, has_aux=False, sens_param=True)
1243
+
1244
+
1245
+ def jacrev(fn, grad_position=0, has_aux=False):
1246
+ """
1247
+ Compute Jacobian via reverse mode, corresponding to
1248
+ `reverse-mode differentiation
1249
+ <https://www.mindspore.cn/docs/en/master/design/programming_paradigm.html#reverse-mode-ad>`_.
1250
+ When number of inputs is much greater than that of outputs, it's better to calculate Jacobian via reverse mode than
1251
+ forward mode to get better performance.
1252
+
1253
+ Args:
1254
+ fn (Union[Cell, Function]): Function to do GradOperation.
1255
+ grad_position (Union[int, tuple[int]], optional): If int, get the gradient with respect to single input.
1256
+ If tuple, get the gradients with respect to selected inputs. 'grad_position' begins with 0. Default: ``0`` .
1257
+ has_aux (bool, optional): If ``True`` , only the first output of `fn` contributes the gradient of `fn`,
1258
+ while the other outputs will be returned straightly. It means the `fn` must return more than
1259
+ one outputs in this case. Default: ``False`` .
1260
+
1261
+ Returns:
1262
+ Function, returns the Jacobian function for the input function or cell.
1263
+ For example, as for `out1, out2 = fn(*args)`, when `has_aux` is set ``True`` , gradient function will return
1264
+ outputs like `(Jacobian, out2)` and `out2` does not contribute to the differentiation, otherwise `Jacobian` .
1265
+
1266
+ Raises:
1267
+ TypeError: `grad_position` or `has_aux` does not belong to required types.
1268
+
1269
+ Supported Platforms:
1270
+ ``Ascend`` ``GPU`` ``CPU``
1271
+
1272
+ Examples:
1273
+ >>> import numpy as np
1274
+ >>> import mindspore.nn as nn
1275
+ >>> from mindspore import jacrev
1276
+ >>> from mindspore import Tensor
1277
+ >>> class MultipleInputsMultipleOutputsNet(nn.Cell):
1278
+ ... def construct(self, x, y, z):
1279
+ ... return x ** 2 + y ** 2 + z ** 2, x * y * z
1280
+ >>> x = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
1281
+ >>> y = Tensor(np.array([[1, 2], [3, 4]]).astype(np.float32))
1282
+ >>> z = Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))
1283
+ >>> net = MultipleInputsMultipleOutputsNet()
1284
+ >>> jac, aux = jacrev(net, grad_position=0, has_aux=True)(x, y, z)
1285
+ >>> print(jac)
1286
+ [[[[ 2. 0.]
1287
+ [ 0. 0.]]
1288
+ [[ 0. 4.]
1289
+ [ 0. 0.]]]
1290
+ [[[ 0. 0.]
1291
+ [ 6. 0.]]
1292
+ [[ 0. 0.]
1293
+ [ 0. 8.]]]]
1294
+ >>> print(aux)
1295
+ [[ 1. 4.]
1296
+ [ 9. 16.]]
1297
+ """
1298
+ _check_has_aux_type(has_aux)
1299
+
1300
+ def aux_fn(*args):
1301
+ outputs = fn(*args)
1302
+ if not isinstance(outputs, tuple) or len(outputs) < 2:
1303
+ raise ValueError("When 'has_aux' is True, origin 'fn' requires more than one outputs.")
1304
+ res = outputs[0]
1305
+ return res
1306
+
1307
+ @jit
1308
+ def wrapped(*args):
1309
+ checked_grad_position = _check_grad_position(grad_position, len(args))
1310
+ outputs = fn(*args)
1311
+ primals, v, outputs_shape = _jacrev_construct_v(args, outputs, has_aux)
1312
+
1313
+ def inner_fn(vjp_inputs, vectors):
1314
+ gradient_outputs = _grad(fn, None, checked_grad_position)(*vjp_inputs, vectors)
1315
+ return gradient_outputs
1316
+
1317
+ def inner_aux_fn(vjp_inputs, vectors):
1318
+ gradient_outputs = _grad(aux_fn, None, checked_grad_position)(*vjp_inputs, vectors)
1319
+ return gradient_outputs
1320
+
1321
+ if has_aux:
1322
+ res = _vmap(inner_aux_fn)(primals, v)
1323
+ jac_res = _jacrev_postprocess(res, outputs_shape, checked_grad_position)
1324
+ forward_outputs = fn(*args)
1325
+ if len(forward_outputs) == 2:
1326
+ return jac_res, forward_outputs[1]
1327
+ return jac_res, forward_outputs[1:]
1328
+
1329
+ res = _vmap(inner_fn)(primals, v)
1330
+ jac_res = _jacrev_postprocess(res, outputs_shape, checked_grad_position)
1331
+ return jac_res
1332
+
1333
+ return wrapped
1334
+
1335
+
1336
+ def custom_vjp(fn=None):
1337
+ """
1338
+ Support vjp to custom bprop for function.
1339
+
1340
+ Args:
1341
+ fn (function): The `fn` that need to define custom bprop. Default: ``None``.
1342
+
1343
+ Supported Platforms:
1344
+ ``Ascend`` ``GPU`` ``CPU``
1345
+ """
1346
+
1347
+ def deco(fn):
1348
+ class CustomVjp(Cell):
1349
+ """
1350
+ The CustomVjp decorates function into cell to support custom bprop.
1351
+ """
1352
+
1353
+ def __init__(self, fwd):
1354
+ super(CustomVjp, self).__init__()
1355
+ self.fwd = fwd
1356
+ self.bwd = None
1357
+ self.add_flags(custom_vjp=True)
1358
+
1359
+ def construct(self, *args):
1360
+ return self.fwd(*args)
1361
+
1362
+ def defbwd(self, bwd):
1363
+ self.bwd = bwd
1364
+
1365
+ def bprop(self, *args):
1366
+ return self.bwd(*args)
1367
+
1368
+ return CustomVjp(fn)
1369
+
1370
+ if fn is not None:
1371
+ return deco(fn)
1372
+ return deco
1373
+
1374
+
1375
+ def stop_gradient(value):
1376
+ """
1377
+ StopGradient is used for eliminating the effect of a value on the gradient, such as truncating
1378
+ the gradient propagation from an output of a function.
1379
+ For more details, please refer to `Stop Gradient
1380
+ <https://www.mindspore.cn/tutorials/en/master/beginner/autograd.html#stop-gradient>`_.
1381
+
1382
+ Args:
1383
+ value (Any): The value whose effect on the gradient to be eliminated.
1384
+
1385
+ Returns:
1386
+ The same as `value`.
1387
+
1388
+ Supported Platforms:
1389
+ ``Ascend`` ``GPU`` ``CPU``
1390
+
1391
+ Examples:
1392
+ >>> from mindspore import ops
1393
+ >>> from mindspore import Tensor
1394
+ >>> from mindspore import dtype as mstype
1395
+ >>> def net(x, y):
1396
+ ... out1 = ops.MatMul()(x, y)
1397
+ ... out2 = ops.MatMul()(x, y)
1398
+ ... out2 = ops.stop_gradient(out2)
1399
+ ... return out1, out2
1400
+ ...
1401
+ >>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
1402
+ >>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
1403
+ >>> grad_fn = ops.grad(net)
1404
+ >>> output = grad_fn(x, y)
1405
+ >>> print(output)
1406
+ [[1.4100001 1.6 6.5999994]
1407
+ [1.4100001 1.6 6.5999994]]
1408
+ """
1409
+ return stop_gradient_(value)
1410
+
1411
+
1412
+ __all__ = [
1413
+ 'grad',
1414
+ 'value_and_grad',
1415
+ 'jacfwd',
1416
+ 'jacrev',
1417
+ 'jet',
1418
+ 'derivative',
1419
+ 'jvp',
1420
+ 'vjp',
1421
+ 'linearize',
1422
+ 'stop_gradient',
1423
+ 'get_grad'
1424
+ ]
1425
+ __all__.sort()