mindspore 2.4.0__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-311-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-311-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-311-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1844 @@
1
+ # Copyright 2020-2022 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0(the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http: // www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ """Operators for quantization."""
17
+ from __future__ import absolute_import
18
+ from functools import partial
19
+
20
+ import mindspore.context as context
21
+ from mindspore import _checkparam as validator
22
+ from mindspore.ops.primitive import Primitive, PrimitiveWithInfer, prim_attr_register
23
+ from mindspore.common import dtype as mstype
24
+ from mindspore.common.dtype import QuantDtype
25
+
26
+
27
+ def _support_te():
28
+ try:
29
+ import te # pylint: disable=unused-import
30
+ return True
31
+ # pylint: disable=broad-except
32
+ except Exception:
33
+ return False
34
+
35
+ if context.get_context('device_target') == "Ascend" and _support_te():
36
+ import mindspore.ops._op_impl._custom_op
37
+
38
+ __all__ = ["MinMaxUpdatePerLayer",
39
+ "MinMaxUpdatePerChannel",
40
+ "FakeLearnedScaleQuantPerLayer",
41
+ "FakeLearnedScaleQuantPerLayerGrad",
42
+ "FakeLearnedScaleQuantPerLayerGradD",
43
+ "FakeLearnedScaleQuantPerLayerGradDReduce",
44
+ "FakeLearnedScaleQuantPerChannel",
45
+ "FakeLearnedScaleQuantPerChannelGrad",
46
+ "FakeLearnedScaleQuantPerChannelGradD",
47
+ "FakeLearnedScaleQuantPerChannelGradDReduce",
48
+ "FakeQuantWithMinMaxVars",
49
+ "FakeQuantWithMinMaxVarsGradient",
50
+ "FakeQuantWithMinMaxVarsPerChannel",
51
+ "FakeQuantWithMinMaxVarsPerChannelGradient",
52
+ "FakeQuantPerLayer",
53
+ "FakeQuantPerLayerGrad",
54
+ "FakeQuantPerChannel",
55
+ "FakeQuantPerChannelGrad",
56
+ "BatchNormFold",
57
+ "BatchNormFoldGrad",
58
+ "CorrectionMul",
59
+ "CorrectionMulGrad",
60
+ "CorrectionMulGradReduce",
61
+ "BatchNormFold2",
62
+ "BatchNormFold2Grad",
63
+ "BatchNormFoldD",
64
+ "BatchNormFoldGradD",
65
+ "BatchNormFold2D",
66
+ "BatchNormFold2GradD",
67
+ "BatchNormFold2GradReduce",
68
+ "IFMR",
69
+ "ActsULQ",
70
+ "ActsULQInputGrad",
71
+ "ActULQClampMinGrad",
72
+ "ActULQClampMaxGrad",
73
+ "WtsARQ",
74
+ "FakeQuantParam",
75
+ ]
76
+
77
+
78
+ class FakeQuantParam(Primitive):
79
+ r"""
80
+ Define the operation for storing quant parameter. This operation passes through input tensor to output tensor
81
+ without any calculation.
82
+
83
+ Args:
84
+ quant_dtype (QuantDtype) - The valid data type of the input tensor.
85
+ quant_algo_name (str) - Define the name of quant algorithm. Use
86
+ `FakeQuantParam.attr_value_linear_quant_algo_name` for linear quantization specially.
87
+ is_per_channel (bool) - Define whether quant parameter is per-channel or per-layer.
88
+ kwargs (dict): Other quant parameter in key-value form. Please use classmethod `linear_quant_param` to create a
89
+ linear quantization specially because key of scale and zero-point is pre-defined by MindSpore.
90
+
91
+ Inputs:
92
+ - *input_x* (Tensor) : Input tensor.
93
+
94
+ Outputs:
95
+ - Tensor: Output tensor same with `input_x`.
96
+
97
+ Examples:
98
+ >>> input_tensor = mindspore.Tensor(numpy.random.rand(1, 16, 5, 5), mindspore.dtype.float32)
99
+ >>> fake_quant_param_op = FakeQuantParam.linear_quant_param(mindspore.common.dtype.QuantDtype.INT8,
100
+ >>> 0.5, 1)
101
+ >>> output_tensor = fake_quant_param_op(input_tensor)
102
+ """
103
+
104
+ attr_key_linear_quant_scale = "linear_quant_scale"
105
+ attr_key_linear_quant_zero_point = "linear_quant_zero_point"
106
+
107
+ attr_value_linear_quant_algo_name = "linear_quant_algo"
108
+
109
+ @prim_attr_register
110
+ def __init__(self, quant_dtype: QuantDtype, quant_algo_name: str, is_per_channel: bool, **kwargs):
111
+ self.add_prim_attr("quant_algo_name", quant_algo_name)
112
+ self.add_prim_attr("is_per_channel", is_per_channel)
113
+ self.add_prim_attr("quant_dtype", quant_dtype.value())
114
+ for key, value in kwargs.items():
115
+ self.add_prim_attr(key, value)
116
+
117
+ @classmethod
118
+ def linear_quant_param(cls, quant_dtype, scale, zp, is_per_channel=False, **kwargs):
119
+ """
120
+ Create a linear quantization operator based on scale and zero-point parameter.
121
+ """
122
+ validator.check_value_type("scale", scale, [float, tuple, list], "FakeQuantParam")
123
+ if isinstance(scale, float):
124
+ scale_list = [scale]
125
+ else:
126
+ scale_list = scale
127
+ validator.check_value_type("zero_point", zp, [int, tuple, list], "FakeQuantParam")
128
+ if isinstance(zp, int):
129
+ zp_list = [zp]
130
+ else:
131
+ zp_list = zp
132
+ validator.check_value_type("is_per_channel", is_per_channel, [bool], "FakeQuantParam")
133
+ kwargs[FakeQuantParam.attr_key_linear_quant_scale] = scale_list
134
+ kwargs[FakeQuantParam.attr_key_linear_quant_zero_point] = zp_list
135
+ return cls(quant_dtype, FakeQuantParam.attr_value_linear_quant_algo_name, is_per_channel, **kwargs)
136
+
137
+
138
+ class MinMaxUpdatePerLayer(PrimitiveWithInfer):
139
+ r"""
140
+ Updates min and max per layer.
141
+
142
+ Args:
143
+ ema (bool): Uses EMA algorithm update value min and max. Default: ``False``.
144
+ ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
145
+
146
+ Inputs:
147
+ - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
148
+ - **min** (Tensor) : Value of the min range of the input data x.
149
+ - **max** (Tensor) : Value of the max range of the input data x.
150
+
151
+ Outputs:
152
+ - Tensor: Simulates quantize tensor of x.
153
+
154
+ Examples:
155
+ >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
156
+ >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
157
+ >>> max_tensor = Tensor(np.array([6]), mstype.float32)
158
+ >>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
159
+ """
160
+ support_quant_bit = [4, 7, 8]
161
+
162
+ @prim_attr_register
163
+ def __init__(self, ema=False, ema_decay=0.999):
164
+ """Initialize FakeQuantMinMaxPerLayerUpdate OP"""
165
+ if context.get_context('device_target') == "Ascend":
166
+ from mindspore.ops._op_impl._custom_op import minmax_update_perlayer
167
+ if ema and not ema_decay:
168
+ raise ValueError(
169
+ f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
170
+
171
+ self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
172
+ self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
173
+ self.init_prim_io_names(inputs=['x', 'min', 'max'],
174
+ outputs=['min_up', 'max_up'])
175
+
176
+ def infer_shape(self, x_shape, min_shape, max_shape):
177
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
178
+ validator.check("min shape", min_shape, "max shape",
179
+ max_shape, validator.EQ, self.name)
180
+ validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
181
+ return min_shape, max_shape
182
+
183
+ def infer_dtype(self, x_type, min_type, max_type):
184
+ tuple(map(partial(validator.check_tensor_dtype_valid,
185
+ valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
186
+ ("x", "min", "max"),
187
+ (x_type, min_type, max_type)))
188
+ return min_type, max_type
189
+
190
+
191
+ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
192
+ r"""
193
+ Updates min and max per channel.
194
+
195
+ Args:
196
+ ema (bool): Uses EMA algorithm update value min and max. Default: ``False``.
197
+ ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
198
+ channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
199
+
200
+ Inputs:
201
+ - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
202
+ - **min** (Tensor) : Value of the min range of the input data x.
203
+ - **max** (Tensor) : Value of the max range of the input data x.
204
+
205
+ Outputs:
206
+ - Tensor: Simulates quantize tensor of x.
207
+
208
+ Examples:
209
+ >>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
210
+ >>> min_value = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
211
+ >>> max_value = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
212
+ >>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min_value, max_value)
213
+ """
214
+ support_quant_bit = [4, 7, 8]
215
+ ascend_support_x_rank = [2, 4]
216
+
217
+ @prim_attr_register
218
+ def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
219
+ """Initialize FakeQuantPerChannelUpdate OP for Ascend"""
220
+ self.is_ascend = context.get_context('device_target') == "Ascend"
221
+ if self.is_ascend:
222
+ from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
223
+ if ema and not ema_decay:
224
+ raise ValueError(
225
+ f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
226
+
227
+ self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
228
+ self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
229
+ if self.is_ascend:
230
+ self.channel_axis = validator.check_int_range(channel_axis, 0, 1, validator.INC_BOTH,
231
+ 'channel_axis', self.name)
232
+ else:
233
+ self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
234
+ self.init_prim_io_names(
235
+ inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
236
+
237
+ def infer_shape(self, x_shape, min_shape, max_shape):
238
+ if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
239
+ raise ValueError(f"For '{self.name}' x rank must be in '{self.ascend_support_x_rank}'")
240
+ if not self.is_ascend:
241
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
242
+ validator.check("min shape", min_shape, "max shape",
243
+ max_shape, validator.EQ, self.name)
244
+ validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
245
+ return min_shape, max_shape
246
+
247
+ def infer_dtype(self, x_type, min_type, max_type):
248
+ tuple(map(partial(validator.check_tensor_dtype_valid,
249
+ valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
250
+ ("x", "min", "max"),
251
+ (x_type, min_type, max_type)))
252
+ return min_type, max_type
253
+
254
+
255
+ class FakeLearnedScaleQuantPerLayer(PrimitiveWithInfer):
256
+ r"""
257
+ Simulates the quantize and dequantize operations of the fake learned scale quant per-layer case in training time.
258
+
259
+ Args:
260
+ quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
261
+ simulate quantization aware function. After delay step in training time begin simulate the aware
262
+ quantize function. Default: 0.
263
+ neg_trunc (bool): Whether the quantization algorithm uses negative truncation or not. Default: ``False``.
264
+ training (bool): Training the network or not. Default: ``True``.
265
+
266
+ Inputs:
267
+ - **input_x** (Tensor) : Input tensor that needs to be quantified.
268
+ - **alpha** (Tensor) : Value of the max clipping range of the input data `input_x`.
269
+ - **quant_max** (Tensor) : Value of the quantization range.
270
+
271
+ Outputs:
272
+ - Tensor: Simulates quantize tensor of `input_x`, with the same type and shape as the `input_x`.
273
+
274
+ Examples:
275
+ >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
276
+ >>> alpha_tensor = Tensor(np.array([6]), mstype.float32)
277
+ >>> quant_max_tensor = Tensor(np.array([127]), mstype.float32)
278
+ >>> output_tensor = FakeLearnedScaleQuantPerLayer()(input_tensor, alpha_tensor, quant_max_tensor)
279
+ """
280
+ @prim_attr_register
281
+ def __init__(self,
282
+ quant_delay=0,
283
+ neg_trunc=False,
284
+ training=True):
285
+ """init FakeLearnedScaleQuantPerLayer OP"""
286
+ if context.get_context('device_target') == "Ascend":
287
+ from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perlayer
288
+
289
+ self.quant_delay = validator.check_non_negative_int(
290
+ quant_delay, 'quant_delay', self.name)
291
+ self.neg_trunc = validator.check_value_type(
292
+ 'neg_trunc', neg_trunc, (bool,), self.name)
293
+ self.training = validator.check_value_type(
294
+ 'training', training, (bool,), self.name)
295
+ self.init_prim_io_names(inputs=['input_x', 'alpha', 'quant_max'],
296
+ outputs=['out'])
297
+
298
+ def infer_shape(self, input_x_shape, alpha_shape, quant_max_shape):
299
+ validator.check_int(len(input_x_shape), 1, validator.GE, "input_x rank", self.name)
300
+ validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
301
+ validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
302
+ return input_x_shape
303
+
304
+ def infer_dtype(self, input_x_type, alpha_type, quant_max_type):
305
+ if context.get_context('device_target') == "GPU":
306
+ valid_dtypes = (mstype.float32,)
307
+ else:
308
+ valid_dtypes = (mstype.float16, mstype.float32)
309
+ tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
310
+ ("input_x", "alpha", "quant_max"),
311
+ (input_x_type, alpha_type, quant_max_type)))
312
+ return input_x_type
313
+
314
+
315
+ class FakeLearnedScaleQuantPerLayerGrad(PrimitiveWithInfer):
316
+ r"""
317
+ Performs grad of FakeLearnedScaleQuantPerLayer operation.
318
+
319
+ Examples:
320
+ >>> fake_learned_scale_grad = FakeLearnedScaleQuantPerLayerGrad()
321
+ >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
322
+ >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
323
+ >>> _alpha = Tensor(np.array([6]), mindspore.float32)
324
+ >>> _quant_max = Tensor(np.array([127]), mindspore.float32)
325
+ >>> result = fake_learned_scale_grad(dout, input_x, _min, _max)
326
+ """
327
+
328
+ @prim_attr_register
329
+ def __init__(self,
330
+ quant_delay=0,
331
+ neg_trunc=False):
332
+ self.quant_delay = validator.check_non_negative_int(
333
+ quant_delay, 'quant_delay', self.name)
334
+ self.neg_trunc = validator.check_value_type(
335
+ 'neg_trunc', neg_trunc, (bool,), self.name)
336
+ self.init_prim_io_names(
337
+ inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
338
+
339
+ def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
340
+ validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
341
+ validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
342
+ validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
343
+ return dout_shape, alpha_shape
344
+
345
+ def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
346
+ if context.get_context('device_target') == "GPU":
347
+ valid_dtypes = (mstype.float32,)
348
+ else:
349
+ valid_dtypes = (mstype.float16, mstype.float32)
350
+ tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
351
+ ("dout", "x", "alpha", "quant_max"),
352
+ (dout_type, x_type, alpha_type, quant_max_type)))
353
+ return dout_type, alpha_type
354
+
355
+
356
+ class FakeLearnedScaleQuantPerLayerGradD(PrimitiveWithInfer):
357
+ r"""
358
+ Performs input grad of FakeLearnedScaleQuantPerLayer operation.
359
+ """
360
+
361
+ @prim_attr_register
362
+ def __init__(self,
363
+ neg_trunc=False):
364
+ from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perlayer_grad
365
+ self.neg_trunc = validator.check_value_type(
366
+ 'neg_trunc', neg_trunc, (bool,), self.name)
367
+ self.init_prim_io_names(
368
+ inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
369
+
370
+ def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
371
+ validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
372
+ validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
373
+ validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
374
+ return dout_shape, dout_shape
375
+
376
+ def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
377
+ valid_dtypes = (mstype.float16, mstype.float32)
378
+ tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
379
+ ("dout", "x", "alpha", "quant_max"),
380
+ (dout_type, x_type, alpha_type, quant_max_type)))
381
+ return dout_type, dout_type
382
+
383
+
384
+ class FakeLearnedScaleQuantPerLayerGradDReduce(PrimitiveWithInfer):
385
+ r"""
386
+ Performs alpha grad reduce of FakeLearnedScaleQuantPerLayer operation.
387
+ """
388
+
389
+ @prim_attr_register
390
+ def __init__(self):
391
+ from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perlayer_grad_reduce
392
+ self.init_prim_io_names(
393
+ inputs=['dout_alpha'], outputs=['dalpha'])
394
+
395
+ def infer_shape(self, dout_alpha_shape):
396
+ return (1,)
397
+
398
+ def infer_dtype(self, dout_alpha_type):
399
+ valid_dtypes = (mstype.float16, mstype.float32)
400
+ validator.check_tensor_dtype_valid("dout_alpha", dout_alpha_type, valid_dtypes, self.name)
401
+ return dout_alpha_type
402
+
403
+
404
+ class FakeLearnedScaleQuantPerChannel(PrimitiveWithInfer):
405
+ r"""
406
+ Simulates the quantize and dequantize operations of the fake learned scale quant per-channel case in training time.
407
+
408
+ Args:
409
+ quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
410
+ simulate quantization aware function. After delay step in training time begin simulate the aware
411
+ quantize function. Default: 0.
412
+ neg_trunc (bool): Whether the quantization algorithm uses negative truncation or not. Default: ``False``.
413
+ training (bool): Training the network or not. Default: ``True``.
414
+ channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
415
+
416
+ Inputs:
417
+ - **input_x** (Tensor) : Input tensor that needs to be quantified.
418
+ - **alpha** (Tensor) : Value of the max clipping range of the input data `input_x`.
419
+ - **quant_max** (Tensor) : Value of the quantization range.
420
+
421
+ Outputs:
422
+ - Tensor: Simulates quantize tensor of `input_x`, with the same type and shape as the `input_x`.
423
+
424
+ Examples:
425
+ >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
426
+ >>> alpha_tensor = Tensor(np.array([6]*3), mstype.float32)
427
+ >>> quant_max_tensor = Tensor(np.array([127]), mstype.float32)
428
+ >>> output_tensor = FakeLearnedScaleQuantPerChannel()(input_tensor, alpha_tensor, quant_max_tensor)
429
+ """
430
+ ascend_support_x_rank = [2, 4]
431
+
432
+ @prim_attr_register
433
+ def __init__(self,
434
+ quant_delay=0,
435
+ neg_trunc=False,
436
+ training=True,
437
+ channel_axis=1):
438
+ """init FakeLearnedScaleQuantPerChannel OP"""
439
+ if context.get_context('device_target') == "Ascend":
440
+ from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perchannel
441
+ self.is_ascend = context.get_context('device_target') == "Ascend"
442
+ self.quant_delay = validator.check_non_negative_int(
443
+ quant_delay, 'quant_delay', self.name)
444
+ self.neg_trunc = validator.check_value_type(
445
+ 'neg_trunc', neg_trunc, (bool,), self.name)
446
+ self.training = validator.check_value_type(
447
+ 'training', training, (bool,), self.name)
448
+ if self.is_ascend:
449
+ self.channel_axis = validator.check_int_range(channel_axis, 0, 1, validator.INC_BOTH,
450
+ 'channel_axis', self.name)
451
+ else:
452
+ self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
453
+ self.init_prim_io_names(inputs=['input_x', 'alpha', 'quant_max'],
454
+ outputs=['out'])
455
+
456
+ def infer_shape(self, input_x_shape, alpha_shape, quant_max_shape):
457
+ if self.is_ascend and len(input_x_shape) not in self.ascend_support_x_rank:
458
+ raise ValueError(f"For '{self.name}' x rank must be in '{self.ascend_support_x_rank}'")
459
+ if not self.is_ascend:
460
+ validator.check_int(len(input_x_shape), 1, validator.GE, "input_x rank", self.name)
461
+ if len(input_x_shape) == 1:
462
+ self.channel_axis = 0
463
+
464
+ validator.check_equal_int(alpha_shape[0], input_x_shape[self.channel_axis], "alpha rank", self.name)
465
+ validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
466
+ return input_x_shape
467
+
468
+ def infer_dtype(self, input_x_type, alpha_type, quant_max_type):
469
+ if context.get_context('device_target') == "GPU":
470
+ valid_dtypes = (mstype.float32,)
471
+ else:
472
+ valid_dtypes = (mstype.float16, mstype.float32)
473
+ tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
474
+ ("input_x", "alpha", "quant_max"),
475
+ (input_x_type, alpha_type, quant_max_type)))
476
+ return input_x_type
477
+
478
+
479
+ class FakeLearnedScaleQuantPerChannelGrad(PrimitiveWithInfer):
480
+ r"""
481
+ Performs grad of FakeLearnedScaleQuantPerChannel operation.
482
+
483
+ Examples:
484
+ >>> fake_learned_scale_grad = FakeLearnedScaleQuantPerChannelGrad()
485
+ >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
486
+ >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
487
+ >>> _alpha = Tensor(np.array([6]*2), mindspore.float32)
488
+ >>> _quant_max = Tensor(np.array([127]), mindspore.float32)
489
+ >>> result = fake_learned_scale_grad(dout, input_x, _min, _max)
490
+ """
491
+
492
+ @prim_attr_register
493
+ def __init__(self,
494
+ quant_delay=0,
495
+ neg_trunc=False,
496
+ channel_axis=1):
497
+ self.quant_delay = validator.check_non_negative_int(
498
+ quant_delay, 'quant_delay', self.name)
499
+ self.neg_trunc = validator.check_value_type(
500
+ 'neg_trunc', neg_trunc, (bool,), self.name)
501
+ self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name)
502
+ self.init_prim_io_names(
503
+ inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
504
+
505
+ def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
506
+ validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
507
+ return dout_shape, alpha_shape
508
+
509
+ def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
510
+ if context.get_context('device_target') == "GPU":
511
+ valid_dtypes = (mstype.float32,)
512
+ else:
513
+ valid_dtypes = (mstype.float16, mstype.float32)
514
+ tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
515
+ ("dout", "x", "alpha", "quant_max"),
516
+ (dout_type, x_type, alpha_type, quant_max_type)))
517
+ return dout_type, alpha_type
518
+
519
+
520
+ class FakeLearnedScaleQuantPerChannelGradD(PrimitiveWithInfer):
521
+ r"""
522
+ Performs input grad of FakeLearnedScaleQuantPerChannel operation.
523
+ """
524
+
525
+ @prim_attr_register
526
+ def __init__(self,
527
+ neg_trunc=False,
528
+ channel_axis=1):
529
+ from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perchannel_grad
530
+ self.neg_trunc = validator.check_value_type(
531
+ 'neg_trunc', neg_trunc, (bool,), self.name)
532
+ self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name)
533
+ self.init_prim_io_names(
534
+ inputs=['dout', 'x', 'alpha', 'quant_max'], outputs=['dx', 'dalpha'])
535
+
536
+ def infer_shape(self, dout_shape, x_shape, alpha_shape, quant_max_shape):
537
+ validator.check("dout shape", dout_shape, "x_shape", x_shape, validator.EQ, self.name)
538
+ validator.check_int(len(alpha_shape), 1, validator.GE, "alpha rank", self.name)
539
+ validator.check_int(len(quant_max_shape), 1, validator.GE, "quant max rank", self.name)
540
+ return dout_shape, dout_shape
541
+
542
+ def infer_dtype(self, dout_type, x_type, alpha_type, quant_max_type):
543
+ valid_dtypes = (mstype.float16, mstype.float32)
544
+ tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=self.name),
545
+ ("dout", "x", "alpha", "quant_max"),
546
+ (dout_type, x_type, alpha_type, quant_max_type)))
547
+ return dout_type, dout_type
548
+
549
+
550
+ class FakeLearnedScaleQuantPerChannelGradDReduce(PrimitiveWithInfer):
551
+ r"""
552
+ Performs alpha grad reduce of FakeLearnedScaleQuantPerChannel operation.
553
+ """
554
+
555
+ @prim_attr_register
556
+ def __init__(self, channel_axis=1):
557
+ from mindspore.ops._op_impl._custom_op import fake_learned_scale_quant_perchannel_grad_reduce
558
+ self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name)
559
+ self.init_prim_io_names(
560
+ inputs=['dout_alpha'], outputs=['dalpha'])
561
+
562
+ def infer_shape(self, dout_alpha_shape):
563
+ return (dout_alpha_shape[self.channel_axis],)
564
+
565
+ def infer_dtype(self, dout_alpha_type):
566
+ valid_dtypes = (mstype.float16, mstype.float32)
567
+ validator.check_tensor_dtype_valid("dout_alpha", dout_alpha_type, valid_dtypes, self.name)
568
+ return dout_alpha_type
569
+
570
+
571
+ class FakeQuantWithMinMaxVars(PrimitiveWithInfer):
572
+ r"""
573
+ Fake-quantize the input by min and max.
574
+
575
+ Args:
576
+ num_bits (int): Quantization bitwidth; between 2 and 16. Default: 8.
577
+ narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
578
+ if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
579
+ range is [1, 2^num_bits-1]. Default: ``False``.
580
+
581
+ Inputs:
582
+ - **x** (Tensor) - float32 tensor representing the shape of the output tensor.
583
+ - **min** (Tensor) - Value of the min range of the input data x.
584
+ - **max** (Tensor) - Value of the max range of the input data x.
585
+
586
+ Outputs:
587
+ - Tensor, the data type and shape of output tensor is the same as input x.
588
+
589
+ Examples:
590
+ >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
591
+ >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
592
+ >>> max_tensor = Tensor(np.array([6]), mstype.float32)
593
+ >>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
594
+ ... input_tensor, min_tensor, max_tensor)
595
+ >>> output_tensor # shape: (3, 16, 5, 5) data type: mstype.float32
596
+ """
597
+
598
+ @prim_attr_register
599
+ def __init__(self,
600
+ num_bits=8,
601
+ narrow_range=False):
602
+ self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
603
+ self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
604
+ self.narrow_range = validator.check_value_type(
605
+ 'narrow_range', narrow_range, (bool,), self.name)
606
+
607
+ def check_broadcast(self, min_shape, input_shape):
608
+ shape_val = 1
609
+ for shape in input_shape:
610
+ shape_val = shape_val * shape
611
+ if min_shape[0] > 1 and min_shape[0] != shape_val:
612
+ raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
613
+
614
+ def infer_shape(self, x_shape, min_shape, max_shape):
615
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
616
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
617
+ validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
618
+ self.check_broadcast(min_shape, x_shape)
619
+ return x_shape
620
+
621
+ def infer_dtype(self, x_type, min_type, max_type):
622
+ tuple(map(partial(validator.check_tensor_dtype_valid,
623
+ valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
624
+ ("x", "min", "max"),
625
+ (x_type, min_type, max_type)))
626
+ return x_type
627
+
628
+
629
+ class FakeQuantWithMinMaxVarsGradient(PrimitiveWithInfer):
630
+ r"""
631
+ Performs grad of FakeQuantWithMinMaxVars operation.
632
+
633
+ Args:
634
+ num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
635
+ narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
636
+ if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
637
+ range is [1, 2^num_bits-1]. Default: ``False``.
638
+
639
+ Inputs:
640
+ - **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars.
641
+ - **x** (Tensor) - float32 tensor representing the shape of the output tensor.
642
+ - **min** (Tensor) - Value of the min range of the input data x.
643
+ - **max** (Tensor) - Value of the max range of the input data x.
644
+
645
+ Outputs:
646
+ - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x.
647
+ - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min.
648
+ - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max.
649
+
650
+ Examples:
651
+ >>> gradients = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
652
+ >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
653
+ >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
654
+ >>> max_tensor = Tensor(np.array([6]), mstype.float32)
655
+ >>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsGradient(num_bits=8,narrow_range=False)
656
+ ... (gradients, input_tensor, min_tensor, max_tensor)
657
+ >>> x_gradient # shape: (3, 16, 5, 5) data type: mstype.float32
658
+ >>> min_gradient # shape: (1,) data type: mstype.float32
659
+ >>> max_gradient # shape: (1,) data type: mstype.float32
660
+ """
661
+
662
+ @prim_attr_register
663
+ def __init__(self,
664
+ num_bits=8,
665
+ narrow_range=False):
666
+ self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
667
+ self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
668
+ self.narrow_range = validator.check_value_type(
669
+ 'narrow_range', narrow_range, (bool,), self.name)
670
+
671
+ def check_broadcast(self, min_shape, input_shape):
672
+ shape_val = 1
673
+ for shape in input_shape:
674
+ shape_val = shape_val * shape
675
+ if min_shape[0] > 1 and min_shape[0] != shape_val:
676
+ raise ValueError(f"For '{self.name}', the shape of \'min\' cannot broadcast to the shape of \'x\'.")
677
+
678
+ def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
679
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
680
+ validator.check("dout shape", dout_shape, "x shape", x_shape, validator.EQ, self.name)
681
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
682
+ validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
683
+ self.check_broadcast(min_shape, x_shape)
684
+ return x_shape, min_shape, max_shape
685
+
686
+ def infer_dtype(self, dout_type, x_type, min_type, max_type):
687
+ tuple(map(partial(validator.check_tensor_dtype_valid,
688
+ valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
689
+ ('dout', "x", "min", "max"),
690
+ (dout_type, x_type, min_type, max_type)))
691
+ return x_type, min_type, max_type
692
+
693
+
694
+ class FakeQuantWithMinMaxVarsPerChannel(PrimitiveWithInfer):
695
+ r"""
696
+ Fake-quantize the input and one of shape: [d], [b, d], [b, h, w, d] by per-channel min and max
697
+
698
+ Args:
699
+ num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
700
+ narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
701
+ if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
702
+ range is [1, 2^num_bits-1]. Default: ``False``.
703
+
704
+ Inputs:
705
+ - **x** (Tensor) - float32 tensor representing the shape of the output tensor.
706
+ - **min** (Tensor) - Value of the min range of the input data x.
707
+ - **max** (Tensor) - Value of the max range of the input data x.
708
+
709
+ Outputs:
710
+ - Tensor, the data type and shape of output tensor is the same as input x.
711
+
712
+ Examples:
713
+ >>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
714
+ >>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32)
715
+ >>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32)
716
+ >>> output_tensor = FakeQuantWithMinMaxVars(num_bits=8, narrow_range=False)(
717
+ ... input_tensor, min_tensor, max_tensor)
718
+ >>> output_tensor # shape: (3, 16, 3, 4) data type: mstype.float32
719
+ """
720
+
721
+ @prim_attr_register
722
+ def __init__(self,
723
+ num_bits=8,
724
+ narrow_range=False):
725
+ self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
726
+ self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
727
+ self.narrow_range = validator.check_value_type(
728
+ 'narrow_range', narrow_range, (bool,), self.name)
729
+
730
+ def infer_shape(self, x_shape, min_shape, max_shape):
731
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
732
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
733
+ validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
734
+ validator.check("min shape", min_shape[0], "x shape", x_shape[-1], validator.EQ, self.name)
735
+ return x_shape
736
+
737
+ def infer_dtype(self, x_type, min_type, max_type):
738
+ tuple(map(partial(validator.check_tensor_dtype_valid,
739
+ valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
740
+ ("x", "min", "max"),
741
+ (x_type, min_type, max_type)))
742
+ return x_type
743
+
744
+
745
+ class FakeQuantWithMinMaxVarsPerChannelGradient(PrimitiveWithInfer):
746
+ r"""
747
+ Performs grad of FakeQuantWithMinMaxVars operation.
748
+
749
+ Args:
750
+ num_bits (int): Quantization bitwidth; between 2 and 16, inclusive. Default: 8.
751
+ narrow_range (bool): Whether the quantization algorithm uses narrow range or not.
752
+ if True, the quantization range is [0, 2^num_bits-1]. Otherwise, the quantization
753
+ range is [1, 2^num_bits-1]. Default: ``False``.
754
+
755
+ Inputs:
756
+ - **gradients** (Tensor) - The gradient above the FakeQuantWithMinMaxVars.
757
+ - **x** (Tensor) - float32 tensor representing the shape of the output tensor.
758
+ - **min** (Tensor) - Value of the min range of the input data x.
759
+ - **max** (Tensor) - Value of the max range of the input data x.
760
+
761
+ Outputs:
762
+ - **backprops_wrt_x** (Tensor) - The gradient of input x, with the same shape and date type as input x.
763
+ - **backprops_wrt_min** (Tensor) - The gradient of input min, with the same shape and date type as input min.
764
+ - **backprops_wrt_max** (Tensor) - The gradient of input max, with the same shape and date type as input max.
765
+
766
+ Examples:
767
+ >>> gradients = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
768
+ >>> input_tensor = Tensor(np.random.rand(3, 16, 3, 4), mstype.float32)
769
+ >>> min_tensor = Tensor(np.array([-6, -1, -2, -3]), mstype.float32)
770
+ >>> max_tensor = Tensor(np.array([6, 1, 2, 3]), mstype.float32)
771
+ >>> x_gradient, min_gradient, max_gradient = FakeQuantWithMinMaxVarsPerChannelGradient(
772
+ ... num_bits=8, narrow_range=False)(
773
+ ... gradients, input_tensor, min_tensor, max_tensor)
774
+ >>> x_gradient # shape: (3, 16, 3, 4) data type: mstype.float32
775
+ >>> min_gradient # shape: (4,) data type: mstype.float32
776
+ >>> max_gradient # shape: (4,) data type: mstype.float32
777
+ """
778
+
779
+ @prim_attr_register
780
+ def __init__(self,
781
+ num_bits=8,
782
+ narrow_range=False):
783
+ self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
784
+ self.num_bits = validator.check_int_range(self.num_bits, 2, 16, validator.INC_BOTH, 'num_bits', self.name)
785
+ self.narrow_range = validator.check_value_type(
786
+ 'narrow_range', narrow_range, (bool,), self.name)
787
+
788
+ def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
789
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
790
+ validator.check("dout shape", dout_shape, "x shape", x_shape, validator.EQ, self.name)
791
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
792
+ validator.check_int(len(min_shape), 1, validator.EQ, "min shape", self.name)
793
+ validator.check("min shape", min_shape[0], "x shape", x_shape[-1], validator.EQ, self.name)
794
+ return x_shape, min_shape, max_shape
795
+
796
+ def infer_dtype(self, dout_type, x_type, min_type, max_type):
797
+ tuple(map(partial(validator.check_tensor_dtype_valid,
798
+ valid_dtypes=(mstype.float16, mstype.float32), prim_name=self.name),
799
+ ("dout", "x", "min", "max"),
800
+ (dout_type, x_type, min_type, max_type)))
801
+ return x_type, min_type, max_type
802
+
803
+
804
+ def _fake_quant_per_infer_dtype(prim_name, x_type, min_type, max_type):
805
+ if context.get_context('device_target') == "GPU":
806
+ valid_dtypes = (mstype.float32,)
807
+ else:
808
+ valid_dtypes = (mstype.float16, mstype.float32)
809
+ tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=prim_name),
810
+ ("x", "min", "max"),
811
+ (x_type, min_type, max_type)))
812
+ return x_type
813
+
814
+
815
+ def _fake_quant_per_grad_infer_dtype(prim_name, dout_type, x_type, min_type, max_type):
816
+ if context.get_context('device_target') == "GPU":
817
+ valid_dtypes = (mstype.float32,)
818
+ else:
819
+ valid_dtypes = (mstype.float16, mstype.float32)
820
+ tuple(map(partial(validator.check_tensor_dtype_valid, valid_dtypes=valid_dtypes, prim_name=prim_name),
821
+ ("dout", "x", "min", "max"),
822
+ (dout_type, x_type, min_type, max_type)))
823
+ return dout_type
824
+
825
+
826
+ class FakeQuantPerLayer(PrimitiveWithInfer):
827
+ r"""
828
+ Simulates the quantize and dequantize operations in training time.
829
+
830
+ Args:
831
+ num_bits (int) : Number bits for quantization aware. Default: 8.
832
+ ema (bool): Uses EMA algorithm update value min and max. Default: ``False``.
833
+ ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
834
+ quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
835
+ simulate quantization aware function. After delay step in training time begin simulate the aware
836
+ quantize function. Default: 0.
837
+ symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: ``False``.
838
+ narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: ``False``.
839
+ training (bool): Training the network or not. Default: ``True``.
840
+
841
+ Inputs:
842
+ - **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
843
+ - **min** (Tensor) : Value of the min range of the input data x.
844
+ - **max** (Tensor) : Value of the max range of the input data x.
845
+
846
+ Outputs:
847
+ - Tensor: Simulates quantize tensor of x.
848
+
849
+ Examples:
850
+ >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
851
+ >>> min_tensor = Tensor(np.array([-6]), mstype.float32)
852
+ >>> max_tensor = Tensor(np.array([6]), mstype.float32)
853
+ >>> output_tensor = FakeQuantPerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
854
+ """
855
+ support_quant_bit = [4, 7, 8]
856
+
857
+ @prim_attr_register
858
+ def __init__(self,
859
+ num_bits=8,
860
+ ema=False,
861
+ ema_decay=0.999,
862
+ quant_delay=0,
863
+ symmetric=False,
864
+ narrow_range=False,
865
+ training=True):
866
+ """Initialize FakeQuantPerLayer OP"""
867
+ if context.get_context('device_target') == "Ascend":
868
+ from mindspore.ops._op_impl._custom_op import fake_quant_perlayer
869
+ if num_bits not in self.support_quant_bit:
870
+ raise ValueError(
871
+ f"For '{self.name}' attr \'num_bits\' is not support.")
872
+ if ema and not ema_decay:
873
+ raise ValueError(
874
+ f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
875
+
876
+ self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
877
+ self.symmetric = validator.check_value_type(
878
+ 'symmetric', symmetric, (bool,), self.name)
879
+ self.narrow_range = validator.check_value_type(
880
+ 'narrow_range', narrow_range, (bool,), self.name)
881
+ self.training = validator.check_value_type('training', training, (bool,), self.name)
882
+ self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
883
+ self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
884
+ self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
885
+ self.init_prim_io_names(inputs=['x', 'min', 'max'],
886
+ outputs=['out'])
887
+
888
+ def infer_shape(self, x_shape, min_shape, max_shape):
889
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
890
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
891
+ validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
892
+ return x_shape
893
+
894
+ def infer_dtype(self, x_type, min_type, max_type):
895
+ return _fake_quant_per_infer_dtype(self.name, x_type, min_type, max_type)
896
+
897
+
898
+ class FakeQuantPerLayerGrad(PrimitiveWithInfer):
899
+ r"""
900
+ Performs grad of FakeQuantPerLayer operation.
901
+
902
+ Examples:
903
+ >>> fake_min_max_grad = FakeQuantPerLayerGrad()
904
+ >>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
905
+ >>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
906
+ >>> _min = Tensor(np.array([-4]), mindspore.float32)
907
+ >>> _max = Tensor(np.array([2]), mindspore.float32)
908
+ >>> result = fake_min_max_grad(dout, input_x, _min, _max)
909
+ """
910
+ support_quant_bit = [4, 7, 8]
911
+
912
+ @prim_attr_register
913
+ def __init__(self,
914
+ num_bits=8,
915
+ quant_delay=0,
916
+ symmetric=False,
917
+ narrow_range=False):
918
+ if context.get_context('device_target') == "Ascend":
919
+ from mindspore.ops._op_impl._custom_op import fake_quant_perlayer_grad
920
+ if num_bits not in self.support_quant_bit:
921
+ raise ValueError(
922
+ f"For '{self.name}' attr \'num_bits\' is not support.")
923
+
924
+ self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
925
+ self.quant_delay = validator.check_value_type(
926
+ 'quant_delay', quant_delay, (int,), self.name)
927
+ self.symmetric = validator.check_value_type(
928
+ 'symmetric', symmetric, (bool,), self.name)
929
+ self.narrow_range = validator.check_value_type(
930
+ 'narrow_range', narrow_range, (bool,), self.name)
931
+ self.init_prim_io_names(
932
+ inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
933
+
934
+ def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
935
+ validator.check("dout shape", dout_shape, "x shape",
936
+ x_shape, validator.EQ, self.name)
937
+ validator.check("min shape", min_shape, "max shape",
938
+ max_shape, validator.EQ, self.name)
939
+ validator.check_equal_int(len(min_shape), 1, "min shape", self.name)
940
+ return dout_shape
941
+
942
+ def infer_dtype(self, dout_type, x_type, min_type, max_type):
943
+ return _fake_quant_per_grad_infer_dtype(self.name, dout_type, x_type, min_type, max_type)
944
+
945
+
946
+ class FakeQuantPerChannel(PrimitiveWithInfer):
947
+ r"""
948
+ Simulates the quantize and dequantize operations in training time base on per channel.
949
+
950
+ Args:
951
+ num_bits (int) : Number bits to quantilization. Default: 8.
952
+ ema (bool): Uses EMA algorithm update tensor min and tensor max. Default: ``False``.
953
+ ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
954
+ quant_delay (int): Quantilization delay parameter. Before delay step in training time not
955
+ update the weight data to simulate quantize operation. After delay step in training time
956
+ begin simulate the quantize operation. Default: 0.
957
+ symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: ``False``.
958
+ narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: ``False``.
959
+ training (bool): Training the network or not. Default: ``True``.
960
+ channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
961
+
962
+ Inputs:
963
+ - **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
964
+ - **min** (int, float) : Value of the min range of the input data.
965
+ - **max** (int, float) : Value of the max range of the input data.
966
+
967
+ Outputs:
968
+ - Tensor, has the same type as input.
969
+
970
+ Examples:
971
+ >>> fake_quant = FakeQuantPerChannel()
972
+ >>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32)
973
+ >>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32)
974
+ >>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32)
975
+ >>> result = fake_quant(input_x, _min, _max)
976
+ """
977
+ support_quant_bit = [4, 7, 8]
978
+ ascend_support_x_rank = [2, 3, 4]
979
+
980
+ @prim_attr_register
981
+ def __init__(self,
982
+ num_bits=8,
983
+ ema=False,
984
+ ema_decay=0.999,
985
+ quant_delay=0,
986
+ symmetric=False,
987
+ narrow_range=False,
988
+ training=True,
989
+ channel_axis=1):
990
+ """Initialize FakeQuantPerChannel OP"""
991
+ self.is_ascend = context.get_context('device_target') == "Ascend"
992
+ if self.is_ascend:
993
+ from mindspore.ops._op_impl._custom_op import fake_quant_perchannel
994
+ if num_bits not in self.support_quant_bit:
995
+ raise ValueError(
996
+ f"For '{self.name}' Attr \'num_bits\' is not support.")
997
+ if ema and not ema_decay:
998
+ raise ValueError(
999
+ f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.")
1000
+
1001
+ self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
1002
+ self.symmetric = validator.check_value_type(
1003
+ 'symmetric', symmetric, (bool,), self.name)
1004
+ self.narrow_range = validator.check_value_type(
1005
+ 'narrow_range', narrow_range, (bool,), self.name)
1006
+ self.training = validator.check_value_type(
1007
+ 'training', training, (bool,), self.name)
1008
+ self.ema_decay = validator.check_float_range(ema_decay, 0, 1, validator.INC_BOTH, 'ema_decay', self.name)
1009
+ self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
1010
+ self.quant_delay = validator.check_non_negative_int(quant_delay, 'quant_delay', self.name)
1011
+ self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel_axis', self.name)
1012
+ self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
1013
+
1014
+ def infer_shape(self, x_shape, min_shape, max_shape):
1015
+ if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
1016
+ raise ValueError(f"For '{self.name}' x rank must be in '{self.ascend_support_x_rank}'")
1017
+ if not self.is_ascend:
1018
+ validator.check_int(len(x_shape), 1, validator.GE, "x rank", self.name)
1019
+ if len(x_shape) == 1:
1020
+ self.channel_axis = 0
1021
+ validator.check("min shape", min_shape, "max shape", max_shape, validator.EQ, self.name)
1022
+ validator.check_equal_int(min_shape[0], x_shape[self.channel_axis], "min shape", self.name)
1023
+ validator.check_equal_int(max_shape[0], x_shape[self.channel_axis], "max shape", self.name)
1024
+ return x_shape
1025
+
1026
+ def infer_dtype(self, x_type, min_type, max_type):
1027
+ return _fake_quant_per_infer_dtype(self.name, x_type, min_type, max_type)
1028
+
1029
+
1030
+ class FakeQuantPerChannelGrad(PrimitiveWithInfer):
1031
+ r"""
1032
+ Performs grad of FakeQuantPerChannel operation.
1033
+
1034
+ Examples:
1035
+ >>> fqmmpc_grad = FakeQuantPerChannelGrad()
1036
+ >>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32)
1037
+ >>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32)
1038
+ >>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32)
1039
+ >>> _max = Tensor(np.random.randint(-2, 8, (2, 3, 4)), mindspore.float32)
1040
+ >>> result = fqmmpc_grad(dout, input_x, _min, _max)
1041
+ """
1042
+ support_quant_bit = [4, 7, 8]
1043
+
1044
+ @prim_attr_register
1045
+ def __init__(self,
1046
+ num_bits=8,
1047
+ quant_delay=0,
1048
+ symmetric=False,
1049
+ narrow_range=False,
1050
+ channel_axis=1):
1051
+ """Initialize FakeQuantPerChannelGrad Fill"""
1052
+ if context.get_context('device_target') == "Ascend":
1053
+ from mindspore.ops._op_impl._custom_op import fake_quant_perchannel_grad
1054
+ if num_bits not in self.support_quant_bit:
1055
+ raise ValueError(
1056
+ f"For '{self.name}' attr \'num_bits\' is not support.")
1057
+
1058
+ self.num_bits = validator.check_positive_int(num_bits, 'num_bits', self.name)
1059
+ self.quant_delay = validator.check_value_type(
1060
+ 'quant_delay', quant_delay, (int,), self.name)
1061
+ self.symmetric = validator.check_value_type(
1062
+ 'symmetric', symmetric, (bool,), self.name)
1063
+ self.narrow_range = validator.check_value_type(
1064
+ 'narrow_range', narrow_range, (bool,), self.name)
1065
+ self.channel_axis = validator.check_non_negative_int(channel_axis, 'channel axis', self.name)
1066
+ self.init_prim_io_names(
1067
+ inputs=['dout', 'x', 'min', 'max'], outputs=['dx'])
1068
+
1069
+ def infer_shape(self, dout_shape, x_shape, min_shape, max_shape):
1070
+ validator.check("dout shape", dout_shape, "x shape", x_shape)
1071
+ validator.check("min shape", min_shape, "max shape", max_shape)
1072
+ return dout_shape
1073
+
1074
+ def infer_dtype(self, dout_type, x_type, min_type, max_type):
1075
+ return _fake_quant_per_grad_infer_dtype(self.name, dout_type, x_type, min_type, max_type)
1076
+
1077
+
1078
+ class BatchNormFold(PrimitiveWithInfer):
1079
+ """
1080
+ Batch Normalization folded.
1081
+
1082
+ Args:
1083
+ momentum (float): Momentum value must be [0, 1]. Default: 0.9.
1084
+ epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
1085
+ float32 else 1e-3. Default: 1e-5.
1086
+ is_training (bool): In training mode set True, else set False. Default: ``True``.
1087
+ freeze_bn (int): Delay in steps at which computation switches from regular batch
1088
+ norm to frozen mean and std. Default: 0.
1089
+
1090
+ Inputs:
1091
+ - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
1092
+ - **mean** (Tensor) - Tensor of shape :math:`(C,)`.
1093
+ - **variance** (Tensor) - Tensor of shape :math:`(C,)`.
1094
+ - **global_step** (Tensor) - Tensor to record current global step.
1095
+
1096
+ Outputs:
1097
+ Tuple of 4 Tensor, the normalized input and the updated parameters.
1098
+
1099
+ - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
1100
+ - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
1101
+ - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
1102
+ - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
1103
+
1104
+ Examples:
1105
+ >>> batch_norm_fold = P.BatchNormFold()
1106
+ >>> input_x = Tensor(np.array([1, 2, -1, -2, -2, 1]).reshape(2, 3), mindspore.float32)
1107
+ >>> mean = Tensor(np.array([0.5, -1, 1,]), mindspore.float32)
1108
+ >>> variance = Tensor(np.array([0.36, 0.4, 0.49]), mindspore.float32)
1109
+ >>> global_step = Tensor(np.arange(6), mindspore.int32)
1110
+ >>> batch_mean, batch_std, running_mean, running_std = batch_norm_fold(input_x, mean, variance, global_step)
1111
+ """
1112
+ channel_axis = 1
1113
+
1114
+ @prim_attr_register
1115
+ def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
1116
+ """Initialize batch norm fold layer"""
1117
+ self.momentum = validator.check_float_range(momentum, 0, 1, validator.INC_BOTH, 'momentum', self.name)
1118
+ self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
1119
+ self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
1120
+ self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
1121
+
1122
+ self.init_prim_io_names(inputs=['x', 'mean', 'variance', 'global_step'],
1123
+ outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std'])
1124
+
1125
+ def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape):
1126
+ validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, validator.EQ, self.name)
1127
+ validator.check("mean_shape[0]", mean_shape[0], "input channel",
1128
+ x_shape[self.channel_axis], validator.EQ, self.name)
1129
+ validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1130
+ return mean_shape, mean_shape, mean_shape, mean_shape
1131
+
1132
+ def infer_dtype(self, x_type, mean_type, variance_type, global_step_type):
1133
+ validator.check("input type", x_type, "mean type", mean_type)
1134
+ validator.check("input type", x_type, "variance type", variance_type)
1135
+ args = {"x": x_type, "mean": mean_type, "variance": variance_type}
1136
+ validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1137
+ validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
1138
+ return x_type, x_type, x_type, x_type
1139
+
1140
+
1141
+ class BatchNormFoldGrad(PrimitiveWithInfer):
1142
+ r"""
1143
+ Performs grad of BatchNormFold operation.
1144
+
1145
+ Examples:
1146
+ >>> batch_norm_fold_grad = ops.BatchNormFoldGrad()
1147
+ >>> d_batch_mean = Tensor(np.random.randint(-2., 2., (1, 2, 2, 3)), mindspore.float32)
1148
+ >>> d_batch_std = Tensor(np.random.randn(1, 2, 2, 3), mindspore.float32)
1149
+ >>> input_x = Tensor(np.random.randint(0, 256, (4, 1, 4, 6)), mindspore.float32)
1150
+ >>> batch_mean = Tensor(np.random.randint(-8., 8., (1, 2, 2, 3)), mindspore.float32)
1151
+ >>> batch_std = Tensor(np.random.randint(0, 12, (1, 2, 2, 3)), mindspore.float32)
1152
+ >>> global_step = Tensor([2], mindspore.int32)
1153
+ >>> result = batch_norm_fold_grad(d_batch_mean, d_batch_std, input_x, batch_mean, batch_std, global_step)
1154
+ """
1155
+ channel_axis = 1
1156
+
1157
+ @prim_attr_register
1158
+ def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
1159
+ """Initialize BatchNormGrad layer"""
1160
+ self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
1161
+ self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
1162
+ self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
1163
+ self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'],
1164
+ outputs=['dx'])
1165
+
1166
+ def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape,
1167
+ global_step_shape):
1168
+ validator.check("d_batch_mean shape", d_batch_mean_shape,
1169
+ "d_batch_std shape", d_batch_std_shape, validator.EQ, self.name)
1170
+ validator.check("d_batch_mean shape", d_batch_mean_shape,
1171
+ "batch_mean shape", batch_mean_shape, validator.EQ, self.name)
1172
+ validator.check("d_batch_mean shape", d_batch_mean_shape,
1173
+ "batch_std shape", batch_std_shape, validator.EQ, self.name)
1174
+ validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0],
1175
+ "input channel", x_shape[self.channel_axis], validator.EQ, self.name)
1176
+ validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1177
+ return x_shape
1178
+
1179
+ def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type,
1180
+ global_step_type):
1181
+ args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type,
1182
+ "batch_mean": batch_mean_type, "batch_std": batch_std_type}
1183
+ validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1184
+ validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
1185
+ return x_type
1186
+
1187
+
1188
+ class CorrectionMul(PrimitiveWithInfer):
1189
+ """
1190
+ Scales the weights with a correction factor to the long term statistics
1191
+ prior to quantization. This ensures that there is no jitter in the quantized weights
1192
+ due to batch to batch variation.
1193
+
1194
+ Inputs:
1195
+ - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
1196
+ - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
1197
+ - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
1198
+
1199
+ Outputs:
1200
+ - **out** (Tensor) - Tensor has the same shape as x.
1201
+
1202
+ Examples:
1203
+ >>> correction_mul = ops.CorrectionMul()
1204
+ >>> input_x = Tensor(np.random.randint(-8, 12, (3, 4)), mindspore.float32)
1205
+ >>> batch_std = Tensor(np.array([1.5, 3, 2]), mindspore.float32)
1206
+ >>> running_std = Tensor(np.array([2, 1.2, 0.5]), mindspore.float32)
1207
+ >>> out = correction_mul(input_x, batch_std, running_std)
1208
+ """
1209
+
1210
+ @prim_attr_register
1211
+ def __init__(self, channel_axis=0):
1212
+ """Initialize correction mul layer"""
1213
+ if context.get_context('device_target') == "Ascend":
1214
+ from mindspore.ops._op_impl._custom_op import correction_mul
1215
+ self.channel_axis = channel_axis
1216
+ self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'],
1217
+ outputs=['out'])
1218
+
1219
+ def infer_shape(self, x_shape, batch_std_shape, running_std_shape):
1220
+ validator.check("batch_std shape", batch_std_shape, "running_std shape",
1221
+ running_std_shape, validator.EQ, self.name)
1222
+ validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
1223
+ validator.EQ, self.name)
1224
+ return x_shape
1225
+
1226
+ def infer_dtype(self, x_type, batch_std_type, running_std_type):
1227
+ args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type}
1228
+ validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1229
+ return x_type
1230
+
1231
+
1232
+ class CorrectionMulGrad(PrimitiveWithInfer):
1233
+ r"""
1234
+ Performs grad of CorrectionMul operation.
1235
+
1236
+ Examples:
1237
+ >>> correction_mul_grad = ops.CorrectionMulGrad()
1238
+ >>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32)
1239
+ >>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32)
1240
+ >>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32)
1241
+ >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
1242
+ >>> result = correction_mul_grad(dout, input_x, gamma, running_std)
1243
+ """
1244
+
1245
+ @prim_attr_register
1246
+ def __init__(self, channel_axis=0):
1247
+ """Initialize correction mul layer"""
1248
+ if context.get_context('device_target') == "Ascend":
1249
+ from mindspore.ops._op_impl._custom_op import correction_mul_grad
1250
+ self.channel_axis = channel_axis
1251
+ self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'],
1252
+ outputs=['dx', 'mul_dx'])
1253
+
1254
+ def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape):
1255
+ validator.check("dout shape", dout_shape, "x_shape x", x_shape, validator.EQ, self.name)
1256
+ validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel_axis],
1257
+ validator.EQ, self.name)
1258
+ validator.check("running_std_shape[0]", running_std_shape[0],
1259
+ "dout channel size", dout_shape[self.channel_axis], validator.EQ, self.name)
1260
+ if context.get_context('device_target') == "Ascend":
1261
+ return x_shape, x_shape
1262
+ return x_shape, gamma_shape
1263
+
1264
+ def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type):
1265
+ args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type}
1266
+ validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1267
+ if context.get_context('device_target') == "Ascend":
1268
+ return x_type, x_type
1269
+ return x_type, gamma_type
1270
+
1271
+
1272
+ class CorrectionMulGradReduce(PrimitiveWithInfer):
1273
+ r"""
1274
+ Performs grad reduce of CorrectionMul operation.
1275
+
1276
+ Examples:
1277
+ >>> correction_mul_grad_rd = ops.CorrectionMulGradReduce()
1278
+ >>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32)
1279
+ >>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32)
1280
+ >>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32)
1281
+ >>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
1282
+ >>> result = correction_mul_grad_rd(dout, input_x, gamma, running_std)
1283
+ """
1284
+
1285
+ @prim_attr_register
1286
+ def __init__(self, channel_axis=0):
1287
+ """Initialize correction mul reduce layer"""
1288
+ if context.get_context('device_target') == "Ascend":
1289
+ from mindspore.ops._op_impl._custom_op import correction_mul_grad
1290
+ self.channel_axis = channel_axis
1291
+ self.init_prim_io_names(inputs=['mul_dx'],
1292
+ outputs=['d_gamma'])
1293
+
1294
+ def infer_shape(self, mul_dx_shape):
1295
+ return [mul_dx_shape[self.channel_axis]]
1296
+
1297
+ def infer_dtype(self, mul_dx_type):
1298
+ return mul_dx_type
1299
+
1300
+
1301
+ class BatchNormFold2(PrimitiveWithInfer):
1302
+ """
1303
+ Scales the bias with a correction factor to the long term statistics
1304
+ prior to quantization. This ensures that there is no jitter in the quantized bias
1305
+ due to batch to batch variation.
1306
+
1307
+ Inputs:
1308
+ - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
1309
+ - **beta** (Tensor) - Tensor of shape :math:`(C,)`.
1310
+ - **gamma** (Tensor) - Tensor of shape :math:`(C,)`.
1311
+ - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
1312
+ - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
1313
+ - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
1314
+ - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
1315
+ - **global_step** (Tensor) - Tensor to record current global step.
1316
+
1317
+ Outputs:
1318
+ - **y** (Tensor) - Tensor has the same shape as x.
1319
+
1320
+ Examples:
1321
+ >>> batch_norm_fold2 = ops.BatchNormFold2()
1322
+ >>> input_x = Tensor(np.random.randint(-6, 6, (4, 3)), mindspore.float32)
1323
+ >>> beta = Tensor(np.array([0.2, -0.1, 0.25]), mindspore.float32)
1324
+ >>> gamma = Tensor(np.array([-0.1, -0.25, 0.1]), mindspore.float32)
1325
+ >>> batch_std = Tensor(np.array([0.1, 0.2, 0.1]), mindspore.float32)
1326
+ >>> batch_mean = Tensor(np.array([0, 0.05, 0.2]), mindspore.float32)
1327
+ >>> running_std = Tensor(np.array([0.1, 0.1, 0.3]), mindspore.float32)
1328
+ >>> running_mean = Tensor(np.array([-0.1, 0, -0.1]), mindspore.float32)
1329
+ >>> global_step = Tensor(np.random.randint(1, 8, (8, )), mindspore.int32)
1330
+ >>> result = batch_norm_fold2(input_x, beta, gamma, batch_std, batch_mean,
1331
+ >>> running_std, running_mean, global_step)
1332
+ """
1333
+ channel_axis = 1
1334
+
1335
+ @prim_attr_register
1336
+ def __init__(self, freeze_bn=0):
1337
+ """Initialize conv2d fold layer"""
1338
+ self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
1339
+ self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean',
1340
+ 'running_std', 'running_mean', 'global_step'],
1341
+ outputs=['y'])
1342
+
1343
+ def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape,
1344
+ running_mean_shape, global_step_shape):
1345
+ validator.check("batch_std shape", batch_std_shape, "running_std shape",
1346
+ running_std_shape, validator.EQ, self.name)
1347
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
1348
+ batch_mean_shape, validator.EQ, self.name)
1349
+ validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, validator.EQ, self.name)
1350
+ validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape,
1351
+ validator.EQ, self.name)
1352
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, validator.EQ, self.name)
1353
+ validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
1354
+ validator.EQ, self.name)
1355
+ validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1356
+ return x_shape
1357
+
1358
+ def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type,
1359
+ running_mean_type, global_step_type):
1360
+ args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
1361
+ "beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type}
1362
+ validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1363
+ validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
1364
+ return x_type
1365
+
1366
+
1367
+ class BatchNormFold2Grad(PrimitiveWithInfer):
1368
+ r"""
1369
+ Performs grad of BatchNormFold2 operation.
1370
+
1371
+ Examples:
1372
+ >>> bnf2_grad = ops.BatchNormFold2Grad()
1373
+ >>> input_x = Tensor(np.arange(3*3*12*12).reshape(6, 3, 6, 12), mindspore.float32)
1374
+ >>> dout = Tensor(np.random.randint(-32, 32, (6, 3, 6, 12)), mindspore.float32)
1375
+ >>> gamma = Tensor(np.random.randint(-4, 4, (3, 1, 1, 2)), mindspore.float32)
1376
+ >>> batch_std = Tensor(np.random.randint(0, 8, (3, 1, 1, 2)), mindspore.float32)
1377
+ >>> batch_mean = Tensor(np.random.randint(-6, 6, (3, 1, 1, 2)), mindspore.float32)
1378
+ >>> running_std = Tensor(np.linspace(0, 2, 6).reshape(3, 1, 1, 2), mindspore.float32)
1379
+ >>> running_mean = Tensor(np.random.randint(-3, 3, (3, 1, 1, 2)), mindspore.float32)
1380
+ >>> global_step = Tensor(np.array([-2]), mindspore.int32)
1381
+ >>> result = bnf2_grad(dout, input_x, gamma, batch_std, batch_mean, running_std, running_mean, global_step)
1382
+ """
1383
+ channel_axis = 1
1384
+
1385
+ @prim_attr_register
1386
+ def __init__(self, freeze_bn=0):
1387
+ """Initialize MulFold layer"""
1388
+ self.freeze_bn = freeze_bn
1389
+ self.init_prim_io_names(inputs=['dout', 'x', 'gamma',
1390
+ 'batch_std', 'batch_mean',
1391
+ 'running_std', 'running_mean', 'global_step'],
1392
+ outputs=['d_batch_std', 'd_batch_mean', 'd_beta', 'd_gamma', 'dx'])
1393
+
1394
+ def infer_shape(self, dout_shape, x_shape, gamma_shape,
1395
+ batch_std_shape, batch_mean_shape,
1396
+ running_std_shape, running_mean_shape, global_step_shape):
1397
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
1398
+ batch_mean_shape, validator.EQ, self.name)
1399
+ validator.check("batch_std shape", batch_std_shape, "running_std shape",
1400
+ running_std_shape, validator.EQ, self.name)
1401
+ validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape,
1402
+ validator.EQ, self.name)
1403
+ validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, validator.EQ, self.name)
1404
+ validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
1405
+ validator.EQ, self.name)
1406
+ validator.check_equal_int(len(global_step_shape), 1, "global step shape len", self.name)
1407
+ return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape
1408
+
1409
+ def infer_dtype(self, dout_type, x_type, gamma_type,
1410
+ batch_std_type, batch_mean_type,
1411
+ running_std_type, running_mean_type, global_step_type):
1412
+ validator.check("batch_std type", batch_std_type,
1413
+ "batch_mean type", batch_mean_type)
1414
+ validator.check("batch_std type", batch_std_type,
1415
+ "gamma type", gamma_type)
1416
+ validator.check("batch_std type", batch_std_type,
1417
+ "running_std type", running_std_type)
1418
+ validator.check("batch_std type", batch_std_type,
1419
+ "running_mean type", running_mean_type)
1420
+ validator.check("batch_std_type", batch_std_type,
1421
+ "dout type", dout_type)
1422
+ args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
1423
+ "running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type}
1424
+ validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1425
+ validator.check_tensor_dtype_valid("global_step", global_step_type, (mstype.int32,), self.name)
1426
+ return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type
1427
+
1428
+
1429
+ class BatchNormFoldD(PrimitiveWithInfer):
1430
+ """Performs grad of _BatchNormFold operation."""
1431
+
1432
+ @prim_attr_register
1433
+ def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
1434
+ """Initialize _BatchNormFold layer"""
1435
+ from mindspore.ops._op_impl._custom_op import batchnorm_fold
1436
+ self.momentum = validator.check_float_range(momentum, 0, 1, validator.INC_BOTH, 'momentum', self.name)
1437
+ self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
1438
+ self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
1439
+ self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
1440
+ self.data_format = "NCHW"
1441
+ self.init_prim_io_names(inputs=['x', 'x_sum', 'x_square_sum', 'mean', 'variance'],
1442
+ outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std',
1443
+ 'mean_updated', 'variance_updated'])
1444
+
1445
+ def infer_shape(self, x_shape, x_sum_shape, x_square_sum_shape, mean_shape, variance_shape):
1446
+ validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, validator.EQ, self.name)
1447
+ validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[1], validator.EQ, self.name)
1448
+ return x_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape, mean_shape
1449
+
1450
+ def infer_dtype(self, x_type, x_sum_type, x_square_sum_type, mean_type, variance_type):
1451
+ validator.check("input type", x_type, "mean type", mean_type)
1452
+ validator.check("input type", x_type, "variance type", variance_type)
1453
+ args = {"x": x_type, "mean": mean_type, "variance": variance_type}
1454
+ validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1455
+ return x_type, x_type, x_type, x_type, x_type, x_type, x_type
1456
+
1457
+
1458
+ class BatchNormFoldGradD(PrimitiveWithInfer):
1459
+ """Performs grad of BatchNormFold operation."""
1460
+
1461
+ @prim_attr_register
1462
+ def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
1463
+ """Initialize _BatchNormFoldGrad layer"""
1464
+ from mindspore.ops._op_impl._custom_op import batchnorm_fold_grad
1465
+ self.epsilon = validator.check_positive_float(epsilon, 'epsilon', self.name)
1466
+ self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
1467
+ self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
1468
+ self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std'],
1469
+ outputs=['dx'])
1470
+
1471
+ def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape):
1472
+ validator.check("d_batch_mean shape", d_batch_mean_shape, "d_batch_std shape", d_batch_std_shape)
1473
+ validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_mean shape", batch_mean_shape)
1474
+ validator.check("d_batch_mean shape", d_batch_mean_shape, "batch_std shape", batch_std_shape)
1475
+ validator.check("x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[1])
1476
+ return x_shape
1477
+
1478
+ def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type):
1479
+ validator.check("input type", x_type, "d_batch_mean type", d_batch_mean_type)
1480
+ validator.check("input type", x_type, "d_batch_std type", d_batch_std_type)
1481
+ validator.check("input type", x_type, "batch_mean type", batch_mean_type)
1482
+ validator.check("input type", x_type, "batch_std type", batch_std_type)
1483
+ validator.check_tensor_dtype_valid("input type", x_type, (mstype.float16, mstype.float32), self.name)
1484
+ return x_type
1485
+
1486
+
1487
+ class BatchNormFold2D(PrimitiveWithInfer):
1488
+ """
1489
+ Scales the bias with a correction factor to the long term statistics
1490
+ prior to quantization. This ensures that there is no jitter in the quantized bias
1491
+ due to batch to batch variation.
1492
+
1493
+ Inputs:
1494
+ - **x** (Tensor) - Tensor of shape :math:`(N, C)`.
1495
+ - **beta** (Tensor) - Tensor of shape :math:`(C,)`.
1496
+ - **gamma** (Tensor) - Tensor of shape :math:`(C,)`.
1497
+ - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`.
1498
+ - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`.
1499
+ - **running_std** (Tensor) - Tensor of shape :math:`(C,)`.
1500
+ - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`.
1501
+ - **global_step** (Tensor) - Tensor to record current global step.
1502
+
1503
+ Outputs:
1504
+ - **y** (Tensor) - Tensor has the same shape as x.
1505
+
1506
+ """
1507
+ channel_axis = 1
1508
+
1509
+ @prim_attr_register
1510
+ def __init__(self, freeze_bn=0):
1511
+ """Initialize conv2d fold layer"""
1512
+ from mindspore.ops._op_impl._custom_op import batchnorm_fold2
1513
+ self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean', 'running_std'],
1514
+ outputs=['y'])
1515
+
1516
+ def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape):
1517
+ validator.check("batch_std shape", batch_std_shape, "running_std shape",
1518
+ running_std_shape, validator.EQ, self.name)
1519
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
1520
+ batch_mean_shape, validator.EQ, self.name)
1521
+ validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, validator.EQ, self.name)
1522
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, validator.EQ, self.name)
1523
+ validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel_axis],
1524
+ validator.EQ, self.name)
1525
+ return x_shape
1526
+
1527
+ def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type):
1528
+ args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type,
1529
+ "beta": beta_type, "gamma": gamma_type, "x": x_type}
1530
+ validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1531
+ return x_type
1532
+
1533
+
1534
+ class BatchNormFold2GradD(PrimitiveWithInfer):
1535
+ """Performs grad of BatchNormFold2 operation."""
1536
+ channel_axis = 1
1537
+
1538
+ @prim_attr_register
1539
+ def __init__(self, freeze_bn=False):
1540
+ """Initialize MulFold layer"""
1541
+ from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad
1542
+ self.freeze_bn = freeze_bn
1543
+ self.init_prim_io_names(
1544
+ inputs=['dout', 'dout_reduce', 'dout_x_reduce', 'gamma', 'batch_std', 'batch_mean', 'running_std'],
1545
+ outputs=['d_batch_std', 'd_batch_mean', 'd_gamma', 'dx'])
1546
+
1547
+ def infer_shape(self, dout_shape, dout_reduce_shape, dout_x_reduce_shape, gamma_shape, batch_std_shape,
1548
+ batch_mean_shape, running_std_shape):
1549
+ validator.check("batch_std shape", batch_std_shape, "batch_mean shape",
1550
+ batch_mean_shape, validator.EQ, self.name)
1551
+ validator.check("batch_std shape", batch_std_shape, "running_std shape",
1552
+ running_std_shape, validator.EQ, self.name)
1553
+ validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, validator.EQ, self.name)
1554
+ validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel_axis],
1555
+ validator.EQ, self.name)
1556
+ return gamma_shape, gamma_shape, gamma_shape, dout_shape
1557
+
1558
+ def infer_dtype(self, dout_type, dout_reduce_type, dout_x_reduce_type, gamma_type, batch_std_type,
1559
+ batch_mean_type, running_std_type):
1560
+ validator.check("batch_std type", batch_std_type,
1561
+ "batch_mean type", batch_mean_type)
1562
+ validator.check("batch_std type", batch_std_type,
1563
+ "gamma type", gamma_type)
1564
+ validator.check("batch_std type", batch_std_type,
1565
+ "running_std type", running_std_type)
1566
+ validator.check("batch_std_type", batch_std_type,
1567
+ "dout type", dout_type)
1568
+ args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type,
1569
+ "running_std": running_std_type, "dout": dout_type}
1570
+ validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
1571
+ return gamma_type, gamma_type, gamma_type, gamma_type
1572
+
1573
+
1574
+ class BatchNormFold2GradReduce(PrimitiveWithInfer):
1575
+ """Performs grad of CorrectionAddGrad operation."""
1576
+ channel_axis = 1
1577
+
1578
+ @prim_attr_register
1579
+ def __init__(self, freeze_bn=False):
1580
+ """Initialize MulFold layer"""
1581
+ from mindspore.ops._op_impl._custom_op import batchnorm_fold2_grad_reduce
1582
+ self.freeze_bn = freeze_bn
1583
+ self.init_prim_io_names(inputs=['dout', 'x'],
1584
+ outputs=['dout_reduce', 'dout_x_reduce'])
1585
+
1586
+ def infer_shape(self, dout_shape, x_shape):
1587
+ validator.check("dout shape", dout_shape, "x shape", x_shape, validator.EQ, self.name)
1588
+ return (dout_shape[self.channel_axis],), (dout_shape[self.channel_axis],)
1589
+
1590
+ def infer_dtype(self, dout_type, x_type):
1591
+ validator.check("dout type", dout_type, "x type", x_type)
1592
+ return dout_type, dout_type
1593
+
1594
+
1595
+ class ActsULQ(PrimitiveWithInfer):
1596
+ """
1597
+ The ActsULQ(Activation universal learnable quantization).
1598
+
1599
+ Args:
1600
+ fixed_min (bool): whether fix clamp min to zero.
1601
+ num_bits (int): The bits num used for quantize.
1602
+
1603
+ Inputs:
1604
+ - **x** (Tensor) - A Tensor of feature map. With float16 or float32 data type.
1605
+ - **clamp_min** (Tensor) - A Tensor of clamp min with the same type as x.
1606
+ - **clamp_max** (Tensor) - A Tensor of clamp max with the same type as x.
1607
+
1608
+ Outputs:
1609
+ - **y** (Tensor) - A tensor of fake quant of feature map with the same type as `w`.
1610
+ - **clamp_min** (Tensor) - A tensor of boolean masks if data in feature map >= clamp_min.
1611
+ - **clamp_max** (Tensor) - A tensor of boolean masks if data in feature map <= clamp_max.
1612
+ - **x_clamped_loss** (Tensor) - A tensor of clamped loss.
1613
+
1614
+ Examples:
1615
+ >>> data_type = np.float32
1616
+ >>> x= np.random.uniform(-10, 10, (32, 120)).astype(data_type)
1617
+ >>> clamp_max = 0.7 * np.max(x)
1618
+ >>> clamp_min = 0.7 * np.min(x)
1619
+ >>> clamp_max = np.array([clamp_max], dtype=data_type)
1620
+ >>> clamp_min = np.array([clamp_min], dtype=data_type)
1621
+ >>> acts_ulq = Q.ActsULQ(fixed_mini=True, num_bits=8)
1622
+ >>> quant_x, clamp_min_mask, clamp_max_mask, x_clamped_loss = acts_ulq(Tensor(x), Tensor( clamp_min),
1623
+ Tensor(clamp_max))
1624
+ """
1625
+ @prim_attr_register
1626
+ def __init__(self, fixed_min=False, num_bits=8):
1627
+ validator.check_value_type("fixed_min", fixed_min, [bool], self.name)
1628
+ validator.check_value_type("num_bits", num_bits, [int], self.name)
1629
+ validator.check_int(num_bits, 8, validator.EQ, "value of num_bits", self.name)
1630
+
1631
+ def infer_shape(self, x_shape, clamp_min_shape, clamp_max_shape):
1632
+ """infer shape of primitive"""
1633
+ validator.check_int(len(clamp_min_shape), len(x_shape), validator.EQ, "dims of clamp_min", self.name)
1634
+ validator.check_int(len(clamp_max_shape), len(x_shape), validator.EQ, "dims of clamp_max", self.name)
1635
+
1636
+ x_shape_len = len(x_shape)
1637
+ for i in range(x_shape_len):
1638
+ validator.check_int(clamp_min_shape[i], 1, validator.EQ, "dims of clamp_min", self.name)
1639
+ validator.check_int(clamp_max_shape[i], 1, validator.EQ, "dims of clamp_max", self.name)
1640
+
1641
+ return x_shape, x_shape, x_shape, x_shape
1642
+
1643
+ def infer_dtype(self, x_dtype, clamp_min_dtype, clamp_max_dtype):
1644
+ """infer dtype of primitive"""
1645
+ valid_types = [mstype.float32, mstype.float16]
1646
+ validator.check_tensor_dtype_valid("x", x_dtype, valid_types, self.name)
1647
+ validator.check_tensor_dtype_valid("clamp_min", clamp_min_dtype, valid_types, self.name)
1648
+ validator.check_tensor_dtype_valid("clamp_max", clamp_max_dtype, valid_types, self.name)
1649
+
1650
+ return x_dtype, mstype.bool_, mstype.bool_, x_dtype
1651
+
1652
+
1653
+ class ActsULQInputGrad(PrimitiveWithInfer):
1654
+ """
1655
+ The ActsULQInputGrad(grad of ActsULQ).
1656
+
1657
+ Inputs:
1658
+ - **y_grad** (Tensor) - A Tensor of grad. With float16 or float32 data type.
1659
+
1660
+ Outputs:
1661
+ - **x_grad** (Tensor) - A tensor of data grad with the same type as `y_grad`.
1662
+ """
1663
+ @prim_attr_register
1664
+ def __init__(self):
1665
+ pass
1666
+
1667
+ def infer_shape(self, y_grad_shape, clamp_min_mask_shape, clamp_max_mask_shape):
1668
+ return y_grad_shape
1669
+
1670
+ def infer_dtype(self, y_grad_type, clamp_min_mask_type, clamp_max_mask_type):
1671
+ valid_types = [mstype.float32, mstype.float16]
1672
+ validator.check_tensor_dtype_valid("y_grad", y_grad_type, valid_types, self.name)
1673
+ return y_grad_type
1674
+
1675
+
1676
+ class ActULQClampMinGrad(PrimitiveWithInfer):
1677
+ """
1678
+ The ActULQClampMinGrad(Activation Universal Linear Quantization on Clamp Minimum Gradient)
1679
+
1680
+ Inputs:
1681
+ - **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type.
1682
+ - **clamp_min_mask** - A tensor of mask, only support int8 type.
1683
+ - **x_clamped_loss** - A tensor of loss, with the same type as "y_grad".
1684
+
1685
+ Outputs:
1686
+ - **clamp_min_grad** - A tensor of clamp minimum gradient, with the same type as "y_grad".
1687
+ The length of tensor is 1.
1688
+
1689
+ Examples:
1690
+ >>> data_type = np.float32
1691
+ >>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
1692
+ >>> clamp_min_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0)
1693
+ >>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
1694
+ >>> act_ulq_clamp_min_grad = Q.ActULQClampMinGrad()
1695
+ >>> clamp_min_grad = act_ulq_clamp_min_grad(Tensor(y_grad), Tensor(clamp_min_mask, mindspore.bool_),
1696
+ Tensor(x_clamped_loss))
1697
+ """
1698
+ @prim_attr_register
1699
+ def __init__(self):
1700
+ pass
1701
+
1702
+ def infer_shape(self, input_x, input_y, input_z):
1703
+ input_x_len = len(input_x)
1704
+ output_shape = []
1705
+ for _ in range(input_x_len):
1706
+ output_shape.append(1)
1707
+ return tuple(output_shape)
1708
+
1709
+ def infer_dtype(self, input_x, input_y, input_z):
1710
+ return mstype.float32
1711
+
1712
+
1713
+ class ActULQClampMaxGrad(PrimitiveWithInfer):
1714
+ """
1715
+ The ActULQClampMaxGrad(Activation Universal Linear Quantization on Clamp Maximum Gradient)
1716
+
1717
+ Inputs:
1718
+ - **y_grad** (Tensor) - A tensor of gradient, with float16 or float32 type.
1719
+ - **clamp_max_mask** - A tensor of mask, only support int8 type.
1720
+ - **x_clamped_loss** - A tensor of loss, with the same type as "y_grad".
1721
+
1722
+ Outputs:
1723
+ - **clamp_max_grad** - A tensor of clamp maximum gradient, with the same type as "y_grad".
1724
+ The length of tensor is 1.
1725
+
1726
+ Examples:
1727
+ >>> data_type = np.float32
1728
+ >>> y_grad = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
1729
+ >>> clamp_max_mask = np.where(np.random.rand(32, 120) >= 0.5, 1, 0)
1730
+ >>> x_clamped_loss = np.random.uniform(-10, 10, (32, 120)).astype(data_type)
1731
+ >>> act_ulq_clamp_max_grad = Q.ActULQClampMaxGrad()
1732
+ >>> clamp_max_grad = act_ulq_clamp_max_grad(Tensor(y_grad), Tensor(clamp_max_mask, mindspore.bool_),
1733
+ Tensor(x_clamped_loss))
1734
+ """
1735
+ @prim_attr_register
1736
+ def __init__(self):
1737
+ pass
1738
+
1739
+ def infer_shape(self, input_x, input_y, input_z):
1740
+ input_x_len = len(input_x)
1741
+ output_shape = []
1742
+ for _ in range(input_x_len):
1743
+ output_shape.append(1)
1744
+ return tuple(output_shape)
1745
+
1746
+ def infer_dtype(self, input_x, input_y, input_z):
1747
+ return mstype.float32
1748
+
1749
+
1750
+ class WtsARQ(PrimitiveWithInfer):
1751
+ """
1752
+ The WtsARQ(Weights Adaptive Range Quantization).
1753
+
1754
+ Args:
1755
+ num_bits (int): The bits num used for quantize.
1756
+ offset_flag (bool): Whether use offset for quantize.
1757
+
1758
+ Inputs:
1759
+ - **w** (Tensor) - A Tensor of weights. With float16 or float32 data type.
1760
+
1761
+ Outputs:
1762
+ - **scale** (Tensor) - A tensor of optimal scale, has the same type as `w`.
1763
+ - **offset** (Tensor) - A tensor of optimal offset, has the same type as `w`.
1764
+ - If axis is [],
1765
+ the shape of scale and offset is :math:`(1, )`.
1766
+ - If axis is [0],
1767
+ the shape of scale and offset is :math:`(w_1, )`.
1768
+ - If axis is [1],
1769
+ the shape of scale and offset is :math:`(w_2, )`.
1770
+ - **y** (Tensor) - A tensor of fakequant weights, has the same type and shape as `w`.
1771
+
1772
+ Examples:
1773
+ >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32))
1774
+ >>> wts_arq = Q.WtsARQ(axes=[0], num_bits=8, offset_flag=False)
1775
+ >>> scale, offset, y = wts_arq(data)
1776
+ """
1777
+ @prim_attr_register
1778
+ def __init__(self, num_bits, offset_flag):
1779
+ validator.check_value_type("num_bits", num_bits, [int], self.name)
1780
+ validator.check_int(num_bits, 8, validator.EQ, "value of num_bits", self.name)
1781
+ validator.check_value_type("offset_flag", offset_flag, [bool], self.name)
1782
+
1783
+ def infer_shape(self, w_shape, w_min_shape, w_max_shape):
1784
+ validator.check_int(len(w_min_shape), len(w_shape), validator.EQ, "dims of w_min", self.name)
1785
+ validator.check_int(len(w_max_shape), len(w_shape), validator.EQ, "dims of w_max", self.name)
1786
+ return w_shape
1787
+
1788
+ def infer_dtype(self, w_dtype, w_min_dtype, w_max_dtype):
1789
+ valid_types = [mstype.float32, mstype.float16]
1790
+ validator.check_tensor_dtype_valid("w", w_dtype, valid_types, self.name)
1791
+ validator.check_tensor_dtype_valid("w_min", w_min_dtype, valid_types, self.name)
1792
+ validator.check_tensor_dtype_valid("w_max", w_max_dtype, valid_types, self.name)
1793
+ return w_dtype
1794
+
1795
+
1796
+ class IFMR(Primitive):
1797
+ """
1798
+ The TFMR(Input Feature Map Reconstruction).
1799
+
1800
+ Args:
1801
+ min_percentile (float): Min init percentile. Default: 0.999999.
1802
+ max_percentile (float): Max init percentile. Default: 0.999999.
1803
+ search_range Union[list(float), tuple(float)]: Range of searching. Default: [0.7, 1.3].
1804
+ search_step (float): Step size of searching. Default: 0.01.
1805
+ with_offset (bool): Whether using offset. Default: ``True``.
1806
+
1807
+ Inputs:
1808
+ - **data** (Tensor) - A Tensor of feature map. With float16 or float32 data type.
1809
+ - **data_min** (Tensor) - A Tensor of min value of feature map, the shape is :math:`(1)`.
1810
+ With float16 or float32 data type.
1811
+ - **data_max** (Tensor) - A Tensor of max value of feature map, the shape is :math:`(1)`.
1812
+ With float16 or float32 data type.
1813
+ - **cumsum** (Tensor) - A `1-D` Tensor of cumsum bin of data. With int32 data type.
1814
+
1815
+ Outputs:
1816
+ - **scale** (Tensor) - A tensor of optimal scale, the shape is :math:`(1)`. Data dtype is float32.
1817
+ - **offset** (Tensor) - A tensor of optimal offset, the shape is :math:`(1)`. Data dtype is float32.
1818
+
1819
+ Examples:
1820
+ >>> data = Tensor(np.random.rand(1, 3, 6, 4).astype(np.float32))
1821
+ >>> data_min = Tensor([0.1], mindspore.float32)
1822
+ >>> data_max = Tensor([0.5], mindspore.float32)
1823
+ >>> cumsum = Tensor(np.random.rand(4).astype(np.int32))
1824
+ >>> ifmr = Q.IFMR(min_percentile=0.2, max_percentile=0.9, search_range=(1.0, 2.0),
1825
+ ... search_step=1.0, with_offset=False)
1826
+ >>> output = ifmr(data, data_min, data_max, cumsum)
1827
+ >>> print(output)
1828
+ (Tensor(shape=[1], dtype=Float32, value= [7.87401572e-03]),
1829
+ Tensor(shape=[1], dtype=Float32, value= [0.00000000e+00]))
1830
+ """
1831
+
1832
+ @prim_attr_register
1833
+ def __init__(self, min_percentile=0.999999, max_percentile=0.999999, search_range=(0.7, 1.3), search_step=0.01,
1834
+ with_offset=True):
1835
+ self.init_prim_io_names(
1836
+ inputs=['data', 'data_min', 'data_max', 'cumsum'], outputs=['scale', 'offset'])
1837
+ validator.check_value_type("min_percentile", min_percentile, [float], self.name)
1838
+ validator.check_value_type("max_percentile", max_percentile, [float], self.name)
1839
+ validator.check_value_type("search_range", search_range, [list, tuple], self.name)
1840
+ for item in search_range:
1841
+ validator.check_positive_float(item, "item of search_range", self.name)
1842
+ validator.check('search_range[1]', search_range[1], 'search_range[0]', search_range[0], validator.GE, self.name)
1843
+ validator.check_value_type("search_step", search_step, [float], self.name)
1844
+ validator.check_value_type("offset_flag", with_offset, [bool], self.name)