mindspore 2.4.0__cp310-cp310-macosx_10_15_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mindspore might be problematic. Click here for more details.

Files changed (1387) hide show
  1. mindspore/.commit_id +1 -0
  2. mindspore/__init__.py +53 -0
  3. mindspore/_c_dataengine.cpython-310-darwin.so +0 -0
  4. mindspore/_c_expression.cpython-310-darwin.so +0 -0
  5. mindspore/_c_mindrecord.cpython-310-darwin.so +0 -0
  6. mindspore/_check_jit_forbidden_api.py +106 -0
  7. mindspore/_checkparam.py +1419 -0
  8. mindspore/_extends/__init__.py +23 -0
  9. mindspore/_extends/builtin_operations.py +224 -0
  10. mindspore/_extends/graph_kernel/__init__.py +17 -0
  11. mindspore/_extends/graph_kernel/model/__init__.py +19 -0
  12. mindspore/_extends/graph_kernel/model/graph_parallel.py +311 -0
  13. mindspore/_extends/graph_kernel/model/graph_split.py +1348 -0
  14. mindspore/_extends/graph_kernel/model/model.py +553 -0
  15. mindspore/_extends/graph_kernel/model/model_builder.py +216 -0
  16. mindspore/_extends/graph_kernel/parallel_estimate.py +60 -0
  17. mindspore/_extends/graph_kernel/splitter.py +140 -0
  18. mindspore/_extends/graph_kernel/utils.py +28 -0
  19. mindspore/_extends/parallel_compile/__init__.py +19 -0
  20. mindspore/_extends/parallel_compile/akg_compiler/__init__.py +19 -0
  21. mindspore/_extends/parallel_compile/akg_compiler/akg_process.py +269 -0
  22. mindspore/_extends/parallel_compile/akg_compiler/build_tbe_kernel.py +529 -0
  23. mindspore/_extends/parallel_compile/akg_compiler/compiler.py +56 -0
  24. mindspore/_extends/parallel_compile/akg_compiler/gen_custom_op_files.py +96 -0
  25. mindspore/_extends/parallel_compile/akg_compiler/get_file_path.py +36 -0
  26. mindspore/_extends/parallel_compile/akg_compiler/tbe_topi.py +556 -0
  27. mindspore/_extends/parallel_compile/akg_compiler/util.py +159 -0
  28. mindspore/_extends/parse/__init__.py +49 -0
  29. mindspore/_extends/parse/compile_config.py +299 -0
  30. mindspore/_extends/parse/namespace.py +136 -0
  31. mindspore/_extends/parse/parser.py +1448 -0
  32. mindspore/_extends/parse/resources.py +213 -0
  33. mindspore/_extends/parse/standard_method.py +4475 -0
  34. mindspore/_extends/parse/trope.py +97 -0
  35. mindspore/_extends/pijit/__init__.py +23 -0
  36. mindspore/_extends/pijit/pijit_func_white_list.py +669 -0
  37. mindspore/_extends/remote/__init__.py +19 -0
  38. mindspore/_extends/remote/kernel_build_server.py +199 -0
  39. mindspore/_extends/remote/kernel_build_server_akg.py +55 -0
  40. mindspore/_extends/remote/kernel_build_server_akg_v2.py +55 -0
  41. mindspore/_extends/remote/kernel_build_server_ascend.py +75 -0
  42. mindspore/_extends/utils.py +68 -0
  43. mindspore/_install_custom.py +43 -0
  44. mindspore/_profiler.py +30 -0
  45. mindspore/amp.py +433 -0
  46. mindspore/boost/__init__.py +42 -0
  47. mindspore/boost/adasum.py +319 -0
  48. mindspore/boost/base.py +535 -0
  49. mindspore/boost/boost.py +400 -0
  50. mindspore/boost/boost_cell_wrapper.py +790 -0
  51. mindspore/boost/dim_reduce.py +323 -0
  52. mindspore/boost/grad_accumulation.py +79 -0
  53. mindspore/boost/grad_freeze.py +382 -0
  54. mindspore/boost/group_loss_scale_manager.py +166 -0
  55. mindspore/boost/less_batch_normalization.py +174 -0
  56. mindspore/common/__init__.py +86 -0
  57. mindspore/common/_auto_dynamic.py +68 -0
  58. mindspore/common/_decorator.py +50 -0
  59. mindspore/common/_jit_fallback_utils.py +110 -0
  60. mindspore/common/_monad.py +25 -0
  61. mindspore/common/_pijit_context.py +190 -0
  62. mindspore/common/_register_for_adapter.py +74 -0
  63. mindspore/common/_register_for_recompute.py +48 -0
  64. mindspore/common/_register_for_tensor.py +46 -0
  65. mindspore/common/_stub_tensor.py +210 -0
  66. mindspore/common/_tensor_overload.py +139 -0
  67. mindspore/common/_utils.py +122 -0
  68. mindspore/common/api.py +2064 -0
  69. mindspore/common/auto_dynamic_shape.py +507 -0
  70. mindspore/common/dtype.py +422 -0
  71. mindspore/common/dump.py +130 -0
  72. mindspore/common/file_system.py +48 -0
  73. mindspore/common/generator.py +254 -0
  74. mindspore/common/hook_handle.py +143 -0
  75. mindspore/common/initializer.py +880 -0
  76. mindspore/common/jit_config.py +98 -0
  77. mindspore/common/lazy_inline.py +240 -0
  78. mindspore/common/mindir_util.py +111 -0
  79. mindspore/common/mutable.py +234 -0
  80. mindspore/common/no_inline.py +54 -0
  81. mindspore/common/np_dtype.py +25 -0
  82. mindspore/common/parameter.py +1081 -0
  83. mindspore/common/recompute.py +292 -0
  84. mindspore/common/seed.py +260 -0
  85. mindspore/common/sparse_tensor.py +1175 -0
  86. mindspore/common/symbol.py +122 -0
  87. mindspore/common/tensor.py +5039 -0
  88. mindspore/communication/__init__.py +37 -0
  89. mindspore/communication/_comm_helper.py +501 -0
  90. mindspore/communication/_hccl_management.py +297 -0
  91. mindspore/communication/comm_func.py +1395 -0
  92. mindspore/communication/management.py +673 -0
  93. mindspore/config/op_info.config +533 -0
  94. mindspore/context.py +2077 -0
  95. mindspore/dataset/__init__.py +90 -0
  96. mindspore/dataset/audio/__init__.py +61 -0
  97. mindspore/dataset/audio/transforms.py +3690 -0
  98. mindspore/dataset/audio/utils.py +386 -0
  99. mindspore/dataset/audio/validators.py +1172 -0
  100. mindspore/dataset/callback/__init__.py +20 -0
  101. mindspore/dataset/callback/ds_callback.py +368 -0
  102. mindspore/dataset/callback/validators.py +32 -0
  103. mindspore/dataset/core/__init__.py +13 -0
  104. mindspore/dataset/core/config.py +1095 -0
  105. mindspore/dataset/core/datatypes.py +101 -0
  106. mindspore/dataset/core/py_util_helpers.py +65 -0
  107. mindspore/dataset/core/validator_helpers.py +781 -0
  108. mindspore/dataset/debug/__init__.py +21 -0
  109. mindspore/dataset/debug/debug_hook.py +97 -0
  110. mindspore/dataset/debug/pre_defined_hook.py +67 -0
  111. mindspore/dataset/engine/__init__.py +124 -0
  112. mindspore/dataset/engine/cache_admin.py +47 -0
  113. mindspore/dataset/engine/cache_client.py +129 -0
  114. mindspore/dataset/engine/datasets.py +4582 -0
  115. mindspore/dataset/engine/datasets_audio.py +911 -0
  116. mindspore/dataset/engine/datasets_standard_format.py +543 -0
  117. mindspore/dataset/engine/datasets_text.py +2161 -0
  118. mindspore/dataset/engine/datasets_user_defined.py +1184 -0
  119. mindspore/dataset/engine/datasets_vision.py +4816 -0
  120. mindspore/dataset/engine/iterators.py +371 -0
  121. mindspore/dataset/engine/obs/__init__.py +23 -0
  122. mindspore/dataset/engine/obs/config_loader.py +68 -0
  123. mindspore/dataset/engine/obs/obs_mindrecord_dataset.py +508 -0
  124. mindspore/dataset/engine/obs/util.py +482 -0
  125. mindspore/dataset/engine/offload.py +596 -0
  126. mindspore/dataset/engine/queue.py +304 -0
  127. mindspore/dataset/engine/samplers.py +895 -0
  128. mindspore/dataset/engine/serializer_deserializer.py +159 -0
  129. mindspore/dataset/engine/validators.py +2895 -0
  130. mindspore/dataset/text/__init__.py +51 -0
  131. mindspore/dataset/text/transforms.py +1703 -0
  132. mindspore/dataset/text/utils.py +715 -0
  133. mindspore/dataset/text/validators.py +642 -0
  134. mindspore/dataset/transforms/__init__.py +45 -0
  135. mindspore/dataset/transforms/c_transforms.py +638 -0
  136. mindspore/dataset/transforms/py_transforms.py +393 -0
  137. mindspore/dataset/transforms/py_transforms_util.py +255 -0
  138. mindspore/dataset/transforms/transforms.py +1260 -0
  139. mindspore/dataset/transforms/validators.py +410 -0
  140. mindspore/dataset/utils/__init__.py +19 -0
  141. mindspore/dataset/utils/browse_dataset.py +190 -0
  142. mindspore/dataset/utils/line_reader.py +126 -0
  143. mindspore/dataset/vision/__init__.py +65 -0
  144. mindspore/dataset/vision/c_transforms.py +2641 -0
  145. mindspore/dataset/vision/py_transforms.py +2120 -0
  146. mindspore/dataset/vision/py_transforms_util.py +1660 -0
  147. mindspore/dataset/vision/transforms.py +7295 -0
  148. mindspore/dataset/vision/utils.py +863 -0
  149. mindspore/dataset/vision/validators.py +1483 -0
  150. mindspore/default_config.py +2 -0
  151. mindspore/experimental/__init__.py +20 -0
  152. mindspore/experimental/es/__init__.py +22 -0
  153. mindspore/experimental/es/embedding_service.py +883 -0
  154. mindspore/experimental/es/embedding_service_layer.py +581 -0
  155. mindspore/experimental/llm_boost/__init__.py +21 -0
  156. mindspore/experimental/llm_boost/atb/__init__.py +23 -0
  157. mindspore/experimental/llm_boost/atb/boost_base.py +211 -0
  158. mindspore/experimental/llm_boost/atb/llama_boost.py +115 -0
  159. mindspore/experimental/llm_boost/atb/qwen_boost.py +101 -0
  160. mindspore/experimental/llm_boost/register.py +129 -0
  161. mindspore/experimental/llm_boost/utils.py +31 -0
  162. mindspore/experimental/map_parameter.py +309 -0
  163. mindspore/experimental/optim/__init__.py +40 -0
  164. mindspore/experimental/optim/adadelta.py +161 -0
  165. mindspore/experimental/optim/adagrad.py +168 -0
  166. mindspore/experimental/optim/adam.py +193 -0
  167. mindspore/experimental/optim/adamax.py +170 -0
  168. mindspore/experimental/optim/adamw.py +290 -0
  169. mindspore/experimental/optim/asgd.py +153 -0
  170. mindspore/experimental/optim/lr_scheduler.py +1371 -0
  171. mindspore/experimental/optim/nadam.py +157 -0
  172. mindspore/experimental/optim/optimizer.py +262 -0
  173. mindspore/experimental/optim/radam.py +194 -0
  174. mindspore/experimental/optim/rmsprop.py +154 -0
  175. mindspore/experimental/optim/rprop.py +164 -0
  176. mindspore/experimental/optim/sgd.py +156 -0
  177. mindspore/hal/__init__.py +40 -0
  178. mindspore/hal/_ascend.py +57 -0
  179. mindspore/hal/_base.py +57 -0
  180. mindspore/hal/_cpu.py +56 -0
  181. mindspore/hal/_gpu.py +57 -0
  182. mindspore/hal/contiguous_tensors_handle.py +175 -0
  183. mindspore/hal/device.py +356 -0
  184. mindspore/hal/event.py +179 -0
  185. mindspore/hal/memory.py +326 -0
  186. mindspore/hal/stream.py +357 -0
  187. mindspore/include/OWNERS +7 -0
  188. mindspore/include/api/allocator.h +97 -0
  189. mindspore/include/api/callback/callback.h +93 -0
  190. mindspore/include/api/callback/ckpt_saver.h +41 -0
  191. mindspore/include/api/callback/loss_monitor.h +33 -0
  192. mindspore/include/api/callback/lr_scheduler.h +51 -0
  193. mindspore/include/api/callback/time_monitor.h +34 -0
  194. mindspore/include/api/callback/train_accuracy.h +37 -0
  195. mindspore/include/api/cell.h +90 -0
  196. mindspore/include/api/cfg.h +82 -0
  197. mindspore/include/api/context.h +602 -0
  198. mindspore/include/api/data_type.h +47 -0
  199. mindspore/include/api/delegate.h +178 -0
  200. mindspore/include/api/delegate_api.h +75 -0
  201. mindspore/include/api/dual_abi_helper.h +208 -0
  202. mindspore/include/api/format.h +28 -0
  203. mindspore/include/api/graph.h +46 -0
  204. mindspore/include/api/kernel.h +58 -0
  205. mindspore/include/api/kernel_api.h +168 -0
  206. mindspore/include/api/metrics/accuracy.h +36 -0
  207. mindspore/include/api/metrics/metrics.h +41 -0
  208. mindspore/include/api/model.h +438 -0
  209. mindspore/include/api/model_group.h +91 -0
  210. mindspore/include/api/model_parallel_runner.h +168 -0
  211. mindspore/include/api/serialization.h +185 -0
  212. mindspore/include/api/status.h +192 -0
  213. mindspore/include/api/types.h +431 -0
  214. mindspore/include/api/visible.h +41 -0
  215. mindspore/include/c_api/context_c.h +179 -0
  216. mindspore/include/c_api/data_type_c.h +52 -0
  217. mindspore/include/c_api/format_c.h +46 -0
  218. mindspore/include/c_api/model_c.h +347 -0
  219. mindspore/include/c_api/status_c.h +79 -0
  220. mindspore/include/c_api/tensor_c.h +146 -0
  221. mindspore/include/c_api/types_c.h +67 -0
  222. mindspore/include/dataset/config.h +163 -0
  223. mindspore/include/dataset/constants.h +363 -0
  224. mindspore/include/dataset/execute.h +196 -0
  225. mindspore/include/dataset/text.h +1092 -0
  226. mindspore/include/dataset/transforms.h +638 -0
  227. mindspore/include/dataset/vision.h +2129 -0
  228. mindspore/include/dataset/vision_ascend.h +206 -0
  229. mindspore/include/dataset/vision_lite.h +625 -0
  230. mindspore/lib/libavcodec.59.dylib +0 -0
  231. mindspore/lib/libavdevice.59.dylib +0 -0
  232. mindspore/lib/libavfilter.8.dylib +0 -0
  233. mindspore/lib/libavformat.59.dylib +0 -0
  234. mindspore/lib/libavutil.57.dylib +0 -0
  235. mindspore/lib/libdnnl.2.dylib +0 -0
  236. mindspore/lib/libicudata.69.dylib +0 -0
  237. mindspore/lib/libicui18n.69.dylib +0 -0
  238. mindspore/lib/libicuuc.69.dylib +0 -0
  239. mindspore/lib/libmindspore_address_sorting.15.dylib +0 -0
  240. mindspore/lib/libmindspore_backend.dylib +0 -0
  241. mindspore/lib/libmindspore_common.dylib +0 -0
  242. mindspore/lib/libmindspore_core.dylib +0 -0
  243. mindspore/lib/libmindspore_glog.0.dylib +0 -0
  244. mindspore/lib/libmindspore_gpr.15.dylib +0 -0
  245. mindspore/lib/libmindspore_grpc++.1.dylib +0 -0
  246. mindspore/lib/libmindspore_grpc.15.dylib +0 -0
  247. mindspore/lib/libmindspore_np_dtype.dylib +0 -0
  248. mindspore/lib/libmindspore_ops.dylib +0 -0
  249. mindspore/lib/libmindspore_upb.15.dylib +0 -0
  250. mindspore/lib/libnnacl.dylib +0 -0
  251. mindspore/lib/libopencv_core.4.5.dylib +0 -0
  252. mindspore/lib/libopencv_imgcodecs.4.5.dylib +0 -0
  253. mindspore/lib/libopencv_imgproc.4.5.dylib +0 -0
  254. mindspore/lib/libps_cache.dylib +0 -0
  255. mindspore/lib/libswresample.4.dylib +0 -0
  256. mindspore/lib/libswscale.6.dylib +0 -0
  257. mindspore/lib/libtinyxml2.8.dylib +0 -0
  258. mindspore/log.py +633 -0
  259. mindspore/mindrecord/__init__.py +43 -0
  260. mindspore/mindrecord/common/__init__.py +17 -0
  261. mindspore/mindrecord/common/constant.py +20 -0
  262. mindspore/mindrecord/common/enums.py +44 -0
  263. mindspore/mindrecord/common/exceptions.py +311 -0
  264. mindspore/mindrecord/config.py +809 -0
  265. mindspore/mindrecord/filereader.py +174 -0
  266. mindspore/mindrecord/filewriter.py +722 -0
  267. mindspore/mindrecord/mindpage.py +210 -0
  268. mindspore/mindrecord/shardheader.py +141 -0
  269. mindspore/mindrecord/shardindexgenerator.py +74 -0
  270. mindspore/mindrecord/shardreader.py +117 -0
  271. mindspore/mindrecord/shardsegment.py +128 -0
  272. mindspore/mindrecord/shardutils.py +185 -0
  273. mindspore/mindrecord/shardwriter.py +237 -0
  274. mindspore/mindrecord/tools/__init__.py +17 -0
  275. mindspore/mindrecord/tools/cifar10.py +140 -0
  276. mindspore/mindrecord/tools/cifar100.py +153 -0
  277. mindspore/mindrecord/tools/cifar100_to_mr.py +185 -0
  278. mindspore/mindrecord/tools/cifar10_to_mr.py +177 -0
  279. mindspore/mindrecord/tools/csv_to_mr.py +200 -0
  280. mindspore/mindrecord/tools/imagenet_to_mr.py +206 -0
  281. mindspore/mindrecord/tools/mnist_to_mr.py +259 -0
  282. mindspore/mindrecord/tools/tfrecord_to_mr.py +360 -0
  283. mindspore/mint/__init__.py +1586 -0
  284. mindspore/mint/distributed/__init__.py +31 -0
  285. mindspore/mint/distributed/distributed.py +254 -0
  286. mindspore/mint/linalg/__init__.py +22 -0
  287. mindspore/mint/nn/__init__.py +757 -0
  288. mindspore/mint/nn/functional.py +679 -0
  289. mindspore/mint/nn/layer/__init__.py +39 -0
  290. mindspore/mint/nn/layer/activation.py +133 -0
  291. mindspore/mint/nn/layer/normalization.py +477 -0
  292. mindspore/mint/nn/layer/pooling.py +110 -0
  293. mindspore/mint/optim/__init__.py +24 -0
  294. mindspore/mint/optim/adamw.py +206 -0
  295. mindspore/mint/special/__init__.py +63 -0
  296. mindspore/multiprocessing/__init__.py +73 -0
  297. mindspore/nn/__init__.py +47 -0
  298. mindspore/nn/cell.py +2787 -0
  299. mindspore/nn/dynamic_lr.py +482 -0
  300. mindspore/nn/grad/__init__.py +21 -0
  301. mindspore/nn/grad/cell_grad.py +196 -0
  302. mindspore/nn/layer/__init__.py +63 -0
  303. mindspore/nn/layer/activation.py +1822 -0
  304. mindspore/nn/layer/basic.py +1629 -0
  305. mindspore/nn/layer/channel_shuffle.py +90 -0
  306. mindspore/nn/layer/combined.py +248 -0
  307. mindspore/nn/layer/container.py +734 -0
  308. mindspore/nn/layer/conv.py +1505 -0
  309. mindspore/nn/layer/dense.py +204 -0
  310. mindspore/nn/layer/embedding.py +869 -0
  311. mindspore/nn/layer/image.py +661 -0
  312. mindspore/nn/layer/math.py +1069 -0
  313. mindspore/nn/layer/normalization.py +1273 -0
  314. mindspore/nn/layer/padding.py +880 -0
  315. mindspore/nn/layer/pooling.py +2302 -0
  316. mindspore/nn/layer/rnn_cells.py +388 -0
  317. mindspore/nn/layer/rnns.py +849 -0
  318. mindspore/nn/layer/thor_layer.py +963 -0
  319. mindspore/nn/layer/timedistributed.py +155 -0
  320. mindspore/nn/layer/transformer.py +823 -0
  321. mindspore/nn/learning_rate_schedule.py +512 -0
  322. mindspore/nn/loss/__init__.py +36 -0
  323. mindspore/nn/loss/loss.py +2924 -0
  324. mindspore/nn/metrics.py +53 -0
  325. mindspore/nn/optim/__init__.py +45 -0
  326. mindspore/nn/optim/_dist_optimizer_registry.py +111 -0
  327. mindspore/nn/optim/ada_grad.py +217 -0
  328. mindspore/nn/optim/adadelta.py +206 -0
  329. mindspore/nn/optim/adafactor.py +448 -0
  330. mindspore/nn/optim/adam.py +1297 -0
  331. mindspore/nn/optim/adamax.py +220 -0
  332. mindspore/nn/optim/adasum.py +548 -0
  333. mindspore/nn/optim/asgd.py +216 -0
  334. mindspore/nn/optim/ftrl.py +401 -0
  335. mindspore/nn/optim/lamb.py +296 -0
  336. mindspore/nn/optim/lars.py +202 -0
  337. mindspore/nn/optim/lazyadam.py +533 -0
  338. mindspore/nn/optim/momentum.py +239 -0
  339. mindspore/nn/optim/optimizer.py +1034 -0
  340. mindspore/nn/optim/proximal_ada_grad.py +242 -0
  341. mindspore/nn/optim/rmsprop.py +264 -0
  342. mindspore/nn/optim/rprop.py +251 -0
  343. mindspore/nn/optim/sgd.py +237 -0
  344. mindspore/nn/optim/tft_wrapper.py +127 -0
  345. mindspore/nn/optim/thor.py +1310 -0
  346. mindspore/nn/probability/__init__.py +22 -0
  347. mindspore/nn/probability/bijector/__init__.py +35 -0
  348. mindspore/nn/probability/bijector/bijector.py +337 -0
  349. mindspore/nn/probability/bijector/exp.py +65 -0
  350. mindspore/nn/probability/bijector/gumbel_cdf.py +144 -0
  351. mindspore/nn/probability/bijector/invert.py +126 -0
  352. mindspore/nn/probability/bijector/power_transform.py +196 -0
  353. mindspore/nn/probability/bijector/scalar_affine.py +167 -0
  354. mindspore/nn/probability/bijector/softplus.py +189 -0
  355. mindspore/nn/probability/bnn_layers/__init__.py +29 -0
  356. mindspore/nn/probability/bnn_layers/_util.py +46 -0
  357. mindspore/nn/probability/bnn_layers/bnn_cell_wrapper.py +112 -0
  358. mindspore/nn/probability/bnn_layers/conv_variational.py +267 -0
  359. mindspore/nn/probability/bnn_layers/dense_variational.py +302 -0
  360. mindspore/nn/probability/bnn_layers/layer_distribution.py +123 -0
  361. mindspore/nn/probability/distribution/__init__.py +56 -0
  362. mindspore/nn/probability/distribution/_utils/__init__.py +34 -0
  363. mindspore/nn/probability/distribution/_utils/custom_ops.py +96 -0
  364. mindspore/nn/probability/distribution/_utils/utils.py +362 -0
  365. mindspore/nn/probability/distribution/bernoulli.py +334 -0
  366. mindspore/nn/probability/distribution/beta.py +391 -0
  367. mindspore/nn/probability/distribution/categorical.py +435 -0
  368. mindspore/nn/probability/distribution/cauchy.py +383 -0
  369. mindspore/nn/probability/distribution/distribution.py +827 -0
  370. mindspore/nn/probability/distribution/exponential.py +350 -0
  371. mindspore/nn/probability/distribution/gamma.py +391 -0
  372. mindspore/nn/probability/distribution/geometric.py +335 -0
  373. mindspore/nn/probability/distribution/gumbel.py +257 -0
  374. mindspore/nn/probability/distribution/half_normal.py +133 -0
  375. mindspore/nn/probability/distribution/laplace.py +128 -0
  376. mindspore/nn/probability/distribution/log_normal.py +272 -0
  377. mindspore/nn/probability/distribution/logistic.py +379 -0
  378. mindspore/nn/probability/distribution/normal.py +336 -0
  379. mindspore/nn/probability/distribution/poisson.py +288 -0
  380. mindspore/nn/probability/distribution/student_t.py +149 -0
  381. mindspore/nn/probability/distribution/transformed_distribution.py +235 -0
  382. mindspore/nn/probability/distribution/uniform.py +375 -0
  383. mindspore/nn/reinforcement/__init__.py +24 -0
  384. mindspore/nn/reinforcement/_batch_read_write.py +142 -0
  385. mindspore/nn/reinforcement/_tensors_queue.py +152 -0
  386. mindspore/nn/reinforcement/tensor_array.py +145 -0
  387. mindspore/nn/sparse/__init__.py +23 -0
  388. mindspore/nn/sparse/sparse.py +147 -0
  389. mindspore/nn/wrap/__init__.py +49 -0
  390. mindspore/nn/wrap/cell_wrapper.py +968 -0
  391. mindspore/nn/wrap/grad_reducer.py +608 -0
  392. mindspore/nn/wrap/loss_scale.py +694 -0
  393. mindspore/numpy/__init__.py +121 -0
  394. mindspore/numpy/array_creations.py +2731 -0
  395. mindspore/numpy/array_ops.py +2629 -0
  396. mindspore/numpy/dtypes.py +185 -0
  397. mindspore/numpy/fft.py +966 -0
  398. mindspore/numpy/logic_ops.py +936 -0
  399. mindspore/numpy/math_ops.py +5911 -0
  400. mindspore/numpy/utils.py +214 -0
  401. mindspore/numpy/utils_const.py +565 -0
  402. mindspore/ops/__init__.py +56 -0
  403. mindspore/ops/_constants.py +30 -0
  404. mindspore/ops/_grad_experimental/__init__.py +31 -0
  405. mindspore/ops/_grad_experimental/grad_array_ops.py +830 -0
  406. mindspore/ops/_grad_experimental/grad_base.py +143 -0
  407. mindspore/ops/_grad_experimental/grad_comm_ops.py +714 -0
  408. mindspore/ops/_grad_experimental/grad_debug_ops.py +31 -0
  409. mindspore/ops/_grad_experimental/grad_implementations.py +203 -0
  410. mindspore/ops/_grad_experimental/grad_inner_ops.py +79 -0
  411. mindspore/ops/_grad_experimental/grad_math_ops.py +802 -0
  412. mindspore/ops/_grad_experimental/grad_nn_ops.py +231 -0
  413. mindspore/ops/_grad_experimental/grad_quant_ops.py +238 -0
  414. mindspore/ops/_grad_experimental/grad_sparse.py +342 -0
  415. mindspore/ops/_grad_experimental/grad_sparse_ops.py +399 -0
  416. mindspore/ops/_grad_experimental/taylor_rule.py +220 -0
  417. mindspore/ops/_op_impl/__init__.py +23 -0
  418. mindspore/ops/_op_impl/_custom_op/__init__.py +39 -0
  419. mindspore/ops/_op_impl/_custom_op/_basic.py +158 -0
  420. mindspore/ops/_op_impl/_custom_op/batch_matmul_impl.py +279 -0
  421. mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py +156 -0
  422. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py +109 -0
  423. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py +125 -0
  424. mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py +105 -0
  425. mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py +124 -0
  426. mindspore/ops/_op_impl/_custom_op/cholesky_trsm_impl.py +116 -0
  427. mindspore/ops/_op_impl/_custom_op/correction_mul.py +89 -0
  428. mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +196 -0
  429. mindspore/ops/_op_impl/_custom_op/dsd_back_impl.py +366 -0
  430. mindspore/ops/_op_impl/_custom_op/dsd_impl.py +162 -0
  431. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel.py +136 -0
  432. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad.py +206 -0
  433. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perchannel_grad_reduce.py +88 -0
  434. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer.py +128 -0
  435. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad.py +199 -0
  436. mindspore/ops/_op_impl/_custom_op/fake_learned_scale_quant_perlayer_grad_reduce.py +88 -0
  437. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +156 -0
  438. mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py +184 -0
  439. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py +143 -0
  440. mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py +169 -0
  441. mindspore/ops/_op_impl/_custom_op/fused_abs_max1_impl.py +548 -0
  442. mindspore/ops/_op_impl/_custom_op/img2col_impl.py +881 -0
  443. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_left_impl.py +278 -0
  444. mindspore/ops/_op_impl/_custom_op/matmul_cube_dense_right_impl.py +200 -0
  445. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_left_cast_impl.py +334 -0
  446. mindspore/ops/_op_impl/_custom_op/matmul_cube_fracz_right_mul_impl.py +255 -0
  447. mindspore/ops/_op_impl/_custom_op/matmul_cube_impl.py +222 -0
  448. mindspore/ops/_op_impl/_custom_op/matmul_dds_grad_impl.py +644 -0
  449. mindspore/ops/_op_impl/_custom_op/matmul_dds_impl.py +488 -0
  450. mindspore/ops/_op_impl/_custom_op/matrix_combine_impl.py +87 -0
  451. mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py +129 -0
  452. mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py +121 -0
  453. mindspore/ops/_op_impl/_custom_op/transpose02314_impl.py +352 -0
  454. mindspore/ops/_op_impl/aicpu/__init__.py +441 -0
  455. mindspore/ops/_op_impl/aicpu/abs.py +36 -0
  456. mindspore/ops/_op_impl/aicpu/acos.py +32 -0
  457. mindspore/ops/_op_impl/aicpu/acos_grad.py +33 -0
  458. mindspore/ops/_op_impl/aicpu/acosh.py +34 -0
  459. mindspore/ops/_op_impl/aicpu/acosh_grad.py +35 -0
  460. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d.py +34 -0
  461. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_2d_grad.py +34 -0
  462. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d.py +39 -0
  463. mindspore/ops/_op_impl/aicpu/adaptive_avg_pool_3d_grad.py +39 -0
  464. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d.py +37 -0
  465. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_2d_grad.py +37 -0
  466. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d.py +42 -0
  467. mindspore/ops/_op_impl/aicpu/adaptive_max_pool_3d_grad.py +152 -0
  468. mindspore/ops/_op_impl/aicpu/add.py +43 -0
  469. mindspore/ops/_op_impl/aicpu/add_n.py +41 -0
  470. mindspore/ops/_op_impl/aicpu/add_v2.py +40 -0
  471. mindspore/ops/_op_impl/aicpu/addcdiv.py +41 -0
  472. mindspore/ops/_op_impl/aicpu/addcmul.py +47 -0
  473. mindspore/ops/_op_impl/aicpu/adjust_contrastv2.py +32 -0
  474. mindspore/ops/_op_impl/aicpu/adjust_hue.py +31 -0
  475. mindspore/ops/_op_impl/aicpu/adjust_saturation.py +32 -0
  476. mindspore/ops/_op_impl/aicpu/affine_grid.py +33 -0
  477. mindspore/ops/_op_impl/aicpu/affine_grid_grad.py +35 -0
  478. mindspore/ops/_op_impl/aicpu/angle.py +31 -0
  479. mindspore/ops/_op_impl/aicpu/arg_max.py +75 -0
  480. mindspore/ops/_op_impl/aicpu/arg_min.py +75 -0
  481. mindspore/ops/_op_impl/aicpu/argmax_with_value.py +43 -0
  482. mindspore/ops/_op_impl/aicpu/argmin_with_value.py +43 -0
  483. mindspore/ops/_op_impl/aicpu/asin.py +32 -0
  484. mindspore/ops/_op_impl/aicpu/asin_grad.py +33 -0
  485. mindspore/ops/_op_impl/aicpu/asinh.py +34 -0
  486. mindspore/ops/_op_impl/aicpu/asinh_grad.py +35 -0
  487. mindspore/ops/_op_impl/aicpu/atanh.py +34 -0
  488. mindspore/ops/_op_impl/aicpu/avgpool_grad_v1.py +37 -0
  489. mindspore/ops/_op_impl/aicpu/avgpool_v1.py +36 -0
  490. mindspore/ops/_op_impl/aicpu/bartlett_window.py +36 -0
  491. mindspore/ops/_op_impl/aicpu/batch_matmul.py +43 -0
  492. mindspore/ops/_op_impl/aicpu/batch_norm_grad_grad.py +49 -0
  493. mindspore/ops/_op_impl/aicpu/bernoulli.py +48 -0
  494. mindspore/ops/_op_impl/aicpu/bessel_i0.py +31 -0
  495. mindspore/ops/_op_impl/aicpu/betainc.py +31 -0
  496. mindspore/ops/_op_impl/aicpu/bias_add.py +44 -0
  497. mindspore/ops/_op_impl/aicpu/bias_add_grad.py +42 -0
  498. mindspore/ops/_op_impl/aicpu/bincount.py +33 -0
  499. mindspore/ops/_op_impl/aicpu/blackman_window.py +36 -0
  500. mindspore/ops/_op_impl/aicpu/broadcast_to.py +58 -0
  501. mindspore/ops/_op_impl/aicpu/bucketize.py +34 -0
  502. mindspore/ops/_op_impl/aicpu/cache_swap_table.py +102 -0
  503. mindspore/ops/_op_impl/aicpu/cast.py +225 -0
  504. mindspore/ops/_op_impl/aicpu/cauchy.py +33 -0
  505. mindspore/ops/_op_impl/aicpu/channel_shuffle.py +40 -0
  506. mindspore/ops/_op_impl/aicpu/check_numerics.py +33 -0
  507. mindspore/ops/_op_impl/aicpu/cholesky.py +32 -0
  508. mindspore/ops/_op_impl/aicpu/cholesky_inverse.py +31 -0
  509. mindspore/ops/_op_impl/aicpu/cholesky_solve.py +33 -0
  510. mindspore/ops/_op_impl/aicpu/choleskygrad.py +32 -0
  511. mindspore/ops/_op_impl/aicpu/coalesce.py +37 -0
  512. mindspore/ops/_op_impl/aicpu/col2im.py +38 -0
  513. mindspore/ops/_op_impl/aicpu/combined_non_max_suppression.py +42 -0
  514. mindspore/ops/_op_impl/aicpu/compare_and_bitpack.py +37 -0
  515. mindspore/ops/_op_impl/aicpu/complex.py +32 -0
  516. mindspore/ops/_op_impl/aicpu/complex_abs.py +31 -0
  517. mindspore/ops/_op_impl/aicpu/compute_accidental_hits.py +44 -0
  518. mindspore/ops/_op_impl/aicpu/concat.py +57 -0
  519. mindspore/ops/_op_impl/aicpu/concat_offset.py +42 -0
  520. mindspore/ops/_op_impl/aicpu/concat_offset_v1.py +31 -0
  521. mindspore/ops/_op_impl/aicpu/conj.py +42 -0
  522. mindspore/ops/_op_impl/aicpu/conjugate_transpose.py +58 -0
  523. mindspore/ops/_op_impl/aicpu/cos.py +34 -0
  524. mindspore/ops/_op_impl/aicpu/cosh.py +34 -0
  525. mindspore/ops/_op_impl/aicpu/count_nonzero.py +43 -0
  526. mindspore/ops/_op_impl/aicpu/crop_and_resize.py +69 -0
  527. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_boxes.py +68 -0
  528. mindspore/ops/_op_impl/aicpu/crop_and_resize_grad_image.py +38 -0
  529. mindspore/ops/_op_impl/aicpu/cross.py +42 -0
  530. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_dense.py +48 -0
  531. mindspore/ops/_op_impl/aicpu/csr_sparse_matrix_to_sparse_tensor.py +51 -0
  532. mindspore/ops/_op_impl/aicpu/ctc_greedy_decoder.py +35 -0
  533. mindspore/ops/_op_impl/aicpu/ctc_loss_v2.py +43 -0
  534. mindspore/ops/_op_impl/aicpu/ctc_loss_v2_grad.py +45 -0
  535. mindspore/ops/_op_impl/aicpu/ctcloss.py +38 -0
  536. mindspore/ops/_op_impl/aicpu/cummax.py +41 -0
  537. mindspore/ops/_op_impl/aicpu/cumprod.py +58 -0
  538. mindspore/ops/_op_impl/aicpu/cumsum.py +58 -0
  539. mindspore/ops/_op_impl/aicpu/cumulative_logsumexp.py +36 -0
  540. mindspore/ops/_op_impl/aicpu/data_format_vec_permute.py +32 -0
  541. mindspore/ops/_op_impl/aicpu/deformable_offsets.py +38 -0
  542. mindspore/ops/_op_impl/aicpu/deformable_offsets_grad.py +43 -0
  543. mindspore/ops/_op_impl/aicpu/dense_to_csr_sparse_matrix.py +49 -0
  544. mindspore/ops/_op_impl/aicpu/dense_to_dense_set_operation.py +45 -0
  545. mindspore/ops/_op_impl/aicpu/dense_to_sparse_set_operation.py +48 -0
  546. mindspore/ops/_op_impl/aicpu/depth_to_space.py +44 -0
  547. mindspore/ops/_op_impl/aicpu/diag.py +36 -0
  548. mindspore/ops/_op_impl/aicpu/diag_part.py +36 -0
  549. mindspore/ops/_op_impl/aicpu/diagonal.py +35 -0
  550. mindspore/ops/_op_impl/aicpu/digamma.py +31 -0
  551. mindspore/ops/_op_impl/aicpu/div.py +41 -0
  552. mindspore/ops/_op_impl/aicpu/div_no_nan.py +35 -0
  553. mindspore/ops/_op_impl/aicpu/dropout2d.py +42 -0
  554. mindspore/ops/_op_impl/aicpu/dropout3d.py +42 -0
  555. mindspore/ops/_op_impl/aicpu/dropout_genmask.py +41 -0
  556. mindspore/ops/_op_impl/aicpu/dropout_genmask_v3.py +32 -0
  557. mindspore/ops/_op_impl/aicpu/dynamic_stitch.py +42 -0
  558. mindspore/ops/_op_impl/aicpu/edit_distance.py +56 -0
  559. mindspore/ops/_op_impl/aicpu/eig.py +35 -0
  560. mindspore/ops/_op_impl/aicpu/embedding_lookup.py +102 -0
  561. mindspore/ops/_op_impl/aicpu/end_of_sequence.py +30 -0
  562. mindspore/ops/_op_impl/aicpu/environ_create.py +28 -0
  563. mindspore/ops/_op_impl/aicpu/environ_destroy_all.py +28 -0
  564. mindspore/ops/_op_impl/aicpu/environ_get.py +41 -0
  565. mindspore/ops/_op_impl/aicpu/environ_set.py +40 -0
  566. mindspore/ops/_op_impl/aicpu/eps.py +32 -0
  567. mindspore/ops/_op_impl/aicpu/equal.py +41 -0
  568. mindspore/ops/_op_impl/aicpu/exp.py +37 -0
  569. mindspore/ops/_op_impl/aicpu/expand.py +45 -0
  570. mindspore/ops/_op_impl/aicpu/expand_dims.py +42 -0
  571. mindspore/ops/_op_impl/aicpu/expm1.py +34 -0
  572. mindspore/ops/_op_impl/aicpu/extract_glimpse.py +35 -0
  573. mindspore/ops/_op_impl/aicpu/eye.py +44 -0
  574. mindspore/ops/_op_impl/aicpu/fft_with_size.py +47 -0
  575. mindspore/ops/_op_impl/aicpu/fill_diagonal.py +39 -0
  576. mindspore/ops/_op_impl/aicpu/fill_v2.py +58 -0
  577. mindspore/ops/_op_impl/aicpu/flatten.py +43 -0
  578. mindspore/ops/_op_impl/aicpu/floor_div.py +38 -0
  579. mindspore/ops/_op_impl/aicpu/fmax.py +36 -0
  580. mindspore/ops/_op_impl/aicpu/fmin.py +37 -0
  581. mindspore/ops/_op_impl/aicpu/fractional_avg_pool.py +41 -0
  582. mindspore/ops/_op_impl/aicpu/fractional_avg_pool_grad.py +41 -0
  583. mindspore/ops/_op_impl/aicpu/fractional_max_pool.py +41 -0
  584. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_grad_with_fixed_ksize.py +43 -0
  585. mindspore/ops/_op_impl/aicpu/fractional_max_pool3d_with_fixed_ksize.py +65 -0
  586. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad.py +42 -0
  587. mindspore/ops/_op_impl/aicpu/fractional_max_pool_grad_with_fixed_ksize.py +42 -0
  588. mindspore/ops/_op_impl/aicpu/fractional_max_pool_with_fixed_ksize.py +49 -0
  589. mindspore/ops/_op_impl/aicpu/fse_decode.py +43 -0
  590. mindspore/ops/_op_impl/aicpu/fused_sparse_adam.py +46 -0
  591. mindspore/ops/_op_impl/aicpu/fused_sparse_ftrl.py +41 -0
  592. mindspore/ops/_op_impl/aicpu/fused_sparse_lazy_adam.py +46 -0
  593. mindspore/ops/_op_impl/aicpu/fused_sparse_proximal_adagrad.py +39 -0
  594. mindspore/ops/_op_impl/aicpu/gamma.py +38 -0
  595. mindspore/ops/_op_impl/aicpu/gather.py +46 -0
  596. mindspore/ops/_op_impl/aicpu/gather_d.py +79 -0
  597. mindspore/ops/_op_impl/aicpu/gather_d_grad_v2.py +79 -0
  598. mindspore/ops/_op_impl/aicpu/gather_grad.py +54 -0
  599. mindspore/ops/_op_impl/aicpu/gather_nd.py +56 -0
  600. mindspore/ops/_op_impl/aicpu/gcd.py +32 -0
  601. mindspore/ops/_op_impl/aicpu/generate_eod_mask.py +38 -0
  602. mindspore/ops/_op_impl/aicpu/geqrf.py +32 -0
  603. mindspore/ops/_op_impl/aicpu/get_next.py +39 -0
  604. mindspore/ops/_op_impl/aicpu/glu.py +33 -0
  605. mindspore/ops/_op_impl/aicpu/glu_grad.py +34 -0
  606. mindspore/ops/_op_impl/aicpu/greater.py +41 -0
  607. mindspore/ops/_op_impl/aicpu/greater_equal.py +41 -0
  608. mindspore/ops/_op_impl/aicpu/grid_sampler_2d.py +35 -0
  609. mindspore/ops/_op_impl/aicpu/grid_sampler_2d_grad.py +38 -0
  610. mindspore/ops/_op_impl/aicpu/grid_sampler_3d.py +34 -0
  611. mindspore/ops/_op_impl/aicpu/grid_sampler_3d_grad.py +38 -0
  612. mindspore/ops/_op_impl/aicpu/hamming_window.py +57 -0
  613. mindspore/ops/_op_impl/aicpu/hard_sigmoid.py +32 -0
  614. mindspore/ops/_op_impl/aicpu/hard_sigmoid_grad.py +33 -0
  615. mindspore/ops/_op_impl/aicpu/heaviside.py +40 -0
  616. mindspore/ops/_op_impl/aicpu/histogram.py +35 -0
  617. mindspore/ops/_op_impl/aicpu/hsv_to_rgb.py +32 -0
  618. mindspore/ops/_op_impl/aicpu/hypot.py +32 -0
  619. mindspore/ops/_op_impl/aicpu/identity.py +42 -0
  620. mindspore/ops/_op_impl/aicpu/identity_n.py +41 -0
  621. mindspore/ops/_op_impl/aicpu/igamma.py +30 -0
  622. mindspore/ops/_op_impl/aicpu/igammac.py +30 -0
  623. mindspore/ops/_op_impl/aicpu/igammagrada.py +30 -0
  624. mindspore/ops/_op_impl/aicpu/im2col.py +43 -0
  625. mindspore/ops/_op_impl/aicpu/imag.py +31 -0
  626. mindspore/ops/_op_impl/aicpu/index_fill.py +54 -0
  627. mindspore/ops/_op_impl/aicpu/index_put.py +50 -0
  628. mindspore/ops/_op_impl/aicpu/init_data_set_queue.py +27 -0
  629. mindspore/ops/_op_impl/aicpu/inplace_index_add.py +39 -0
  630. mindspore/ops/_op_impl/aicpu/instance_norm_v2.py +41 -0
  631. mindspore/ops/_op_impl/aicpu/instance_norm_v2_grad.py +44 -0
  632. mindspore/ops/_op_impl/aicpu/is_finite.py +40 -0
  633. mindspore/ops/_op_impl/aicpu/is_inf.py +31 -0
  634. mindspore/ops/_op_impl/aicpu/is_nan.py +31 -0
  635. mindspore/ops/_op_impl/aicpu/kldivloss.py +34 -0
  636. mindspore/ops/_op_impl/aicpu/kldivlossgrad.py +35 -0
  637. mindspore/ops/_op_impl/aicpu/layer_norm_grad_grad.py +47 -0
  638. mindspore/ops/_op_impl/aicpu/lcm.py +32 -0
  639. mindspore/ops/_op_impl/aicpu/left_shift.py +38 -0
  640. mindspore/ops/_op_impl/aicpu/less.py +41 -0
  641. mindspore/ops/_op_impl/aicpu/less_equal.py +41 -0
  642. mindspore/ops/_op_impl/aicpu/lgamma.py +33 -0
  643. mindspore/ops/_op_impl/aicpu/linear_sum_assignment.py +57 -0
  644. mindspore/ops/_op_impl/aicpu/linspace.py +33 -0
  645. mindspore/ops/_op_impl/aicpu/list_diff.py +50 -0
  646. mindspore/ops/_op_impl/aicpu/log.py +37 -0
  647. mindspore/ops/_op_impl/aicpu/log1p.py +34 -0
  648. mindspore/ops/_op_impl/aicpu/log_matrix_determinant.py +31 -0
  649. mindspore/ops/_op_impl/aicpu/log_normal_reverse.py +33 -0
  650. mindspore/ops/_op_impl/aicpu/log_uniform_candidate_sampler.py +37 -0
  651. mindspore/ops/_op_impl/aicpu/logical_xor.py +30 -0
  652. mindspore/ops/_op_impl/aicpu/logit.py +33 -0
  653. mindspore/ops/_op_impl/aicpu/logit_grad.py +34 -0
  654. mindspore/ops/_op_impl/aicpu/logspace.py +36 -0
  655. mindspore/ops/_op_impl/aicpu/lower_bound.py +47 -0
  656. mindspore/ops/_op_impl/aicpu/lstsq.py +34 -0
  657. mindspore/ops/_op_impl/aicpu/lu.py +39 -0
  658. mindspore/ops/_op_impl/aicpu/lu_solve.py +32 -0
  659. mindspore/ops/_op_impl/aicpu/lu_unpack.py +114 -0
  660. mindspore/ops/_op_impl/aicpu/lu_unpack_grad.py +49 -0
  661. mindspore/ops/_op_impl/aicpu/masked_fill.py +42 -0
  662. mindspore/ops/_op_impl/aicpu/masked_scatter.py +40 -0
  663. mindspore/ops/_op_impl/aicpu/masked_select.py +31 -0
  664. mindspore/ops/_op_impl/aicpu/masked_select_grad.py +35 -0
  665. mindspore/ops/_op_impl/aicpu/matmul.py +39 -0
  666. mindspore/ops/_op_impl/aicpu/matrix_band_part.py +59 -0
  667. mindspore/ops/_op_impl/aicpu/matrix_determinant.py +30 -0
  668. mindspore/ops/_op_impl/aicpu/matrix_diag_part_v3.py +54 -0
  669. mindspore/ops/_op_impl/aicpu/matrix_diag_v3.py +56 -0
  670. mindspore/ops/_op_impl/aicpu/matrix_exp.py +34 -0
  671. mindspore/ops/_op_impl/aicpu/matrix_inverse.py +31 -0
  672. mindspore/ops/_op_impl/aicpu/matrix_logarithm.py +31 -0
  673. mindspore/ops/_op_impl/aicpu/matrix_power.py +37 -0
  674. mindspore/ops/_op_impl/aicpu/matrix_set_diag_v3.py +54 -0
  675. mindspore/ops/_op_impl/aicpu/matrix_solve.py +35 -0
  676. mindspore/ops/_op_impl/aicpu/matrix_solve_ls.py +36 -0
  677. mindspore/ops/_op_impl/aicpu/matrix_triangular_solve.py +36 -0
  678. mindspore/ops/_op_impl/aicpu/max_pool3d_grad_with_argmax.py +60 -0
  679. mindspore/ops/_op_impl/aicpu/max_pool3d_with_argmax.py +59 -0
  680. mindspore/ops/_op_impl/aicpu/max_unpool2d.py +57 -0
  681. mindspore/ops/_op_impl/aicpu/max_unpool2d_grad.py +58 -0
  682. mindspore/ops/_op_impl/aicpu/max_unpool3d.py +57 -0
  683. mindspore/ops/_op_impl/aicpu/max_unpool3d_grad.py +58 -0
  684. mindspore/ops/_op_impl/aicpu/maximum_grad_grad.py +40 -0
  685. mindspore/ops/_op_impl/aicpu/maxpool_grad_v1.py +46 -0
  686. mindspore/ops/_op_impl/aicpu/maxpool_v1.py +42 -0
  687. mindspore/ops/_op_impl/aicpu/median.py +39 -0
  688. mindspore/ops/_op_impl/aicpu/median_grad.py +45 -0
  689. mindspore/ops/_op_impl/aicpu/meshgrid.py +41 -0
  690. mindspore/ops/_op_impl/aicpu/minimum_grad_grad.py +40 -0
  691. mindspore/ops/_op_impl/aicpu/mirror_pad.py +50 -0
  692. mindspore/ops/_op_impl/aicpu/mirror_pad_grad.py +48 -0
  693. mindspore/ops/_op_impl/aicpu/mul.py +43 -0
  694. mindspore/ops/_op_impl/aicpu/mul_no_nan.py +42 -0
  695. mindspore/ops/_op_impl/aicpu/multi_margin_loss.py +37 -0
  696. mindspore/ops/_op_impl/aicpu/multi_margin_loss_grad.py +41 -0
  697. mindspore/ops/_op_impl/aicpu/multilabel_margin_loss_grad.py +37 -0
  698. mindspore/ops/_op_impl/aicpu/multinomial.py +47 -0
  699. mindspore/ops/_op_impl/aicpu/multinomial_with_replacement.py +35 -0
  700. mindspore/ops/_op_impl/aicpu/mvlgamma.py +32 -0
  701. mindspore/ops/_op_impl/aicpu/mvlgamma_grad.py +33 -0
  702. mindspore/ops/_op_impl/aicpu/nan_to_num.py +34 -0
  703. mindspore/ops/_op_impl/aicpu/neg.py +36 -0
  704. mindspore/ops/_op_impl/aicpu/nextafter.py +32 -0
  705. mindspore/ops/_op_impl/aicpu/nllloss.py +38 -0
  706. mindspore/ops/_op_impl/aicpu/nllloss_grad.py +39 -0
  707. mindspore/ops/_op_impl/aicpu/no_repeat_ngram.py +34 -0
  708. mindspore/ops/_op_impl/aicpu/non_deterministic_ints.py +33 -0
  709. mindspore/ops/_op_impl/aicpu/non_max_suppression.py +36 -0
  710. mindspore/ops/_op_impl/aicpu/non_max_suppression_with_overlaps.py +35 -0
  711. mindspore/ops/_op_impl/aicpu/non_zero.py +43 -0
  712. mindspore/ops/_op_impl/aicpu/not_equal.py +39 -0
  713. mindspore/ops/_op_impl/aicpu/nth_element.py +39 -0
  714. mindspore/ops/_op_impl/aicpu/nuclear_norm.py +33 -0
  715. mindspore/ops/_op_impl/aicpu/one_hot.py +116 -0
  716. mindspore/ops/_op_impl/aicpu/ones_like.py +39 -0
  717. mindspore/ops/_op_impl/aicpu/orgqr.py +34 -0
  718. mindspore/ops/_op_impl/aicpu/pad_and_shift.py +33 -0
  719. mindspore/ops/_op_impl/aicpu/pad_v3.py +61 -0
  720. mindspore/ops/_op_impl/aicpu/pad_v3_grad.py +59 -0
  721. mindspore/ops/_op_impl/aicpu/padding.py +41 -0
  722. mindspore/ops/_op_impl/aicpu/parameterized_truncated_normal.py +54 -0
  723. mindspore/ops/_op_impl/aicpu/pdist_grad.py +33 -0
  724. mindspore/ops/_op_impl/aicpu/poisson.py +37 -0
  725. mindspore/ops/_op_impl/aicpu/polar.py +32 -0
  726. mindspore/ops/_op_impl/aicpu/polygamma.py +34 -0
  727. mindspore/ops/_op_impl/aicpu/pow.py +39 -0
  728. mindspore/ops/_op_impl/aicpu/print_tensor.py +39 -0
  729. mindspore/ops/_op_impl/aicpu/priority_replay_buffer.py +113 -0
  730. mindspore/ops/_op_impl/aicpu/qr.py +36 -0
  731. mindspore/ops/_op_impl/aicpu/quant_dtype_cast.py +40 -0
  732. mindspore/ops/_op_impl/aicpu/quantile.py +35 -0
  733. mindspore/ops/_op_impl/aicpu/ragged_range.py +49 -0
  734. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_sparse.py +73 -0
  735. mindspore/ops/_op_impl/aicpu/ragged_tensor_to_tensor.py +74 -0
  736. mindspore/ops/_op_impl/aicpu/random_categorical.py +68 -0
  737. mindspore/ops/_op_impl/aicpu/random_choice_with_mask.py +36 -0
  738. mindspore/ops/_op_impl/aicpu/random_gamma.py +38 -0
  739. mindspore/ops/_op_impl/aicpu/random_poisson.py +134 -0
  740. mindspore/ops/_op_impl/aicpu/random_shuffle.py +47 -0
  741. mindspore/ops/_op_impl/aicpu/randperm.py +38 -0
  742. mindspore/ops/_op_impl/aicpu/randperm_v2.py +41 -0
  743. mindspore/ops/_op_impl/aicpu/range.py +36 -0
  744. mindspore/ops/_op_impl/aicpu/range_v2.py +35 -0
  745. mindspore/ops/_op_impl/aicpu/real.py +31 -0
  746. mindspore/ops/_op_impl/aicpu/real_div.py +40 -0
  747. mindspore/ops/_op_impl/aicpu/reciprocal.py +34 -0
  748. mindspore/ops/_op_impl/aicpu/reciprocal_grad.py +35 -0
  749. mindspore/ops/_op_impl/aicpu/reduce_mean.py +57 -0
  750. mindspore/ops/_op_impl/aicpu/reduce_prod.py +57 -0
  751. mindspore/ops/_op_impl/aicpu/reduce_sum.py +57 -0
  752. mindspore/ops/_op_impl/aicpu/relu_grad_v3.py +41 -0
  753. mindspore/ops/_op_impl/aicpu/relu_v3.py +38 -0
  754. mindspore/ops/_op_impl/aicpu/reservoir_replay_buffer.py +96 -0
  755. mindspore/ops/_op_impl/aicpu/reshape.py +42 -0
  756. mindspore/ops/_op_impl/aicpu/resize_area.py +40 -0
  757. mindspore/ops/_op_impl/aicpu/resize_bicubic.py +20 -0
  758. mindspore/ops/_op_impl/aicpu/resize_bicubic_grad.py +19 -0
  759. mindspore/ops/_op_impl/aicpu/resize_bilinear.py +32 -0
  760. mindspore/ops/_op_impl/aicpu/resize_bilinear_grad.py +32 -0
  761. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2.py +36 -0
  762. mindspore/ops/_op_impl/aicpu/resize_nearest_neighbor_v2_grad.py +35 -0
  763. mindspore/ops/_op_impl/aicpu/resize_v2.py +68 -0
  764. mindspore/ops/_op_impl/aicpu/resize_v2_grad.py +68 -0
  765. mindspore/ops/_op_impl/aicpu/reverse_sequence.py +55 -0
  766. mindspore/ops/_op_impl/aicpu/reversev2.py +54 -0
  767. mindspore/ops/_op_impl/aicpu/rgb_to_hsv.py +32 -0
  768. mindspore/ops/_op_impl/aicpu/right_shift.py +38 -0
  769. mindspore/ops/_op_impl/aicpu/rnnt_loss.py +35 -0
  770. mindspore/ops/_op_impl/aicpu/round.py +34 -0
  771. mindspore/ops/_op_impl/aicpu/rsqrt.py +33 -0
  772. mindspore/ops/_op_impl/aicpu/rsqrt_grad.py +36 -0
  773. mindspore/ops/_op_impl/aicpu/sample_distorted_bounding_box_v2.py +49 -0
  774. mindspore/ops/_op_impl/aicpu/scale_and_translate.py +52 -0
  775. mindspore/ops/_op_impl/aicpu/scale_and_translate_grad.py +36 -0
  776. mindspore/ops/_op_impl/aicpu/scatter.py +79 -0
  777. mindspore/ops/_op_impl/aicpu/scatter_add_with_axis.py +53 -0
  778. mindspore/ops/_op_impl/aicpu/scatter_elements.py +39 -0
  779. mindspore/ops/_op_impl/aicpu/scatter_nd.py +59 -0
  780. mindspore/ops/_op_impl/aicpu/scatter_nd_max.py +54 -0
  781. mindspore/ops/_op_impl/aicpu/scatter_nd_min.py +54 -0
  782. mindspore/ops/_op_impl/aicpu/scatter_nd_update.py +59 -0
  783. mindspore/ops/_op_impl/aicpu/search_sorted.py +44 -0
  784. mindspore/ops/_op_impl/aicpu/segment_max.py +52 -0
  785. mindspore/ops/_op_impl/aicpu/segment_mean.py +56 -0
  786. mindspore/ops/_op_impl/aicpu/segment_min.py +52 -0
  787. mindspore/ops/_op_impl/aicpu/segment_prod.py +56 -0
  788. mindspore/ops/_op_impl/aicpu/segment_sum.py +56 -0
  789. mindspore/ops/_op_impl/aicpu/select.py +45 -0
  790. mindspore/ops/_op_impl/aicpu/self_adjoint_eig.py +34 -0
  791. mindspore/ops/_op_impl/aicpu/sequence_add.py +34 -0
  792. mindspore/ops/_op_impl/aicpu/sequence_add_offset.py +34 -0
  793. mindspore/ops/_op_impl/aicpu/sequence_addn.py +38 -0
  794. mindspore/ops/_op_impl/aicpu/sequence_concat.py +40 -0
  795. mindspore/ops/_op_impl/aicpu/sequence_stack.py +40 -0
  796. mindspore/ops/_op_impl/aicpu/set_size.py +38 -0
  797. mindspore/ops/_op_impl/aicpu/sign.py +36 -0
  798. mindspore/ops/_op_impl/aicpu/sin.py +34 -0
  799. mindspore/ops/_op_impl/aicpu/sinc.py +43 -0
  800. mindspore/ops/_op_impl/aicpu/sinh.py +34 -0
  801. mindspore/ops/_op_impl/aicpu/slice.py +59 -0
  802. mindspore/ops/_op_impl/aicpu/slice_grad.py +76 -0
  803. mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py +35 -0
  804. mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py +37 -0
  805. mindspore/ops/_op_impl/aicpu/sort.py +39 -0
  806. mindspore/ops/_op_impl/aicpu/space_to_depth.py +44 -0
  807. mindspore/ops/_op_impl/aicpu/sparse_addmm.py +87 -0
  808. mindspore/ops/_op_impl/aicpu/sparse_apply_adagrad_da.py +80 -0
  809. mindspore/ops/_op_impl/aicpu/sparse_apply_centered_rms_prop.py +105 -0
  810. mindspore/ops/_op_impl/aicpu/sparse_apply_momentum.py +80 -0
  811. mindspore/ops/_op_impl/aicpu/sparse_apply_proximal_gradient_descent.py +79 -0
  812. mindspore/ops/_op_impl/aicpu/sparse_concat.py +59 -0
  813. mindspore/ops/_op_impl/aicpu/sparse_cross.py +42 -0
  814. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_add.py +58 -0
  815. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_div.py +58 -0
  816. mindspore/ops/_op_impl/aicpu/sparse_dense_cwise_mul.py +58 -0
  817. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows.py +63 -0
  818. mindspore/ops/_op_impl/aicpu/sparse_fill_empty_rows_grad.py +45 -0
  819. mindspore/ops/_op_impl/aicpu/sparse_matrix_mat_mul.py +56 -0
  820. mindspore/ops/_op_impl/aicpu/sparse_matrix_nnz.py +81 -0
  821. mindspore/ops/_op_impl/aicpu/sparse_matrix_transpose.py +116 -0
  822. mindspore/ops/_op_impl/aicpu/sparse_reorder.py +56 -0
  823. mindspore/ops/_op_impl/aicpu/sparse_reshape.py +34 -0
  824. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_grad.py +36 -0
  825. mindspore/ops/_op_impl/aicpu/sparse_segment_mean_with_num_segments.py +44 -0
  826. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n.py +43 -0
  827. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_grad.py +38 -0
  828. mindspore/ops/_op_impl/aicpu/sparse_segment_sqrt_n_with_num_segments.py +44 -0
  829. mindspore/ops/_op_impl/aicpu/sparse_segment_sum.py +49 -0
  830. mindspore/ops/_op_impl/aicpu/sparse_segment_sum_with_num_segments.py +68 -0
  831. mindspore/ops/_op_impl/aicpu/sparse_slice.py +63 -0
  832. mindspore/ops/_op_impl/aicpu/sparse_slice_grad.py +61 -0
  833. mindspore/ops/_op_impl/aicpu/sparse_softmax.py +33 -0
  834. mindspore/ops/_op_impl/aicpu/sparse_softmax_cross_entropy_with_logits_v2.py +35 -0
  835. mindspore/ops/_op_impl/aicpu/sparse_sparse_maximum.py +53 -0
  836. mindspore/ops/_op_impl/aicpu/sparse_sparse_minimum.py +53 -0
  837. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_add.py +84 -0
  838. mindspore/ops/_op_impl/aicpu/sparse_tensor_dense_mat_mul.py +190 -0
  839. mindspore/ops/_op_impl/aicpu/sparse_tensor_to_csr_sparse_matrix.py +51 -0
  840. mindspore/ops/_op_impl/aicpu/sparse_to_dense_v2.py +73 -0
  841. mindspore/ops/_op_impl/aicpu/split.py +45 -0
  842. mindspore/ops/_op_impl/aicpu/sqrt.py +34 -0
  843. mindspore/ops/_op_impl/aicpu/sqrt_grad.py +35 -0
  844. mindspore/ops/_op_impl/aicpu/square.py +35 -0
  845. mindspore/ops/_op_impl/aicpu/squared_difference.py +37 -0
  846. mindspore/ops/_op_impl/aicpu/squeeze.py +42 -0
  847. mindspore/ops/_op_impl/aicpu/sspaddmm.py +97 -0
  848. mindspore/ops/_op_impl/aicpu/stack.py +45 -0
  849. mindspore/ops/_op_impl/aicpu/stack_push_pop.py +87 -0
  850. mindspore/ops/_op_impl/aicpu/standard_laplace.py +34 -0
  851. mindspore/ops/_op_impl/aicpu/standard_normal.py +34 -0
  852. mindspore/ops/_op_impl/aicpu/stateless_dropout_genmask.py +37 -0
  853. mindspore/ops/_op_impl/aicpu/stft.py +70 -0
  854. mindspore/ops/_op_impl/aicpu/strided_slice.py +43 -0
  855. mindspore/ops/_op_impl/aicpu/strided_slice_grad.py +50 -0
  856. mindspore/ops/_op_impl/aicpu/sub.py +41 -0
  857. mindspore/ops/_op_impl/aicpu/sub_and_filter.py +36 -0
  858. mindspore/ops/_op_impl/aicpu/tan.py +34 -0
  859. mindspore/ops/_op_impl/aicpu/tanh.py +34 -0
  860. mindspore/ops/_op_impl/aicpu/tanh_grad.py +35 -0
  861. mindspore/ops/_op_impl/aicpu/tensor_scatter_update.py +59 -0
  862. mindspore/ops/_op_impl/aicpu/tile.py +56 -0
  863. mindspore/ops/_op_impl/aicpu/topk.py +34 -0
  864. mindspore/ops/_op_impl/aicpu/trace.py +40 -0
  865. mindspore/ops/_op_impl/aicpu/tracegrad.py +41 -0
  866. mindspore/ops/_op_impl/aicpu/trans_data.py +35 -0
  867. mindspore/ops/_op_impl/aicpu/transpose.py +58 -0
  868. mindspore/ops/_op_impl/aicpu/tridiagonal_matmul.py +42 -0
  869. mindspore/ops/_op_impl/aicpu/tridiagonal_solve.py +35 -0
  870. mindspore/ops/_op_impl/aicpu/tril.py +42 -0
  871. mindspore/ops/_op_impl/aicpu/tril_indices.py +34 -0
  872. mindspore/ops/_op_impl/aicpu/triplet_margin_loss.py +62 -0
  873. mindspore/ops/_op_impl/aicpu/triu.py +43 -0
  874. mindspore/ops/_op_impl/aicpu/triu_indices.py +34 -0
  875. mindspore/ops/_op_impl/aicpu/truncated_normal.py +39 -0
  876. mindspore/ops/_op_impl/aicpu/uniform.py +36 -0
  877. mindspore/ops/_op_impl/aicpu/uniform_candidate_sampler.py +41 -0
  878. mindspore/ops/_op_impl/aicpu/uniform_int.py +36 -0
  879. mindspore/ops/_op_impl/aicpu/uniform_real.py +33 -0
  880. mindspore/ops/_op_impl/aicpu/unique.py +31 -0
  881. mindspore/ops/_op_impl/aicpu/unique_consecutive.py +47 -0
  882. mindspore/ops/_op_impl/aicpu/unique_with_pad.py +32 -0
  883. mindspore/ops/_op_impl/aicpu/unravel_index.py +32 -0
  884. mindspore/ops/_op_impl/aicpu/unsorted_segment_prod.py +53 -0
  885. mindspore/ops/_op_impl/aicpu/unsorted_segment_sum.py +57 -0
  886. mindspore/ops/_op_impl/aicpu/unstack.py +45 -0
  887. mindspore/ops/_op_impl/aicpu/update_cache.py +44 -0
  888. mindspore/ops/_op_impl/aicpu/upper_bound.py +47 -0
  889. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d.py +42 -0
  890. mindspore/ops/_op_impl/aicpu/upsample_nearest_3d_grad.py +49 -0
  891. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d.py +40 -0
  892. mindspore/ops/_op_impl/aicpu/upsample_trilinear_3d_grad.py +50 -0
  893. mindspore/ops/_op_impl/aicpu/xdivy.py +35 -0
  894. mindspore/ops/_op_impl/aicpu/xlogy.py +33 -0
  895. mindspore/ops/_op_impl/aicpu/zeros_like.py +42 -0
  896. mindspore/ops/_op_impl/aicpu/zeta.py +31 -0
  897. mindspore/ops/_op_impl/akg/__init__.py +19 -0
  898. mindspore/ops/_op_impl/akg/ascend/__init__.py +48 -0
  899. mindspore/ops/_op_impl/akg/ascend/abs.py +35 -0
  900. mindspore/ops/_op_impl/akg/ascend/add.py +42 -0
  901. mindspore/ops/_op_impl/akg/ascend/add_n.py +37 -0
  902. mindspore/ops/_op_impl/akg/ascend/batchmatmul.py +33 -0
  903. mindspore/ops/_op_impl/akg/ascend/cast.py +46 -0
  904. mindspore/ops/_op_impl/akg/ascend/equal.py +35 -0
  905. mindspore/ops/_op_impl/akg/ascend/exp.py +35 -0
  906. mindspore/ops/_op_impl/akg/ascend/expand_dims.py +33 -0
  907. mindspore/ops/_op_impl/akg/ascend/greater.py +34 -0
  908. mindspore/ops/_op_impl/akg/ascend/greater_equal.py +35 -0
  909. mindspore/ops/_op_impl/akg/ascend/less.py +31 -0
  910. mindspore/ops/_op_impl/akg/ascend/less_equal.py +35 -0
  911. mindspore/ops/_op_impl/akg/ascend/load_im2col.py +33 -0
  912. mindspore/ops/_op_impl/akg/ascend/log.py +34 -0
  913. mindspore/ops/_op_impl/akg/ascend/maximum.py +36 -0
  914. mindspore/ops/_op_impl/akg/ascend/minimum.py +39 -0
  915. mindspore/ops/_op_impl/akg/ascend/mul.py +41 -0
  916. mindspore/ops/_op_impl/akg/ascend/neg.py +37 -0
  917. mindspore/ops/_op_impl/akg/ascend/pow.py +35 -0
  918. mindspore/ops/_op_impl/akg/ascend/prod_force_se_a.py +33 -0
  919. mindspore/ops/_op_impl/akg/ascend/real_div.py +36 -0
  920. mindspore/ops/_op_impl/akg/ascend/reciprocal.py +32 -0
  921. mindspore/ops/_op_impl/akg/ascend/reduce_max.py +32 -0
  922. mindspore/ops/_op_impl/akg/ascend/reduce_min.py +32 -0
  923. mindspore/ops/_op_impl/akg/ascend/reduce_sum.py +37 -0
  924. mindspore/ops/_op_impl/akg/ascend/rsqrt.py +35 -0
  925. mindspore/ops/_op_impl/akg/ascend/select.py +37 -0
  926. mindspore/ops/_op_impl/akg/ascend/sqrt.py +35 -0
  927. mindspore/ops/_op_impl/akg/ascend/square.py +35 -0
  928. mindspore/ops/_op_impl/akg/ascend/sub.py +42 -0
  929. mindspore/ops/_op_impl/akg/cpu/__init__.py +23 -0
  930. mindspore/ops/_op_impl/akg/cpu/coo2csr.py +29 -0
  931. mindspore/ops/_op_impl/akg/cpu/csr2coo.py +29 -0
  932. mindspore/ops/_op_impl/akg/cpu/csr_gather.py +33 -0
  933. mindspore/ops/_op_impl/akg/cpu/csr_mm.py +34 -0
  934. mindspore/ops/_op_impl/akg/cpu/csr_mul.py +33 -0
  935. mindspore/ops/_op_impl/akg/cpu/csr_mv.py +33 -0
  936. mindspore/ops/_op_impl/akg/cpu/csr_reduce_sum.py +31 -0
  937. mindspore/ops/_op_impl/akg/gpu/__init__.py +24 -0
  938. mindspore/ops/_op_impl/akg/gpu/coo2csr.py +29 -0
  939. mindspore/ops/_op_impl/akg/gpu/csr2coo.py +29 -0
  940. mindspore/ops/_op_impl/akg/gpu/csr_div.py +36 -0
  941. mindspore/ops/_op_impl/akg/gpu/csr_gather.py +33 -0
  942. mindspore/ops/_op_impl/akg/gpu/csr_mm.py +37 -0
  943. mindspore/ops/_op_impl/akg/gpu/csr_mul.py +36 -0
  944. mindspore/ops/_op_impl/akg/gpu/csr_mv.py +36 -0
  945. mindspore/ops/_op_impl/akg/gpu/csr_reduce_sum.py +33 -0
  946. mindspore/ops/_op_impl/cpu/__init__.py +78 -0
  947. mindspore/ops/_op_impl/cpu/adam.py +49 -0
  948. mindspore/ops/_op_impl/cpu/adam_weight_decay.py +47 -0
  949. mindspore/ops/_op_impl/cpu/arg_max.py +30 -0
  950. mindspore/ops/_op_impl/cpu/arg_max_with_value.py +31 -0
  951. mindspore/ops/_op_impl/cpu/arg_min_with_value.py +31 -0
  952. mindspore/ops/_op_impl/cpu/buffer_append.py +28 -0
  953. mindspore/ops/_op_impl/cpu/buffer_get.py +28 -0
  954. mindspore/ops/_op_impl/cpu/buffer_sample.py +28 -0
  955. mindspore/ops/_op_impl/cpu/cast.py +171 -0
  956. mindspore/ops/_op_impl/cpu/concat_offset.py +38 -0
  957. mindspore/ops/_op_impl/cpu/conv2d.py +30 -0
  958. mindspore/ops/_op_impl/cpu/conv3d.py +30 -0
  959. mindspore/ops/_op_impl/cpu/div.py +32 -0
  960. mindspore/ops/_op_impl/cpu/dropout.py +31 -0
  961. mindspore/ops/_op_impl/cpu/dropout_grad.py +30 -0
  962. mindspore/ops/_op_impl/cpu/dynamic_shape.py +42 -0
  963. mindspore/ops/_op_impl/cpu/dynamic_stitch.py +41 -0
  964. mindspore/ops/_op_impl/cpu/equal_count.py +30 -0
  965. mindspore/ops/_op_impl/cpu/gather_d.py +49 -0
  966. mindspore/ops/_op_impl/cpu/gather_d_grad.py +38 -0
  967. mindspore/ops/_op_impl/cpu/gather_d_grad_v2.py +40 -0
  968. mindspore/ops/_op_impl/cpu/gather_v2.py +40 -0
  969. mindspore/ops/_op_impl/cpu/hsigmoid.py +33 -0
  970. mindspore/ops/_op_impl/cpu/hsigmoid_grad.py +34 -0
  971. mindspore/ops/_op_impl/cpu/hswish.py +32 -0
  972. mindspore/ops/_op_impl/cpu/hswish_grad.py +33 -0
  973. mindspore/ops/_op_impl/cpu/identity_n.py +40 -0
  974. mindspore/ops/_op_impl/cpu/is_finite.py +39 -0
  975. mindspore/ops/_op_impl/cpu/l2loss.py +30 -0
  976. mindspore/ops/_op_impl/cpu/layer_norm.py +36 -0
  977. mindspore/ops/_op_impl/cpu/layer_norm_grad.py +38 -0
  978. mindspore/ops/_op_impl/cpu/maximum.py +35 -0
  979. mindspore/ops/_op_impl/cpu/maximum_grad.py +47 -0
  980. mindspore/ops/_op_impl/cpu/minimum.py +40 -0
  981. mindspore/ops/_op_impl/cpu/minimum_grad.py +51 -0
  982. mindspore/ops/_op_impl/cpu/mirror_pad.py +36 -0
  983. mindspore/ops/_op_impl/cpu/mirror_pad_grad.py +36 -0
  984. mindspore/ops/_op_impl/cpu/mul.py +32 -0
  985. mindspore/ops/_op_impl/cpu/one_hot.py +31 -0
  986. mindspore/ops/_op_impl/cpu/pad.py +32 -0
  987. mindspore/ops/_op_impl/cpu/pow.py +32 -0
  988. mindspore/ops/_op_impl/cpu/priority_replay_buffer.py +42 -0
  989. mindspore/ops/_op_impl/cpu/pyexecute.py +29 -0
  990. mindspore/ops/_op_impl/cpu/pyfunc.py +29 -0
  991. mindspore/ops/_op_impl/cpu/range.py +34 -0
  992. mindspore/ops/_op_impl/cpu/real_div.py +33 -0
  993. mindspore/ops/_op_impl/cpu/reduce_all.py +29 -0
  994. mindspore/ops/_op_impl/cpu/reduce_any.py +29 -0
  995. mindspore/ops/_op_impl/cpu/reduce_max.py +32 -0
  996. mindspore/ops/_op_impl/cpu/reduce_mean.py +40 -0
  997. mindspore/ops/_op_impl/cpu/reduce_min.py +32 -0
  998. mindspore/ops/_op_impl/cpu/reduce_prod.py +40 -0
  999. mindspore/ops/_op_impl/cpu/reduce_std.py +31 -0
  1000. mindspore/ops/_op_impl/cpu/reduce_sum.py +41 -0
  1001. mindspore/ops/_op_impl/cpu/space_to_batch_nd.py +38 -0
  1002. mindspore/ops/_op_impl/cpu/sparse_slice.py +62 -0
  1003. mindspore/ops/_op_impl/cpu/sparse_slice_grad.py +60 -0
  1004. mindspore/ops/_op_impl/cpu/split.py +34 -0
  1005. mindspore/ops/_op_impl/cpu/sspaddmm.py +95 -0
  1006. mindspore/ops/_op_impl/cpu/stack.py +38 -0
  1007. mindspore/ops/_op_impl/cpu/sub.py +32 -0
  1008. mindspore/ops/_op_impl/cpu/tensor_copy_slices.py +41 -0
  1009. mindspore/ops/_op_impl/cpu/tile.py +37 -0
  1010. mindspore/ops/_op_impl/cpu/top_k.py +31 -0
  1011. mindspore/ops/_op_impl/cpu/transpose.py +39 -0
  1012. mindspore/ops/_primitive_cache.py +90 -0
  1013. mindspore/ops/_register_for_op.py +73 -0
  1014. mindspore/ops/_utils/__init__.py +20 -0
  1015. mindspore/ops/_utils/utils.py +147 -0
  1016. mindspore/ops/_vmap/__init__.py +25 -0
  1017. mindspore/ops/_vmap/vmap_array_ops.py +2149 -0
  1018. mindspore/ops/_vmap/vmap_base.py +533 -0
  1019. mindspore/ops/_vmap/vmap_convolution_ops.py +441 -0
  1020. mindspore/ops/_vmap/vmap_debug_ops.py +50 -0
  1021. mindspore/ops/_vmap/vmap_grad_math_ops.py +274 -0
  1022. mindspore/ops/_vmap/vmap_grad_nn_ops.py +806 -0
  1023. mindspore/ops/_vmap/vmap_image_ops.py +194 -0
  1024. mindspore/ops/_vmap/vmap_math_ops.py +993 -0
  1025. mindspore/ops/_vmap/vmap_nn_ops.py +2250 -0
  1026. mindspore/ops/_vmap/vmap_other_ops.py +105 -0
  1027. mindspore/ops/_vmap/vmap_random_ops.py +122 -0
  1028. mindspore/ops/_vmap/vmap_sparse_ops.py +89 -0
  1029. mindspore/ops/auto_generate/__init__.py +31 -0
  1030. mindspore/ops/auto_generate/cpp_create_prim_instance_helper.py +309 -0
  1031. mindspore/ops/auto_generate/gen_arg_dtype_cast.py +252 -0
  1032. mindspore/ops/auto_generate/gen_arg_handler.py +197 -0
  1033. mindspore/ops/auto_generate/gen_extend_func.py +1701 -0
  1034. mindspore/ops/auto_generate/gen_ops_def.py +8482 -0
  1035. mindspore/ops/auto_generate/gen_ops_prim.py +16704 -0
  1036. mindspore/ops/auto_generate/pyboost_inner_prim.py +549 -0
  1037. mindspore/ops/composite/__init__.py +71 -0
  1038. mindspore/ops/composite/base.py +1318 -0
  1039. mindspore/ops/composite/env_ops.py +41 -0
  1040. mindspore/ops/composite/math_ops.py +125 -0
  1041. mindspore/ops/composite/multitype_ops/__init__.py +77 -0
  1042. mindspore/ops/composite/multitype_ops/_compile_utils.py +1459 -0
  1043. mindspore/ops/composite/multitype_ops/_constexpr_utils.py +897 -0
  1044. mindspore/ops/composite/multitype_ops/add_impl.py +606 -0
  1045. mindspore/ops/composite/multitype_ops/bitwise_and_impl.py +56 -0
  1046. mindspore/ops/composite/multitype_ops/bitwise_or_impl.py +56 -0
  1047. mindspore/ops/composite/multitype_ops/bitwise_xor_impl.py +56 -0
  1048. mindspore/ops/composite/multitype_ops/div_impl.py +189 -0
  1049. mindspore/ops/composite/multitype_ops/equal_impl.py +335 -0
  1050. mindspore/ops/composite/multitype_ops/floordiv_impl.py +88 -0
  1051. mindspore/ops/composite/multitype_ops/getitem_impl.py +400 -0
  1052. mindspore/ops/composite/multitype_ops/greater_equal_impl.py +109 -0
  1053. mindspore/ops/composite/multitype_ops/greater_impl.py +110 -0
  1054. mindspore/ops/composite/multitype_ops/in_impl.py +196 -0
  1055. mindspore/ops/composite/multitype_ops/left_shift_impl.py +37 -0
  1056. mindspore/ops/composite/multitype_ops/less_equal_impl.py +111 -0
  1057. mindspore/ops/composite/multitype_ops/less_impl.py +112 -0
  1058. mindspore/ops/composite/multitype_ops/logic_not_impl.py +113 -0
  1059. mindspore/ops/composite/multitype_ops/logical_and_impl.py +60 -0
  1060. mindspore/ops/composite/multitype_ops/logical_or_impl.py +61 -0
  1061. mindspore/ops/composite/multitype_ops/mod_impl.py +86 -0
  1062. mindspore/ops/composite/multitype_ops/mul_impl.py +294 -0
  1063. mindspore/ops/composite/multitype_ops/negative_impl.py +79 -0
  1064. mindspore/ops/composite/multitype_ops/not_equal_impl.py +290 -0
  1065. mindspore/ops/composite/multitype_ops/not_in_impl.py +196 -0
  1066. mindspore/ops/composite/multitype_ops/ones_like_impl.py +96 -0
  1067. mindspore/ops/composite/multitype_ops/pow_impl.py +87 -0
  1068. mindspore/ops/composite/multitype_ops/right_shift_impl.py +37 -0
  1069. mindspore/ops/composite/multitype_ops/setitem_impl.py +884 -0
  1070. mindspore/ops/composite/multitype_ops/sub_impl.py +116 -0
  1071. mindspore/ops/composite/multitype_ops/uadd_impl.py +29 -0
  1072. mindspore/ops/composite/multitype_ops/zeros_like_impl.py +228 -0
  1073. mindspore/ops/deprecated.py +315 -0
  1074. mindspore/ops/function/__init__.py +782 -0
  1075. mindspore/ops/function/array_func.py +7226 -0
  1076. mindspore/ops/function/clip_func.py +384 -0
  1077. mindspore/ops/function/debug_func.py +181 -0
  1078. mindspore/ops/function/fft_func.py +44 -0
  1079. mindspore/ops/function/grad/__init__.py +34 -0
  1080. mindspore/ops/function/grad/grad_func.py +1425 -0
  1081. mindspore/ops/function/image_func.py +292 -0
  1082. mindspore/ops/function/linalg_func.py +416 -0
  1083. mindspore/ops/function/math_func.py +12228 -0
  1084. mindspore/ops/function/nn_func.py +8609 -0
  1085. mindspore/ops/function/other_func.py +115 -0
  1086. mindspore/ops/function/parameter_func.py +134 -0
  1087. mindspore/ops/function/random_func.py +1715 -0
  1088. mindspore/ops/function/reshard_func.py +104 -0
  1089. mindspore/ops/function/sparse_func.py +884 -0
  1090. mindspore/ops/function/sparse_unary_func.py +2422 -0
  1091. mindspore/ops/function/spectral_func.py +150 -0
  1092. mindspore/ops/function/vmap_func.py +117 -0
  1093. mindspore/ops/functional.py +464 -0
  1094. mindspore/ops/op_info_register.py +1572 -0
  1095. mindspore/ops/operations/__init__.py +722 -0
  1096. mindspore/ops/operations/_csr_ops.py +403 -0
  1097. mindspore/ops/operations/_custom_grad.py +181 -0
  1098. mindspore/ops/operations/_embedding_cache_ops.py +307 -0
  1099. mindspore/ops/operations/_grad_ops.py +2978 -0
  1100. mindspore/ops/operations/_infer_ops.py +19 -0
  1101. mindspore/ops/operations/_inner_ops.py +2544 -0
  1102. mindspore/ops/operations/_map_tensor_ops.py +112 -0
  1103. mindspore/ops/operations/_ms_kernel.py +601 -0
  1104. mindspore/ops/operations/_ocr_ops.py +379 -0
  1105. mindspore/ops/operations/_opaque_predicate_registry.py +41 -0
  1106. mindspore/ops/operations/_pyfunc_registry.py +58 -0
  1107. mindspore/ops/operations/_quant_ops.py +1844 -0
  1108. mindspore/ops/operations/_rl_inner_ops.py +1231 -0
  1109. mindspore/ops/operations/_scalar_ops.py +106 -0
  1110. mindspore/ops/operations/_sequence_ops.py +1155 -0
  1111. mindspore/ops/operations/_sparse_grad_ops.py +56 -0
  1112. mindspore/ops/operations/_tensor_array.py +359 -0
  1113. mindspore/ops/operations/_thor_ops.py +807 -0
  1114. mindspore/ops/operations/array_ops.py +6124 -0
  1115. mindspore/ops/operations/comm_ops.py +1985 -0
  1116. mindspore/ops/operations/control_ops.py +127 -0
  1117. mindspore/ops/operations/custom_ops.py +1129 -0
  1118. mindspore/ops/operations/debug_ops.py +678 -0
  1119. mindspore/ops/operations/image_ops.py +1041 -0
  1120. mindspore/ops/operations/inner_ops.py +697 -0
  1121. mindspore/ops/operations/linalg_ops.py +95 -0
  1122. mindspore/ops/operations/manually_defined/__init__.py +24 -0
  1123. mindspore/ops/operations/manually_defined/_inner.py +73 -0
  1124. mindspore/ops/operations/manually_defined/ops_def.py +2271 -0
  1125. mindspore/ops/operations/math_ops.py +5095 -0
  1126. mindspore/ops/operations/nn_ops.py +9575 -0
  1127. mindspore/ops/operations/other_ops.py +874 -0
  1128. mindspore/ops/operations/random_ops.py +1288 -0
  1129. mindspore/ops/operations/reshard_ops.py +53 -0
  1130. mindspore/ops/operations/rl_ops.py +288 -0
  1131. mindspore/ops/operations/sparse_ops.py +2753 -0
  1132. mindspore/ops/operations/spectral_ops.py +111 -0
  1133. mindspore/ops/primitive.py +1046 -0
  1134. mindspore/ops/signature.py +54 -0
  1135. mindspore/ops/vm_impl_registry.py +91 -0
  1136. mindspore/ops_generate/__init__.py +27 -0
  1137. mindspore/ops_generate/arg_dtype_cast.py +252 -0
  1138. mindspore/ops_generate/arg_handler.py +197 -0
  1139. mindspore/ops_generate/gen_aclnn_implement.py +263 -0
  1140. mindspore/ops_generate/gen_constants.py +36 -0
  1141. mindspore/ops_generate/gen_ops.py +1099 -0
  1142. mindspore/ops_generate/gen_ops_inner_prim.py +131 -0
  1143. mindspore/ops_generate/gen_pyboost_func.py +1052 -0
  1144. mindspore/ops_generate/gen_utils.py +209 -0
  1145. mindspore/ops_generate/op_proto.py +145 -0
  1146. mindspore/ops_generate/pyboost_utils.py +367 -0
  1147. mindspore/ops_generate/template.py +261 -0
  1148. mindspore/parallel/__init__.py +30 -0
  1149. mindspore/parallel/_auto_parallel_context.py +1486 -0
  1150. mindspore/parallel/_cell_wrapper.py +174 -0
  1151. mindspore/parallel/_cost_model_context.py +700 -0
  1152. mindspore/parallel/_dp_allreduce_fusion.py +159 -0
  1153. mindspore/parallel/_offload_context.py +275 -0
  1154. mindspore/parallel/_parallel_serialization.py +561 -0
  1155. mindspore/parallel/_ps_context.py +242 -0
  1156. mindspore/parallel/_recovery_context.py +110 -0
  1157. mindspore/parallel/_tensor.py +730 -0
  1158. mindspore/parallel/_transformer/__init__.py +35 -0
  1159. mindspore/parallel/_transformer/layers.py +765 -0
  1160. mindspore/parallel/_transformer/loss.py +251 -0
  1161. mindspore/parallel/_transformer/moe.py +693 -0
  1162. mindspore/parallel/_transformer/op_parallel_config.py +222 -0
  1163. mindspore/parallel/_transformer/transformer.py +3119 -0
  1164. mindspore/parallel/_utils.py +612 -0
  1165. mindspore/parallel/algo_parameter_config.py +400 -0
  1166. mindspore/parallel/checkpoint_transform.py +650 -0
  1167. mindspore/parallel/cluster/__init__.py +15 -0
  1168. mindspore/parallel/cluster/process_entity/__init__.py +18 -0
  1169. mindspore/parallel/cluster/process_entity/_api.py +352 -0
  1170. mindspore/parallel/cluster/process_entity/_utils.py +101 -0
  1171. mindspore/parallel/cluster/run.py +136 -0
  1172. mindspore/parallel/mpi/__init__.py +14 -0
  1173. mindspore/parallel/mpi/_mpi_config.py +116 -0
  1174. mindspore/parallel/parameter_broadcast.py +151 -0
  1175. mindspore/parallel/shard.py +481 -0
  1176. mindspore/parallel/transform_safetensors.py +993 -0
  1177. mindspore/profiler/__init__.py +28 -0
  1178. mindspore/profiler/common/__init__.py +14 -0
  1179. mindspore/profiler/common/constant.py +29 -0
  1180. mindspore/profiler/common/exceptions/__init__.py +14 -0
  1181. mindspore/profiler/common/exceptions/error_code.py +83 -0
  1182. mindspore/profiler/common/exceptions/exceptions.py +286 -0
  1183. mindspore/profiler/common/process_pool.py +41 -0
  1184. mindspore/profiler/common/registry.py +47 -0
  1185. mindspore/profiler/common/singleton.py +28 -0
  1186. mindspore/profiler/common/struct_type.py +118 -0
  1187. mindspore/profiler/common/util.py +472 -0
  1188. mindspore/profiler/common/validator/__init__.py +14 -0
  1189. mindspore/profiler/common/validator/validate_path.py +84 -0
  1190. mindspore/profiler/dynamic_profiler.py +694 -0
  1191. mindspore/profiler/envprofiling.py +254 -0
  1192. mindspore/profiler/parser/__init__.py +14 -0
  1193. mindspore/profiler/parser/aicpu_data_parser.py +272 -0
  1194. mindspore/profiler/parser/ascend_analysis/__init__.py +14 -0
  1195. mindspore/profiler/parser/ascend_analysis/constant.py +71 -0
  1196. mindspore/profiler/parser/ascend_analysis/file_manager.py +180 -0
  1197. mindspore/profiler/parser/ascend_analysis/function_event.py +185 -0
  1198. mindspore/profiler/parser/ascend_analysis/fwk_cann_parser.py +136 -0
  1199. mindspore/profiler/parser/ascend_analysis/fwk_file_parser.py +131 -0
  1200. mindspore/profiler/parser/ascend_analysis/msprof_timeline_parser.py +104 -0
  1201. mindspore/profiler/parser/ascend_analysis/path_manager.py +313 -0
  1202. mindspore/profiler/parser/ascend_analysis/profiler_info_parser.py +123 -0
  1203. mindspore/profiler/parser/ascend_analysis/tlv_decoder.py +86 -0
  1204. mindspore/profiler/parser/ascend_analysis/trace_event_manager.py +75 -0
  1205. mindspore/profiler/parser/ascend_cluster_generator.py +116 -0
  1206. mindspore/profiler/parser/ascend_communicate_generator.py +314 -0
  1207. mindspore/profiler/parser/ascend_flops_generator.py +116 -0
  1208. mindspore/profiler/parser/ascend_fpbp_generator.py +82 -0
  1209. mindspore/profiler/parser/ascend_hccl_generator.py +271 -0
  1210. mindspore/profiler/parser/ascend_integrate_generator.py +42 -0
  1211. mindspore/profiler/parser/ascend_memory_generator.py +185 -0
  1212. mindspore/profiler/parser/ascend_msprof_exporter.py +282 -0
  1213. mindspore/profiler/parser/ascend_msprof_generator.py +187 -0
  1214. mindspore/profiler/parser/ascend_op_generator.py +334 -0
  1215. mindspore/profiler/parser/ascend_steptrace_generator.py +94 -0
  1216. mindspore/profiler/parser/ascend_timeline_generator.py +545 -0
  1217. mindspore/profiler/parser/base_timeline_generator.py +483 -0
  1218. mindspore/profiler/parser/container.py +229 -0
  1219. mindspore/profiler/parser/cpu_gpu_timeline_generator.py +697 -0
  1220. mindspore/profiler/parser/flops_parser.py +531 -0
  1221. mindspore/profiler/parser/framework_enum.py +111 -0
  1222. mindspore/profiler/parser/framework_parser.py +464 -0
  1223. mindspore/profiler/parser/framework_struct.py +61 -0
  1224. mindspore/profiler/parser/gpu_analysis/__init__.py +14 -0
  1225. mindspore/profiler/parser/gpu_analysis/function_event.py +44 -0
  1226. mindspore/profiler/parser/gpu_analysis/fwk_file_parser.py +89 -0
  1227. mindspore/profiler/parser/gpu_analysis/profiler_info_parser.py +72 -0
  1228. mindspore/profiler/parser/hccl_parser.py +573 -0
  1229. mindspore/profiler/parser/hwts_log_parser.py +122 -0
  1230. mindspore/profiler/parser/integrator.py +526 -0
  1231. mindspore/profiler/parser/memory_usage_parser.py +277 -0
  1232. mindspore/profiler/parser/minddata_analyzer.py +800 -0
  1233. mindspore/profiler/parser/minddata_parser.py +186 -0
  1234. mindspore/profiler/parser/minddata_pipeline_parser.py +299 -0
  1235. mindspore/profiler/parser/op_intermediate_parser.py +149 -0
  1236. mindspore/profiler/parser/optime_parser.py +250 -0
  1237. mindspore/profiler/parser/profiler_info.py +213 -0
  1238. mindspore/profiler/parser/step_trace_parser.py +666 -0
  1239. mindspore/profiler/profiler.py +153 -0
  1240. mindspore/profiler/profiling.py +1922 -0
  1241. mindspore/rewrite/__init__.py +28 -0
  1242. mindspore/rewrite/api/__init__.py +17 -0
  1243. mindspore/rewrite/api/node.py +519 -0
  1244. mindspore/rewrite/api/node_type.py +53 -0
  1245. mindspore/rewrite/api/pattern_engine.py +490 -0
  1246. mindspore/rewrite/api/scoped_value.py +181 -0
  1247. mindspore/rewrite/api/symbol_tree.py +497 -0
  1248. mindspore/rewrite/ast_helpers/__init__.py +25 -0
  1249. mindspore/rewrite/ast_helpers/ast_converter.py +143 -0
  1250. mindspore/rewrite/ast_helpers/ast_finder.py +404 -0
  1251. mindspore/rewrite/ast_helpers/ast_flattener.py +268 -0
  1252. mindspore/rewrite/ast_helpers/ast_modifier.py +605 -0
  1253. mindspore/rewrite/ast_helpers/ast_replacer.py +79 -0
  1254. mindspore/rewrite/common/__init__.py +19 -0
  1255. mindspore/rewrite/common/config.py +24 -0
  1256. mindspore/rewrite/common/error_log.py +39 -0
  1257. mindspore/rewrite/common/event.py +28 -0
  1258. mindspore/rewrite/common/namer.py +271 -0
  1259. mindspore/rewrite/common/namespace.py +118 -0
  1260. mindspore/rewrite/common/observable.py +44 -0
  1261. mindspore/rewrite/common/observer.py +54 -0
  1262. mindspore/rewrite/node/__init__.py +22 -0
  1263. mindspore/rewrite/node/call_function.py +95 -0
  1264. mindspore/rewrite/node/cell_container.py +139 -0
  1265. mindspore/rewrite/node/control_flow.py +113 -0
  1266. mindspore/rewrite/node/node.py +1428 -0
  1267. mindspore/rewrite/node/node_manager.py +283 -0
  1268. mindspore/rewrite/node/node_topological_manager.py +223 -0
  1269. mindspore/rewrite/parsers/__init__.py +29 -0
  1270. mindspore/rewrite/parsers/arguments_parser.py +63 -0
  1271. mindspore/rewrite/parsers/assign_parser.py +852 -0
  1272. mindspore/rewrite/parsers/attribute_parser.py +57 -0
  1273. mindspore/rewrite/parsers/class_def_parser.py +289 -0
  1274. mindspore/rewrite/parsers/constant_parser.py +104 -0
  1275. mindspore/rewrite/parsers/container_parser.py +88 -0
  1276. mindspore/rewrite/parsers/expr_parser.py +55 -0
  1277. mindspore/rewrite/parsers/for_parser.py +61 -0
  1278. mindspore/rewrite/parsers/function_def_parser.py +84 -0
  1279. mindspore/rewrite/parsers/if_parser.py +85 -0
  1280. mindspore/rewrite/parsers/module_parser.py +117 -0
  1281. mindspore/rewrite/parsers/parser.py +43 -0
  1282. mindspore/rewrite/parsers/parser_register.py +86 -0
  1283. mindspore/rewrite/parsers/return_parser.py +37 -0
  1284. mindspore/rewrite/parsers/while_parser.py +59 -0
  1285. mindspore/rewrite/sparsify/__init__.py +0 -0
  1286. mindspore/rewrite/sparsify/sparse_transformer.py +457 -0
  1287. mindspore/rewrite/sparsify/sparsify.py +112 -0
  1288. mindspore/rewrite/sparsify/utils.py +179 -0
  1289. mindspore/rewrite/symbol_tree/__init__.py +20 -0
  1290. mindspore/rewrite/symbol_tree/symbol_tree.py +1819 -0
  1291. mindspore/rewrite/symbol_tree/symbol_tree_builder.py +76 -0
  1292. mindspore/rewrite/symbol_tree/symbol_tree_dumper.py +142 -0
  1293. mindspore/run_check/__init__.py +20 -0
  1294. mindspore/run_check/_check_version.py +507 -0
  1295. mindspore/run_check/run_check.py +66 -0
  1296. mindspore/safeguard/__init__.py +18 -0
  1297. mindspore/safeguard/rewrite_obfuscation.py +875 -0
  1298. mindspore/scipy/__init__.py +18 -0
  1299. mindspore/scipy/fft.py +264 -0
  1300. mindspore/scipy/linalg.py +919 -0
  1301. mindspore/scipy/ops.py +165 -0
  1302. mindspore/scipy/ops_grad.py +115 -0
  1303. mindspore/scipy/ops_wrapper.py +74 -0
  1304. mindspore/scipy/optimize/__init__.py +20 -0
  1305. mindspore/scipy/optimize/_bfgs.py +230 -0
  1306. mindspore/scipy/optimize/_lagrange.py +201 -0
  1307. mindspore/scipy/optimize/_lbfgs.py +146 -0
  1308. mindspore/scipy/optimize/gradient_optimization_algorithm.py +168 -0
  1309. mindspore/scipy/optimize/line_search.py +370 -0
  1310. mindspore/scipy/optimize/linear_sum_assignment.py +78 -0
  1311. mindspore/scipy/optimize/minimize.py +200 -0
  1312. mindspore/scipy/utils.py +156 -0
  1313. mindspore/scipy/utils_const.py +246 -0
  1314. mindspore/train/__init__.py +48 -0
  1315. mindspore/train/_utils.py +465 -0
  1316. mindspore/train/amp.py +935 -0
  1317. mindspore/train/anf_ir_pb2.py +1517 -0
  1318. mindspore/train/callback/__init__.py +44 -0
  1319. mindspore/train/callback/_backup_and_restore.py +117 -0
  1320. mindspore/train/callback/_callback.py +613 -0
  1321. mindspore/train/callback/_checkpoint.py +814 -0
  1322. mindspore/train/callback/_cluster_monitor.py +201 -0
  1323. mindspore/train/callback/_dataset_graph.py +150 -0
  1324. mindspore/train/callback/_early_stop.py +239 -0
  1325. mindspore/train/callback/_flops_collector.py +239 -0
  1326. mindspore/train/callback/_history.py +92 -0
  1327. mindspore/train/callback/_lambda_callback.py +80 -0
  1328. mindspore/train/callback/_landscape.py +1049 -0
  1329. mindspore/train/callback/_loss_monitor.py +107 -0
  1330. mindspore/train/callback/_lr_scheduler_callback.py +76 -0
  1331. mindspore/train/callback/_on_request_exit.py +298 -0
  1332. mindspore/train/callback/_reduce_lr_on_plateau.py +226 -0
  1333. mindspore/train/callback/_summary_collector.py +1184 -0
  1334. mindspore/train/callback/_tft_register.py +352 -0
  1335. mindspore/train/callback/_time_monitor.py +141 -0
  1336. mindspore/train/checkpoint_pb2.py +233 -0
  1337. mindspore/train/data_sink.py +219 -0
  1338. mindspore/train/dataset_helper.py +692 -0
  1339. mindspore/train/lineage_pb2.py +1260 -0
  1340. mindspore/train/loss_scale_manager.py +213 -0
  1341. mindspore/train/memory_profiling_pb2.py +298 -0
  1342. mindspore/train/metrics/__init__.py +175 -0
  1343. mindspore/train/metrics/accuracy.py +133 -0
  1344. mindspore/train/metrics/auc.py +129 -0
  1345. mindspore/train/metrics/bleu_score.py +170 -0
  1346. mindspore/train/metrics/confusion_matrix.py +700 -0
  1347. mindspore/train/metrics/cosine_similarity.py +109 -0
  1348. mindspore/train/metrics/dice.py +116 -0
  1349. mindspore/train/metrics/error.py +175 -0
  1350. mindspore/train/metrics/fbeta.py +167 -0
  1351. mindspore/train/metrics/hausdorff_distance.py +333 -0
  1352. mindspore/train/metrics/loss.py +97 -0
  1353. mindspore/train/metrics/mean_surface_distance.py +189 -0
  1354. mindspore/train/metrics/metric.py +373 -0
  1355. mindspore/train/metrics/occlusion_sensitivity.py +225 -0
  1356. mindspore/train/metrics/perplexity.py +133 -0
  1357. mindspore/train/metrics/precision.py +160 -0
  1358. mindspore/train/metrics/recall.py +159 -0
  1359. mindspore/train/metrics/roc.py +223 -0
  1360. mindspore/train/metrics/root_mean_square_surface_distance.py +191 -0
  1361. mindspore/train/metrics/topk.py +167 -0
  1362. mindspore/train/mind_ir_pb2.py +1908 -0
  1363. mindspore/train/model.py +2252 -0
  1364. mindspore/train/node_strategy_pb2.py +653 -0
  1365. mindspore/train/print_pb2.py +184 -0
  1366. mindspore/train/profiling_parallel_pb2.py +151 -0
  1367. mindspore/train/serialization.py +3325 -0
  1368. mindspore/train/summary/__init__.py +23 -0
  1369. mindspore/train/summary/_lineage_adapter.py +41 -0
  1370. mindspore/train/summary/_summary_adapter.py +496 -0
  1371. mindspore/train/summary/_writer_pool.py +207 -0
  1372. mindspore/train/summary/enums.py +56 -0
  1373. mindspore/train/summary/summary_record.py +581 -0
  1374. mindspore/train/summary/writer.py +167 -0
  1375. mindspore/train/summary_pb2.py +1165 -0
  1376. mindspore/train/train_thor/__init__.py +20 -0
  1377. mindspore/train/train_thor/convert_utils.py +268 -0
  1378. mindspore/train/train_thor/dataset_helper.py +192 -0
  1379. mindspore/train/train_thor/model_thor.py +257 -0
  1380. mindspore/utils/__init__.py +21 -0
  1381. mindspore/utils/utils.py +60 -0
  1382. mindspore/version.py +1 -0
  1383. mindspore-2.4.0.dist-info/METADATA +352 -0
  1384. mindspore-2.4.0.dist-info/RECORD +1387 -0
  1385. mindspore-2.4.0.dist-info/WHEEL +5 -0
  1386. mindspore-2.4.0.dist-info/entry_points.txt +3 -0
  1387. mindspore-2.4.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1660 @@
1
+ # Copyright 2019 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Built-in py_transforms_utils functions.
16
+ """
17
+ import colorsys
18
+ import io
19
+ import math
20
+ import numbers
21
+ import random
22
+
23
+ import numpy as np
24
+ from PIL import Image, ImageOps, ImageEnhance
25
+
26
+ from ..core.py_util_helpers import is_numpy
27
+ from .utils import Inter, FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM, PERSPECTIVE, AFFINE
28
+
29
+ augment_error_message = "img should be PIL image. Got {}. Use Decode() for encoded data or ToPIL() for decoded data."
30
+
31
+
32
+ def is_pil(img):
33
+ """
34
+ Check if the input image is PIL format.
35
+
36
+ Args:
37
+ img: Image to be checked.
38
+
39
+ Returns:
40
+ bool, True if input is PIL.Image.Image.
41
+ """
42
+ return isinstance(img, Image.Image)
43
+
44
+
45
+ def normalize(img, mean, std, pad_channel=False, dtype="float32"):
46
+ """
47
+ Normalize the image between [0, 1] with respect to mean and standard deviation.
48
+
49
+ Args:
50
+ img (numpy.ndarray): Image array of shape CHW to be normalized.
51
+ mean (list): List of mean values for each channel, w.r.t channel order.
52
+ std (list): List of standard deviations for each channel, w.r.t. channel order.
53
+ pad_channel (bool): Whether to pad a extra channel with value zero.
54
+ dtype (str): Output datatype of normalize, only worked when `pad_channel` is ``True``.
55
+ Default: ``"float32"``.
56
+
57
+ Returns:
58
+ img (numpy.ndarray), Normalized image.
59
+ """
60
+ if not is_numpy(img):
61
+ raise TypeError("img should be NumPy image. Got {}.".format(type(img)))
62
+
63
+ if img.ndim != 3:
64
+ raise TypeError('img dimension should be 3. Got {}.'.format(img.ndim))
65
+
66
+ if np.issubdtype(img.dtype, np.integer):
67
+ raise NotImplementedError("Unsupported image datatype: [{}], pls execute [ToTensor] before [Normalize]."
68
+ .format(img.dtype))
69
+
70
+ num_channels = img.shape[0] # shape is (C, H, W)
71
+
72
+ if len(mean) != len(std):
73
+ raise ValueError("Length of mean and std must be equal.")
74
+ # if length equal to 1, adjust the mean and std arrays to have the correct
75
+ # number of channels (replicate the values)
76
+ if len(mean) == 1:
77
+ mean = [mean[0]] * num_channels
78
+ std = [std[0]] * num_channels
79
+ elif len(mean) != num_channels:
80
+ raise ValueError("Length of mean and std must both be 1 or equal to the number of channels({0})."
81
+ .format(num_channels))
82
+
83
+ mean = np.array(mean, dtype=img.dtype)
84
+ std = np.array(std, dtype=img.dtype)
85
+
86
+ image = (img - mean[:, None, None]) / std[:, None, None]
87
+ if pad_channel:
88
+ zeros = np.zeros([1, image.shape[1], image.shape[2]], dtype=np.float32)
89
+ image = np.concatenate((image, zeros), axis=0)
90
+ if dtype == "float16":
91
+ image = image.astype(np.float16)
92
+ return image
93
+
94
+
95
+ def decode(img):
96
+ """
97
+ Decode the input image to PIL Image format in RGB mode.
98
+
99
+ Args:
100
+ img: Image to be decoded.
101
+
102
+ Returns:
103
+ img (PIL.Image.Image), Decoded image in RGB mode.
104
+ """
105
+
106
+ try:
107
+ data = io.BytesIO(img)
108
+ img = Image.open(data)
109
+ return img.convert('RGB')
110
+ except IOError as e:
111
+ raise ValueError("{0}\n: Failed to decode given image.".format(e))
112
+ except AttributeError as e:
113
+ raise ValueError("{0}\n: Failed to decode, Image might already be decoded.".format(e))
114
+
115
+
116
+ def hwc_to_chw(img):
117
+ """
118
+ Transpose the input image from shape (H, W, C) to (C, H, W).
119
+ If the input image is of shape <H, W>, it will remain unchanged.
120
+
121
+ Args:
122
+ img (numpy.ndarray): Image to be converted.
123
+
124
+ Returns:
125
+ img (numpy.ndarray), Converted image.
126
+ """
127
+ if not is_numpy(img):
128
+ raise TypeError('img should be NumPy array. Got {}.'.format(type(img)))
129
+ if img.ndim not in (2, 3):
130
+ raise TypeError("img dimension should be 2 or 3. Got {}.".format(img.ndim))
131
+ if img.ndim == 2:
132
+ return img
133
+ return img.transpose(2, 0, 1).copy()
134
+
135
+
136
+ def to_tensor(img, output_type):
137
+ """
138
+ Change the input image (PIL.Image.Image or numpy.ndarray) to numpy.ndarray format.
139
+
140
+ Args:
141
+ img (Union[PIL.Image.Image, numpy.ndarray]): Image to be converted.
142
+ output_type: The datatype of the NumPy output. e.g. ``np.float32``.
143
+
144
+ Returns:
145
+ img (numpy.ndarray), Converted image.
146
+ """
147
+ if not (is_pil(img) or is_numpy(img)):
148
+ raise TypeError("The input image should be of type numpy.ndarray or PIL.Image.Image. Got {}.".format(type(img)))
149
+
150
+ img = np.asarray(img)
151
+ if img.ndim not in (2, 3):
152
+ raise TypeError("The dimension of input image should be 2 or 3. Got {}.".format(img.ndim))
153
+
154
+ if img.ndim == 2:
155
+ img = img[:, :, None]
156
+
157
+ img = hwc_to_chw(img)
158
+
159
+ img = img / 255.
160
+ return to_type(img, output_type)
161
+
162
+
163
+ def to_pil(img):
164
+ """
165
+ Convert the input image to PIL format.
166
+
167
+ Args:
168
+ img: Image to be converted.
169
+
170
+ Returns:
171
+ img (PIL.Image.Image), Converted image.
172
+ """
173
+ if not is_pil(img):
174
+ if not isinstance(img, np.ndarray):
175
+ raise TypeError("The input image should be of type numpy.ndarray or PIL.Image.Image. "
176
+ "Got {}.".format(type(img)))
177
+ if img.ndim not in (2, 3):
178
+ raise ValueError("The dimension of input image should be 2 or 3. Got {}.".format(img.ndim))
179
+ if img.ndim == 2:
180
+ img = np.expand_dims(img, 2)
181
+ if img.shape[-1] > 4:
182
+ raise ValueError("The channel of input image should not exceed 4. Got {}.".format(img.shape[-1]))
183
+ if img.shape[-1] == 1:
184
+ if img.dtype not in (np.bool_, np.int8, np.int16, np.int32, np.uint8, np.uint16, np.uint32, np.float32,
185
+ np.float64):
186
+ raise TypeError("The input image type {} is not supported when "
187
+ "image shape is [H, W] or [H, W, 1].".format(img.dtype))
188
+ img = img[:, :, 0]
189
+ elif img.dtype != np.uint8:
190
+ raise TypeError("The input image type {} is not supported when "
191
+ "image shape is [H, W, 2], [H, W, 3] or [H, W, 4].".format(img.dtype))
192
+ return Image.fromarray(img)
193
+ return img
194
+
195
+
196
+ def horizontal_flip(img):
197
+ """
198
+ Flip the input image horizontally.
199
+
200
+ Args:
201
+ img (PIL.Image.Image): Image to be flipped horizontally.
202
+
203
+ Returns:
204
+ PIL.Image.Image, Horizontally flipped image.
205
+ """
206
+ if not is_pil(img):
207
+ raise TypeError(augment_error_message.format(type(img)))
208
+
209
+ return img.transpose(FLIP_LEFT_RIGHT)
210
+
211
+
212
+ def vertical_flip(img):
213
+ """
214
+ Flip the input image vertically.
215
+
216
+ Args:
217
+ img (PIL.Image.Image): Image to be flipped vertically.
218
+
219
+ Returns:
220
+ PIL.Image.Image, Vertically flipped image.
221
+ """
222
+ if not is_pil(img):
223
+ raise TypeError(augment_error_message.format(type(img)))
224
+
225
+ return img.transpose(FLIP_TOP_BOTTOM)
226
+
227
+
228
+ def random_horizontal_flip(img, prob):
229
+ """
230
+ Randomly flip the input image horizontally.
231
+
232
+ Args:
233
+ img (PIL.Image.Image): Image to be flipped.
234
+ If the given probability is above the random probability, then the image is flipped.
235
+ prob (float): Probability of the image being flipped.
236
+
237
+ Returns:
238
+ PIL.Image.Image, Converted image.
239
+ """
240
+ if not is_pil(img):
241
+ raise TypeError(augment_error_message.format(type(img)))
242
+
243
+ if prob > random.random():
244
+ img = horizontal_flip(img)
245
+ return img
246
+
247
+
248
+ def random_vertical_flip(img, prob):
249
+ """
250
+ Randomly flip the input image vertically.
251
+
252
+ Args:
253
+ img (PIL.Image.Image): Image to be flipped.
254
+ If the given probability is above the random probability, then the image is flipped.
255
+ prob (float): Probability of the image being flipped.
256
+
257
+ Returns:
258
+ PIL.Image.Image, Converted image.
259
+ """
260
+ if not is_pil(img):
261
+ raise TypeError(augment_error_message.format(type(img)))
262
+
263
+ if prob > random.random():
264
+ img = vertical_flip(img)
265
+ return img
266
+
267
+
268
+ def crop(img, top, left, height, width):
269
+ """
270
+ Crop the input PIL Image.
271
+
272
+ Args:
273
+ img (PIL.Image.Image): Image to be cropped. (0,0) denotes the top left corner of the image,
274
+ in the directions of (width, height).
275
+ top (int): Vertical component of the top left corner of the crop box.
276
+ left (int): Horizontal component of the top left corner of the crop box.
277
+ height (int): Height of the crop box.
278
+ width (int): Width of the crop box.
279
+
280
+ Returns:
281
+ PIL.Image.Image, cropped image.
282
+ """
283
+ if not is_pil(img):
284
+ raise TypeError(augment_error_message.format(type(img)))
285
+
286
+ return img.crop((left, top, left + width, top + height))
287
+
288
+
289
+ def resize(img, size, interpolation=Inter.BILINEAR):
290
+ """
291
+ Resize the input PIL Image to desired size.
292
+
293
+ Args:
294
+ img (PIL.Image.Image): Image to be resized.
295
+ size (Union[int, sequence]): The output size of the resized image.
296
+ If size is an integer, smaller edge of the image will be resized to this value with
297
+ the same image aspect ratio.
298
+ If size is a sequence of (height, width), this will be the desired output size.
299
+ interpolation (interpolation mode): Image interpolation mode. Default: ``Inter.BILINEAR = 2``.
300
+
301
+ Returns:
302
+ PIL.Image.Image, resized image.
303
+ """
304
+ if not is_pil(img):
305
+ raise TypeError(augment_error_message.format(type(img)))
306
+ if not (isinstance(size, int) or (isinstance(size, (list, tuple)) and len(size) == 2)):
307
+ raise TypeError('Size should be a single number or a list/tuple (h, w) of length 2.'
308
+ 'Got {}.'.format(size))
309
+
310
+ if isinstance(size, int):
311
+ img_width, img_height = img.size
312
+ aspect_ratio = img_width / img_height # maintain the aspect ratio
313
+ if (img_width <= img_height and img_width == size) or \
314
+ (img_height <= img_width and img_height == size):
315
+ return img
316
+ if img_width < img_height:
317
+ out_width = size
318
+ out_height = int(size / aspect_ratio)
319
+ return img.resize((out_width, out_height), interpolation)
320
+ out_height = size
321
+ out_width = int(size * aspect_ratio)
322
+ return img.resize((out_width, out_height), interpolation)
323
+ return img.resize(size[::-1], interpolation)
324
+
325
+
326
+ def center_crop(img, size):
327
+ """
328
+ Crop the input PIL Image at the center to the given size.
329
+
330
+ Args:
331
+ img (PIL.Image.Image): Image to be cropped.
332
+ size (Union[int, tuple]): The size of the crop box.
333
+ If size is an integer, a square crop of size (size, size) is returned.
334
+ If size is a sequence of length 2, it should be (height, width).
335
+
336
+ Returns:
337
+ PIL.Image.Image, cropped image.
338
+ """
339
+ if not is_pil(img):
340
+ raise TypeError(augment_error_message.format(type(img)))
341
+
342
+ if isinstance(size, int):
343
+ size = (size, size)
344
+ img_width, img_height = img.size
345
+ crop_height, crop_width = size
346
+ crop_top = int(round((img_height - crop_height) / 2.))
347
+ crop_left = int(round((img_width - crop_width) / 2.))
348
+ return crop(img, crop_top, crop_left, crop_height, crop_width)
349
+
350
+
351
+ def random_resize_crop(img, size, scale, ratio, interpolation=Inter.BILINEAR, max_attempts=10):
352
+ """
353
+ Crop the input PIL Image to a random size and aspect ratio.
354
+
355
+ Args:
356
+ img (PIL.Image.Image): Image to be randomly cropped and resized.
357
+ size (Union[int, sequence]): The size of the output image.
358
+ If size is an integer, a square crop of size (size, size) is returned.
359
+ If size is a sequence of length 2, it should be (height, width).
360
+ scale (tuple): Range (min, max) of respective size of the original size to be cropped.
361
+ ratio (tuple): Range (min, max) of aspect ratio to be cropped.
362
+ interpolation (interpolation mode): Image interpolation mode. Default: ``Inter.BILINEAR = 2``.
363
+ max_attempts (int): The maximum number of attempts to propose a valid crop_area. Default: ``10``.
364
+ If exceeded, fall back to use center_crop instead.
365
+
366
+ Returns:
367
+ PIL.Image.Image, randomly cropped and resized image.
368
+ """
369
+ if not is_pil(img):
370
+ raise TypeError(augment_error_message.format(type(img)))
371
+ if isinstance(size, int):
372
+ size = (size, size)
373
+ elif isinstance(size, (tuple, list)) and len(size) == 2:
374
+ size = size
375
+ else:
376
+ raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
377
+
378
+ if scale[0] > scale[1] or ratio[0] > ratio[1]:
379
+ raise ValueError("Range should be in the order of (min, max).")
380
+
381
+ def _input_to_factor(img, scale, ratio):
382
+ img_width, img_height = img.size
383
+ img_area = img_width * img_height
384
+
385
+ for _ in range(max_attempts):
386
+ crop_area = random.uniform(scale[0], scale[1]) * img_area
387
+ # in case of non-symmetrical aspect ratios,
388
+ # use uniform distribution on a logarithmic scale.
389
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
390
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
391
+
392
+ width = int(round(math.sqrt(crop_area * aspect_ratio)))
393
+ height = int(round(width / aspect_ratio))
394
+
395
+ if 0 < width <= img_width and 0 < height <= img_height:
396
+ top = random.randint(0, img_height - height)
397
+ left = random.randint(0, img_width - width)
398
+ return top, left, height, width
399
+
400
+ # exceeding max_attempts, use center crop
401
+ img_ratio = img_width / img_height
402
+ if img_ratio < ratio[0]:
403
+ width = img_width
404
+ height = int(round(width / ratio[0]))
405
+ elif img_ratio > ratio[1]:
406
+ height = img_height
407
+ width = int(round(height * ratio[1]))
408
+ else:
409
+ width = img_width
410
+ height = img_height
411
+ top = int(round((img_height - height) / 2.))
412
+ left = int(round((img_width - width) / 2.))
413
+ return top, left, height, width
414
+
415
+ top, left, height, width = _input_to_factor(img, scale, ratio)
416
+ img = crop(img, top, left, height, width)
417
+ img = resize(img, size, interpolation)
418
+ return img
419
+
420
+
421
+ def random_crop(img, size, padding, pad_if_needed, fill_value, padding_mode):
422
+ """
423
+ Crop the input PIL Image at a random location.
424
+
425
+ Args:
426
+ img (PIL.Image.Image): Image to be randomly cropped.
427
+ size (Union[int, sequence]): The output size of the cropped image.
428
+ If size is an integer, a square crop of size (size, size) is returned.
429
+ If size is a sequence of length 2, it should be (height, width).
430
+ padding (Union[int, sequence], optional): The number of pixels to pad the image.
431
+ If a single number is provided, it pads all borders with this value.
432
+ If a tuple or lists of 2 values are provided, it pads the (left and top)
433
+ with the first value and (right and bottom) with the second value.
434
+ If 4 values are provided as a list or tuple,
435
+ it pads the left, top, right and bottom respectively.
436
+ Default: ``None``.
437
+ pad_if_needed (bool): Pad the image if either side is smaller than
438
+ the given output size. Default: ``False``.
439
+ fill_value (Union[int, tuple]): The pixel intensity of the borders if
440
+ the `padding_mode` is ``'constant'``. If it is a 3-tuple, it is used to
441
+ fill R, G, B channels respectively.
442
+ padding_mode (str): The method of padding. Can be ``'constant'``, ``'edge'``,
443
+ ``'reflect'``, ``'symmetric'``.
444
+
445
+ - ``'constant'``, means it fills the border with constant values
446
+ - ``'edge'``, means it pads with the last value on the edge
447
+ - ``'reflect'``, means it reflects the values on the edge omitting the last
448
+ value of edge
449
+ - ``'symmetric'``, means it reflects the values on the edge repeating the last
450
+ value of edge
451
+
452
+ Returns:
453
+ PIL.Image.Image, cropped image.
454
+ """
455
+ if not is_pil(img):
456
+ raise TypeError(augment_error_message.format(type(img)))
457
+ if isinstance(size, int):
458
+ size = (size, size)
459
+ elif isinstance(size, (tuple, list)) and len(size) == 2:
460
+ size = size
461
+ else:
462
+ raise TypeError("Size should be a single integer or a list/tuple (h, w) of length 2.")
463
+
464
+ if isinstance(fill_value, list):
465
+ # Ensure fill_value of type list (from serialize JSON support) is converted to type tuple
466
+ fill_value = tuple(fill_value)
467
+
468
+ def _input_to_factor(img, size):
469
+ img_width, img_height = img.size
470
+ height, width = size
471
+ if height > img_height or width > img_width:
472
+ raise ValueError("Crop size {} is larger than input image size {}.".format(size, (img_height, img_width)))
473
+
474
+ if width == img_width and height == img_height:
475
+ return 0, 0, img_height, img_width
476
+
477
+ top = random.randint(0, img_height - height)
478
+ left = random.randint(0, img_width - width)
479
+ return top, left, height, width
480
+
481
+ if padding is not None:
482
+ img = pad(img, padding, fill_value, padding_mode)
483
+ # pad width when needed, img.size (width, height), crop size (height, width)
484
+ if pad_if_needed and img.size[0] < size[1]:
485
+ img = pad(img, (size[1] - img.size[0], 0), fill_value, padding_mode)
486
+ # pad height when needed
487
+ if pad_if_needed and img.size[1] < size[0]:
488
+ img = pad(img, (0, size[0] - img.size[1]), fill_value, padding_mode)
489
+
490
+ top, left, height, width = _input_to_factor(img, size)
491
+ return crop(img, top, left, height, width)
492
+
493
+
494
+ def adjust_brightness(img, brightness_factor):
495
+ """
496
+ Adjust brightness of an image.
497
+
498
+ Args:
499
+ img (PIL.Image.Image): Image to be adjusted.
500
+ brightness_factor (float): A non negative number indicated the factor by which
501
+ the brightness is adjusted. ``0`` gives a black image, ``1`` gives the original.
502
+
503
+ Returns:
504
+ PIL.Image.Image, brightness adjusted image.
505
+ """
506
+ if not is_pil(img):
507
+ raise TypeError(augment_error_message.format(type(img)))
508
+
509
+ enhancer = ImageEnhance.Brightness(img)
510
+ img = enhancer.enhance(brightness_factor)
511
+ return img
512
+
513
+
514
+ def adjust_contrast(img, contrast_factor):
515
+ """
516
+ Adjust contrast of an image.
517
+
518
+ Args:
519
+ img (PIL.Image.Image): PIL Image to be adjusted.
520
+ contrast_factor (float): A non negative number indicated the factor by which
521
+ the contrast is adjusted. ``0`` gives a solid gray image, ``1`` gives the original.
522
+
523
+ Returns:
524
+ PIL.Image.Image, contrast adjusted image.
525
+ """
526
+ if not is_pil(img):
527
+ raise TypeError(augment_error_message.format(type(img)))
528
+
529
+ enhancer = ImageEnhance.Contrast(img)
530
+ img = enhancer.enhance(contrast_factor)
531
+ return img
532
+
533
+
534
+ def adjust_saturation(img, saturation_factor):
535
+ """
536
+ Adjust saturation of an image.
537
+
538
+ Args:
539
+ img (PIL.Image.Image): PIL Image to be adjusted.
540
+ saturation_factor (float): A non negative number indicated the factor by which
541
+ the saturation is adjusted. ``0`` will give a black and white image, ``1`` will
542
+ give the original.
543
+
544
+ Returns:
545
+ PIL.Image.Image, saturation adjusted image.
546
+ """
547
+ if not is_pil(img):
548
+ raise TypeError(augment_error_message.format(type(img)))
549
+
550
+ enhancer = ImageEnhance.Color(img)
551
+ img = enhancer.enhance(saturation_factor)
552
+ return img
553
+
554
+
555
+ def adjust_hue(img, hue_factor):
556
+ """
557
+ Adjust hue of an image. The Hue is changed by changing the HSV values after image is converted to HSV.
558
+
559
+ Args:
560
+ img (PIL.Image.Image): PIL Image to be adjusted.
561
+ hue_factor (float): Amount to shift the Hue channel. Value should be in
562
+ [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel. This
563
+ is because Hue wraps around when rotated 360 degrees.
564
+ ``0`` means no shift that gives the original image while both -0.5 and 0.5
565
+ will give an image with complementary colors .
566
+
567
+ Returns:
568
+ PIL.Image.Image, hue adjusted image.
569
+ """
570
+ image = img
571
+ image_hue_factor = hue_factor
572
+ if not -0.5 <= image_hue_factor <= 0.5:
573
+ raise ValueError('image_hue_factor {} is not in [-0.5, 0.5].'.format(image_hue_factor))
574
+
575
+ if not is_pil(image):
576
+ raise TypeError(augment_error_message.format(type(image)))
577
+
578
+ mode = image.mode
579
+ if mode in {'L', '1', 'I', 'F'}:
580
+ return image
581
+
582
+ hue, saturation, value = img.convert('HSV').split()
583
+
584
+ np_hue = np.array(hue, dtype=np.uint8)
585
+
586
+ with np.errstate(over='ignore'):
587
+ np_hue += np.uint8(image_hue_factor * 255)
588
+ hue = Image.fromarray(np_hue, 'L')
589
+
590
+ image = Image.merge('HSV', (hue, saturation, value)).convert(mode)
591
+ return image
592
+
593
+
594
+ def to_type(img, output_type):
595
+ """
596
+ Convert the NumPy image array to desired NumPy dtype.
597
+
598
+ Args:
599
+ img (numpy): NumPy image to cast to desired NumPy dtype.
600
+ output_type (Numpy datatype): NumPy dtype to cast to.
601
+
602
+ Returns:
603
+ img (numpy.ndarray), Converted image.
604
+ """
605
+ if not is_numpy(img):
606
+ raise TypeError("img should be NumPy image. Got {}.".format(type(img)))
607
+
608
+ try:
609
+ return img.astype(output_type)
610
+ except Exception:
611
+ raise RuntimeError("output_type: " + str(output_type) + " is not a valid datatype.")
612
+
613
+
614
+ def rotate(img, angle, resample, expand, center, fill_value):
615
+ """
616
+ Rotate the input PIL Image by angle.
617
+
618
+ Args:
619
+ img (PIL.Image.Image): Image to be rotated.
620
+ angle (int or float): Rotation angle in degrees, counter-clockwise.
621
+ resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter.
622
+ If omitted, or if the image has mode "1" or "P", it is set to be ``Inter.NEAREST``.
623
+ expand (bool, optional): Optional expansion flag. If set to ``True``, expand the output
624
+ image to make it large enough to hold the entire rotated image.
625
+ If set to ``False`` or omitted, make the output image the same size as the input.
626
+ Note that the expand flag assumes rotation around the center and no translation.
627
+ center (tuple, optional): Optional center of rotation (a 2-tuple).
628
+ Origin is the top left corner.
629
+ fill_value (Union[int, tuple]): Optional fill color for the area outside the rotated image.
630
+ If it is a 3-tuple, it is used for R, G, B channels respectively.
631
+ If it is an integer, it is used for all RGB channels.
632
+
633
+ Returns:
634
+ PIL.Image.Image, rotated image.
635
+
636
+ https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.rotate
637
+ """
638
+ if not is_pil(img):
639
+ raise TypeError(augment_error_message.format(type(img)))
640
+
641
+ if isinstance(fill_value, int):
642
+ fill_value = tuple([fill_value] * 3)
643
+ elif isinstance(fill_value, list):
644
+ # Ensure fill_value of type list (from serialize JSON support) is converted to type tuple
645
+ fill_value = tuple(fill_value)
646
+
647
+ angle = angle % 360.0
648
+ if resample == Inter.ANTIALIAS:
649
+ if angle not in [0, 90, 180, 270] or center:
650
+ raise ValueError("When using Inter.ANTIALIAS, center needs to be None and "
651
+ "angle needs to be an integer multiple of 90.")
652
+ return img.rotate(angle, resample, expand, center, fillcolor=fill_value)
653
+
654
+
655
+ def random_color_adjust(img, brightness, contrast, saturation, hue):
656
+ """
657
+ Randomly adjust the brightness, contrast, saturation, and hue of an image.
658
+
659
+ Args:
660
+ img (PIL.Image.Image): Image to have its color adjusted randomly.
661
+ brightness (Union[float, tuple]): Brightness adjustment factor. Cannot be negative.
662
+ If it is a float, the factor is uniformly chosen from the range [max(0, 1-brightness), 1+brightness].
663
+ If it is a sequence, it should be [min, max] for the range.
664
+ contrast (Union[float, tuple]): Contrast adjustment factor. Cannot be negative.
665
+ If it is a float, the factor is uniformly chosen from the range [max(0, 1-contrast), 1+contrast].
666
+ If it is a sequence, it should be [min, max] for the range.
667
+ saturation (Union[float, tuple]): Saturation adjustment factor. Cannot be negative.
668
+ If it is a float, the factor is uniformly chosen from the range [max(0, 1-saturation), 1+saturation].
669
+ If it is a sequence, it should be [min, max] for the range.
670
+ hue (Union[float, tuple]): Hue adjustment factor.
671
+ If it is a float, the range will be [-hue, hue]. Value should be 0 <= hue <= 0.5.
672
+ If it is a sequence, it should be [min, max] where -0.5 <= min <= max <= 0.5.
673
+
674
+ Returns:
675
+ PIL.Image.Image, image after random adjustment of its color.
676
+ """
677
+ if not is_pil(img):
678
+ raise TypeError(augment_error_message.format(type(img)))
679
+
680
+ def _input_to_factor(value, input_name, center=1, bound=(0, float('inf')), non_negative=True):
681
+ if isinstance(value, numbers.Number):
682
+ if value < 0:
683
+ raise ValueError("The input value of {} cannot be negative.".format(input_name))
684
+ # convert value into a range
685
+ value = [center - value, center + value]
686
+ if non_negative:
687
+ value[0] = max(0, value[0])
688
+ elif isinstance(value, (list, tuple)) and len(value) == 2:
689
+ if not bound[0] <= value[0] <= value[1] <= bound[1]:
690
+ raise ValueError("Please check your value range of {} is valid and "
691
+ "within the bound {}.".format(input_name, bound))
692
+ else:
693
+ raise TypeError("Input of {} should be either a single value, or a list/tuple of "
694
+ "length 2.".format(input_name))
695
+ factor = random.uniform(value[0], value[1])
696
+ return factor
697
+
698
+ brightness_factor = _input_to_factor(brightness, 'brightness')
699
+ contrast_factor = _input_to_factor(contrast, 'contrast')
700
+ saturation_factor = _input_to_factor(saturation, 'saturation')
701
+ hue_factor = _input_to_factor(hue, 'hue', center=0, bound=(-0.5, 0.5), non_negative=False)
702
+
703
+ transforms = [lambda img: adjust_brightness(img, brightness_factor),
704
+ lambda img: adjust_contrast(img, contrast_factor),
705
+ lambda img: adjust_saturation(img, saturation_factor),
706
+ lambda img: adjust_hue(img, hue_factor)]
707
+
708
+ # apply color adjustments in a random order
709
+ random.shuffle(transforms)
710
+ for transform in transforms:
711
+ img = transform(img)
712
+
713
+ return img
714
+
715
+
716
+ def random_lighting(img, alpha):
717
+ """
718
+ Add AlexNet-style PCA-based noise to an image.
719
+
720
+ Args:
721
+ img (PIL.Image.Image): Image to be added AlexNet-style PCA-based noise.
722
+ alpha (float, optional): Intensity of the image.
723
+
724
+ Returns:
725
+ PIL.Image.Image, image with noise added.
726
+ """
727
+ if not is_pil(img):
728
+ raise TypeError(augment_error_message.format(type(img)))
729
+ if img.mode != 'RGB':
730
+ img = img.convert("RGB")
731
+
732
+ alpha_r = np.random.normal(loc=0.0, scale=alpha)
733
+ alpha_g = np.random.normal(loc=0.0, scale=alpha)
734
+ alpha_b = np.random.normal(loc=0.0, scale=alpha)
735
+ table = np.array([
736
+ [55.46 * -0.5675, 4.794 * 0.7192, 1.148 * 0.4009],
737
+ [55.46 * -0.5808, 4.794 * -0.0045, 1.148 * -0.8140],
738
+ [55.46 * -0.5836, 4.794 * -0.6948, 1.148 * 0.4203]
739
+ ])
740
+ pca_r = table[0][0] * alpha_r + table[0][1] * alpha_g + table[0][2] * alpha_b
741
+ pca_g = table[1][0] * alpha_r + table[1][1] * alpha_g + table[1][2] * alpha_b
742
+ pca_b = table[2][0] * alpha_r + table[2][1] * alpha_g + table[2][2] * alpha_b
743
+ img_arr = np.array(img).astype(np.float64)
744
+ img_arr[:, :, 0] += pca_r
745
+ img_arr[:, :, 1] += pca_g
746
+ img_arr[:, :, 2] += pca_b
747
+ img_arr = np.uint8(np.minimum(np.maximum(img_arr, 0), 255))
748
+ img = Image.fromarray(img_arr)
749
+ return img
750
+
751
+
752
+ def random_rotation(img, degrees, resample, expand, center, fill_value):
753
+ """
754
+ Rotate the input PIL Image by a random angle.
755
+
756
+ See <https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.rotate>.
757
+
758
+ Args:
759
+ img (PIL.Image.Image): Image to be rotated.
760
+ degrees (Union[int, float, sequence]): Range of random rotation degrees.
761
+ If `degrees` is a number, the range will be converted to (-degrees, degrees).
762
+ If `degrees` is a sequence, it should be (min, max).
763
+ resample (Union[Inter.NEAREST, Inter.BILINEAR, Inter.BICUBIC], optional): An optional resampling filter.
764
+ If omitted, or if the image has mode "1" or "P", it is set to be ``Inter.NEAREST``.
765
+ expand (bool, optional): Optional expansion flag. If set to True, expand the output
766
+ image to make it large enough to hold the entire rotated image.
767
+ If set to ``False`` or omitted, make the output image the same size as the input.
768
+ Note that the expand flag assumes rotation around the center and no translation.
769
+ center (tuple, optional): Optional center of rotation (a 2-tuple).
770
+ Origin is the top left corner.
771
+ fill_value (Union[int, tuple]): Optional fill color for the area outside the rotated image.
772
+ If it is a 3-tuple, it is used for R, G, B channels respectively.
773
+ If it is an integer, it is used for all RGB channels.
774
+
775
+ Returns:
776
+ PIL.Image.Image, Rotated image.
777
+ """
778
+ if not is_pil(img):
779
+ raise TypeError(augment_error_message.format(type(img)))
780
+
781
+ if isinstance(degrees, numbers.Number):
782
+ if degrees < 0:
783
+ raise ValueError("If degrees is a single number, it cannot be negative.")
784
+ degrees = (-degrees, degrees)
785
+ elif isinstance(degrees, (list, tuple)):
786
+ if len(degrees) != 2:
787
+ raise ValueError("If degrees is a sequence, the length must be 2.")
788
+ else:
789
+ raise TypeError("Degrees must be a single non-negative number or a sequence.")
790
+
791
+ if isinstance(fill_value, list):
792
+ # Ensure fill_value of type list (from serialize JSON support) is converted to type tuple
793
+ fill_value = tuple(fill_value)
794
+
795
+ angle = random.uniform(degrees[0], degrees[1])
796
+ return rotate(img, angle, resample, expand, center, fill_value)
797
+
798
+
799
+ def five_crop(img, size):
800
+ """
801
+ Generate 5 cropped images (one central and four corners).
802
+
803
+ Args:
804
+ img (PIL.Image.Image): PIL Image to be cropped.
805
+ size (Union[int, sequence]): The output size of the crop.
806
+ If size is an integer, a square crop of size (size, size) is returned.
807
+ If size is a sequence of length 2, it should be (height, width).
808
+
809
+ Returns:
810
+ img_tuple (tuple), a tuple of 5 PIL Image
811
+ (top_left, top_right, bottom_left, bottom_right, center).
812
+ """
813
+ if not is_pil(img):
814
+ raise TypeError(augment_error_message.format(type(img)))
815
+
816
+ if isinstance(size, int):
817
+ size = (size, size)
818
+ elif isinstance(size, (tuple, list)) and len(size) == 2:
819
+ size = size
820
+ else:
821
+ raise TypeError("Size should be a single number or a list/tuple (h, w) of length 2.")
822
+
823
+ # PIL.Image.Image.size returns in (width, height) order
824
+ img_width, img_height = img.size
825
+ crop_height, crop_width = size
826
+ if crop_height > img_height or crop_width > img_width:
827
+ raise ValueError("Crop size {} is larger than input image size {}.".format(size, (img_height, img_width)))
828
+ center = center_crop(img, (crop_height, crop_width))
829
+ top_left = img.crop((0, 0, crop_width, crop_height))
830
+ top_right = img.crop((img_width - crop_width, 0, img_width, crop_height))
831
+ bottom_left = img.crop((0, img_height - crop_height, crop_width, img_height))
832
+ bottom_right = img.crop((img_width - crop_width, img_height - crop_height, img_width, img_height))
833
+
834
+ return top_left, top_right, bottom_left, bottom_right, center
835
+
836
+
837
+ def ten_crop(img, size, use_vertical_flip=False):
838
+ """
839
+ Generate 10 cropped images (first 5 from FiveCrop, second 5 from their flipped version).
840
+
841
+ The default is horizontal flipping, `use_vertical_flip` is ``False``.
842
+
843
+ Args:
844
+ img (PIL.Image.Image): PIL Image to be cropped.
845
+ size (Union[int, sequence]): The output size of the crop.
846
+ If size is an integer, a square crop of size (size, size) is returned.
847
+ If size is a sequence of length 2, it should be (height, width).
848
+ use_vertical_flip (bool): Flip the image vertically instead of horizontally if set to ``True``.
849
+ Default: ``False``.
850
+
851
+ Returns:
852
+ tuple[PIL.Image.Image], a tuple of 10 PIL Image
853
+ (top_left, top_right, bottom_left, bottom_right, center) of original image +
854
+ (top_left, top_right, bottom_left, bottom_right, center) of flipped image.
855
+ """
856
+ if not is_pil(img):
857
+ raise TypeError(augment_error_message.format(type(img)))
858
+
859
+ if isinstance(size, int):
860
+ size = (size, size)
861
+ elif isinstance(size, (tuple, list)) and len(size) == 2:
862
+ size = size
863
+ else:
864
+ raise TypeError("Size should be a single number or a list/tuple (h, w) of length 2.")
865
+
866
+ first_five_crop = five_crop(img, size)
867
+
868
+ if use_vertical_flip:
869
+ img = vertical_flip(img)
870
+ else:
871
+ img = horizontal_flip(img)
872
+
873
+ second_five_crop = five_crop(img, size)
874
+
875
+ return first_five_crop + second_five_crop
876
+
877
+
878
+ def grayscale(img, num_output_channels):
879
+ """
880
+ Convert the input PIL Image to grayscale image.
881
+
882
+ Args:
883
+ img (PIL.Image.Image): PIL Image to be converted to grayscale.
884
+ num_output_channels (int): Number of channels of the output grayscale image (1 or 3).
885
+
886
+ Returns:
887
+ PIL.Image.Image, grayscaled image.
888
+ """
889
+ if not is_pil(img):
890
+ raise TypeError(augment_error_message.format(type(img)))
891
+
892
+ if num_output_channels == 1:
893
+ img = img.convert('L')
894
+ elif num_output_channels == 3:
895
+ # each channel is the same grayscale layer
896
+ img = img.convert('L')
897
+ np_gray = np.array(img, dtype=np.uint8)
898
+ np_img = np.dstack([np_gray, np_gray, np_gray])
899
+ img = Image.fromarray(np_img, 'RGB')
900
+ else:
901
+ raise ValueError('num_output_channels should be either 1 or 3. Got {}.'.format(num_output_channels))
902
+
903
+ return img
904
+
905
+
906
+ def pad(img, padding, fill_value, padding_mode):
907
+ """
908
+ Pad the image according to padding parameters.
909
+
910
+ Args:
911
+ img (PIL.Image.Image): Image to be padded.
912
+ padding (Union[int, sequence], optional): The number of pixels to pad the image.
913
+ If a single number is provided, it pads all borders with this value.
914
+ If a tuple or lists of 2 values are provided, it pads the (left and top)
915
+ with the first value and (right and bottom) with the second value.
916
+ If 4 values are provided as a list or tuple,
917
+ it pads the left, top, right and bottom respectively.
918
+ Default: ``None``.
919
+ fill_value (Union[int, tuple]): The pixel intensity of the borders if
920
+ the `padding_mode` is ``"constant"``. If it is a 3-tuple, it is used to
921
+ fill R, G, B channels respectively.
922
+ padding_mode (str): The method of padding. Can be ``'constant'``, ``'edge'``, ``'reflect'``, ``'symmetric'``.
923
+
924
+ - ``'constant'``, means it fills the border with constant values
925
+ - ``'edge'``, means it pads with the last value on the edge
926
+ - ``'reflect'``, means it reflects the values on the edge omitting the last
927
+ value of edge
928
+ - ``'symmetric'``, means it reflects the values on the edge repeating the last
929
+ value of edge
930
+
931
+ Returns:
932
+ PIL.Image.Image, padded image.
933
+ """
934
+ if not is_pil(img):
935
+ raise TypeError(augment_error_message.format(type(img)))
936
+
937
+ if isinstance(padding, numbers.Number):
938
+ top = bottom = left = right = padding
939
+
940
+ elif isinstance(padding, (tuple, list)):
941
+ if len(padding) == 2:
942
+ left = right = padding[0]
943
+ top = bottom = padding[1]
944
+ elif len(padding) == 4:
945
+ left = padding[0]
946
+ top = padding[1]
947
+ right = padding[2]
948
+ bottom = padding[3]
949
+ else:
950
+ raise ValueError("The size of the padding list or tuple should be 2 or 4.")
951
+ else:
952
+ raise TypeError("Padding can be any of: a number, a tuple or list of size 2 or 4.")
953
+
954
+ if isinstance(fill_value, list):
955
+ # Ensure fill_value of type list (from serialize JSON support) is converted to type tuple
956
+ fill_value = tuple(fill_value)
957
+
958
+ if not isinstance(fill_value, (numbers.Number, str, tuple)):
959
+ raise TypeError("fill_value can be any of: an integer, a string or a tuple.")
960
+
961
+ if padding_mode not in ['constant', 'edge', 'reflect', 'symmetric']:
962
+ raise ValueError("Padding mode should be 'constant', 'edge', 'reflect', or 'symmetric'.")
963
+
964
+ if padding_mode == 'constant':
965
+ if img.mode == 'P':
966
+ palette = img.getpalette()
967
+ image = ImageOps.expand(img, border=(left, top, right, bottom), fill=fill_value)
968
+ image.putpalette(palette)
969
+ return image
970
+ if isinstance(fill_value, tuple) and (img.mode == 'L' or img.mode == '1'):
971
+ fill_value = (fill_value[0],)
972
+ return ImageOps.expand(img, border=(left, top, right, bottom), fill=fill_value)
973
+
974
+ if img.mode == 'P':
975
+ palette = img.getpalette()
976
+ img = np.asarray(img)
977
+ img = np.pad(img, ((top, bottom), (left, right)), padding_mode)
978
+ img = Image.fromarray(img)
979
+ img.putpalette(palette)
980
+ return img
981
+
982
+ img = np.asarray(img)
983
+ if len(img.shape) == 3:
984
+ img = np.pad(img, ((top, bottom), (left, right), (0, 0)), padding_mode)
985
+ if len(img.shape) == 2:
986
+ img = np.pad(img, ((top, bottom), (left, right)), padding_mode)
987
+
988
+ return Image.fromarray(img)
989
+
990
+
991
+ def get_perspective_params(img, distortion_scale):
992
+ """Helper function to get parameters for RandomPerspective.
993
+ """
994
+ img_width, img_height = img.size
995
+ distorted_half_width = int(img_width / 2 * distortion_scale)
996
+ distorted_half_height = int(img_height / 2 * distortion_scale)
997
+ top_left = (random.randint(0, distorted_half_width),
998
+ random.randint(0, distorted_half_height))
999
+ top_right = (random.randint(img_width - distorted_half_width - 1, img_width - 1),
1000
+ random.randint(0, distorted_half_height))
1001
+ bottom_right = (random.randint(img_width - distorted_half_width - 1, img_width - 1),
1002
+ random.randint(img_height - distorted_half_height - 1, img_height - 1))
1003
+ bottom_left = (random.randint(0, distorted_half_width),
1004
+ random.randint(img_height - distorted_half_height - 1, img_height - 1))
1005
+
1006
+ start_points = [(0, 0), (img_width - 1, 0), (img_width - 1, img_height - 1), (0, img_height - 1)]
1007
+ end_points = [top_left, top_right, bottom_right, bottom_left]
1008
+ return start_points, end_points
1009
+
1010
+
1011
+ def perspective(img, start_points, end_points, interpolation=Inter.BICUBIC):
1012
+ """
1013
+ Apply perspective transformation to the input PIL Image.
1014
+
1015
+ Args:
1016
+ img (PIL.Image.Image): PIL Image to be applied perspective transformation.
1017
+ start_points (list): List of [top_left, top_right, bottom_right, bottom_left] of the original image.
1018
+ end_points: List of [top_left, top_right, bottom_right, bottom_left] of the transformed image.
1019
+ interpolation (interpolation mode): Image interpolation mode, Default: ``Inter.BICUBIC = 3``.
1020
+
1021
+ Returns:
1022
+ PIL.Image.Image, image after being perspectively transformed.
1023
+ """
1024
+
1025
+ def _input_to_coeffs(original_points, transformed_points):
1026
+ # Get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
1027
+ # According to "Using Projective Geometry to Correct a Camera" from AMS.
1028
+ # http://www.ams.org/publicoutreach/feature-column/fc-2013-03
1029
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Geometry.c#L377
1030
+
1031
+ matrix = []
1032
+ for pt1, pt2 in zip(transformed_points, original_points):
1033
+ matrix.append([pt1[0], pt1[1], 1, 0, 0, 0, -pt2[0] * pt1[0], -pt2[0] * pt1[1]])
1034
+ matrix.append([0, 0, 0, pt1[0], pt1[1], 1, -pt2[1] * pt1[0], -pt2[1] * pt1[1]])
1035
+ matrix_a = np.array(matrix, dtype=float)
1036
+ matrix_b = np.array(original_points, dtype=float).reshape(8)
1037
+ res = np.linalg.lstsq(matrix_a, matrix_b, rcond=None)[0]
1038
+ return res.tolist()
1039
+
1040
+ if not is_pil(img):
1041
+ raise TypeError(augment_error_message.format(type(img)))
1042
+
1043
+ coeffs = _input_to_coeffs(start_points, end_points)
1044
+ return img.transform(img.size, PERSPECTIVE, coeffs, interpolation)
1045
+
1046
+
1047
+ def get_erase_params(np_img, scale, ratio, value, bounded, max_attempts):
1048
+ """Helper function to get parameters for RandomErasing/Cutout.
1049
+ """
1050
+ if not is_numpy(np_img):
1051
+ raise TypeError('img should be NumPy array. Got {}.'.format(type(np_img)))
1052
+
1053
+ image_c, image_h, image_w = np_img.shape
1054
+ area = image_h * image_w
1055
+
1056
+ for _ in range(max_attempts):
1057
+ erase_area = random.uniform(scale[0], scale[1]) * area
1058
+ aspect_ratio = random.uniform(ratio[0], ratio[1])
1059
+ erase_w = int(round(math.sqrt(erase_area * aspect_ratio)))
1060
+ erase_h = int(round(erase_w / aspect_ratio))
1061
+ erase_shape = (image_c, erase_h, erase_w)
1062
+
1063
+ if erase_h < image_h and erase_w < image_w:
1064
+ if bounded:
1065
+ i = random.randint(0, image_h - erase_h)
1066
+ j = random.randint(0, image_w - erase_w)
1067
+ else:
1068
+ def clip(x, lower, upper):
1069
+ return max(lower, min(x, upper))
1070
+
1071
+ x = random.randint(0, image_w)
1072
+ y = random.randint(0, image_h)
1073
+ x1 = clip(x - erase_w // 2, 0, image_w)
1074
+ x2 = clip(x + erase_w // 2, 0, image_w)
1075
+ y1 = clip(y - erase_h // 2, 0, image_h)
1076
+ y2 = clip(y + erase_h // 2, 0, image_h)
1077
+
1078
+ i, j, erase_h, erase_w = y1, x1, y2 - y1, x2 - x1
1079
+
1080
+ if isinstance(value, numbers.Number):
1081
+ erase_value = value
1082
+ elif isinstance(value, (str, bytes)):
1083
+ erase_value = np.random.normal(loc=0.0, scale=1.0, size=erase_shape)
1084
+ elif isinstance(value, (tuple, list)) and len(value) == 3:
1085
+ value = np.array(value)
1086
+ erase_value = np.multiply(np.ones(erase_shape), value[:, None, None])
1087
+ else:
1088
+ raise ValueError("The value for erasing should be either a single value, or a string "
1089
+ "'random', or a sequence of 3 elements for RGB respectively.")
1090
+
1091
+ return i, j, erase_h, erase_w, erase_value
1092
+
1093
+ # exceeding max_attempts, return original image
1094
+ return 0, 0, image_h, image_w, np_img
1095
+
1096
+
1097
+ def erase(np_img, i, j, height, width, erase_value, inplace=False):
1098
+ """
1099
+ Erase the pixels, within a selected rectangle region, to the given value. Applied on the input NumPy image array.
1100
+
1101
+ Args:
1102
+ np_img (numpy.ndarray): NumPy image array of shape (C, H, W) to be erased.
1103
+ i (int): The height component of the top left corner (height, width).
1104
+ j (int): The width component of the top left corner (height, width).
1105
+ height (int): Height of the erased region.
1106
+ width (int): Width of the erased region.
1107
+ erase_value: Erase value return from helper function get_erase_params().
1108
+ inplace (bool, optional): Apply this transform inplace. Default: ``False``.
1109
+
1110
+ Returns:
1111
+ np_img (numpy.ndarray), Erased NumPy image array.
1112
+ """
1113
+ if not is_numpy(np_img):
1114
+ raise TypeError('img should be NumPy array. Got {}.'.format(type(np_img)))
1115
+
1116
+ if not inplace:
1117
+ np_img = np_img.copy()
1118
+ # (i, j) here are the coordinates of axes (height, width) as in CHW
1119
+ np_img[:, i:i + height, j:j + width] = erase_value
1120
+ return np_img
1121
+
1122
+
1123
+ def linear_transform(np_img, transformation_matrix, mean_vector):
1124
+ """
1125
+ Apply linear transformation to the input NumPy image array, given a square transformation matrix and a mean_vector.
1126
+
1127
+ The transformation first flattens the input array and subtract mean_vector from it, then computes the
1128
+ dot product with the transformation matrix, and reshapes it back to its original shape.
1129
+
1130
+ Args:
1131
+ np_img (numpy.ndarray): NumPy image array of shape (C, H, W) to be linear transformed.
1132
+ transformation_matrix (numpy.ndarray): a square transformation matrix of shape (D, D), D = C x H x W.
1133
+ mean_vector (numpy.ndarray): a NumPy ndarray of shape (D,) where D = C x H x W.
1134
+
1135
+ Returns:
1136
+ np_img (numpy.ndarray), Linear transformed image.
1137
+ """
1138
+ if not is_numpy(np_img):
1139
+ raise TypeError('img should be NumPy array. Got {}'.format(type(np_img)))
1140
+ if transformation_matrix.shape[0] != transformation_matrix.shape[1]:
1141
+ raise ValueError("transformation_matrix should be a square matrix. "
1142
+ "Got shape {} instead".format(transformation_matrix.shape))
1143
+ if np.prod(np_img.shape) != transformation_matrix.shape[0]:
1144
+ raise ValueError("transformation_matrix shape {0} not compatible with "
1145
+ "Numpy image shape {1}.".format(transformation_matrix.shape, np_img.shape))
1146
+ if mean_vector.shape[0] != transformation_matrix.shape[0]:
1147
+ raise ValueError("mean_vector length {0} should match either one dimension of the square "
1148
+ "transformation_matrix {1}.".format(mean_vector.shape[0], transformation_matrix.shape))
1149
+ zero_centered_img = np_img.reshape(1, -1) - mean_vector
1150
+ transformed_img = np.dot(zero_centered_img, transformation_matrix)
1151
+ if transformed_img.size != np_img.size:
1152
+ raise ValueError("Linear transform failed, input shape should match with transformation_matrix.")
1153
+ transformed_img = transformed_img.reshape(np_img.shape)
1154
+ return transformed_img
1155
+
1156
+
1157
+ def random_affine(img, angle, translations, scale, shear, resample, fill_value=0):
1158
+ """
1159
+ Applies a random Affine transformation on the input PIL Image.
1160
+
1161
+ Args:
1162
+ img (PIL.Image.Image): Image to be applied affine transformation.
1163
+ angle (Sequence): Rotation angle in degrees, clockwise.
1164
+ translations (Sequence): Translations in horizontal and vertical axis.
1165
+ scale (Sequence): Scale parameter.
1166
+ shear (Sequence): Shear amount parallel to X axis and Y axis.
1167
+ resample (Inter): Resampling filter.
1168
+ fill_value (Union[tuple, int], optional): Optional fill_value to fill the area outside the transform
1169
+ in the output image. Used only in Pillow versions > 5.0.0.
1170
+ If ``None``, no filling is performed.
1171
+
1172
+ Returns:
1173
+ PIL.Image.Image, randomly affine transformed image.
1174
+
1175
+ """
1176
+ if not is_pil(img):
1177
+ raise ValueError("Input image should be a Pillow image.")
1178
+
1179
+ # rotation
1180
+ angle = random.uniform(angle[0], angle[1])
1181
+
1182
+ # translation
1183
+ if translations is not None:
1184
+ max_dx = translations[0] * img.size[0]
1185
+ max_dy = translations[1] * img.size[1]
1186
+ translations = (np.round(random.uniform(-max_dx, max_dx)),
1187
+ np.round(random.uniform(-max_dy, max_dy)))
1188
+ else:
1189
+ translations = (0, 0)
1190
+
1191
+ # scale
1192
+ if scale is not None:
1193
+ scale = random.uniform(scale[0], scale[1])
1194
+ else:
1195
+ scale = 1.0
1196
+
1197
+ # shear
1198
+ if shear is not None:
1199
+ if len(shear) == 2:
1200
+ shear = [random.uniform(shear[0], shear[1]), 0.]
1201
+ elif len(shear) == 4:
1202
+ shear = [random.uniform(shear[0], shear[1]),
1203
+ random.uniform(shear[2], shear[3])]
1204
+ else:
1205
+ shear = 0.0
1206
+
1207
+ output_size = img.size
1208
+ center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5)
1209
+
1210
+ angle = math.radians(angle)
1211
+ if isinstance(shear, (tuple, list)) and len(shear) == 2:
1212
+ shear = [math.radians(s) for s in shear]
1213
+ elif isinstance(shear, numbers.Number):
1214
+ shear = math.radians(shear)
1215
+ shear = [shear, 0]
1216
+ else:
1217
+ raise ValueError(
1218
+ "Shear should be a single value or a tuple/list containing " +
1219
+ "two values. Got {}.".format(shear))
1220
+
1221
+ scale = 1.0 / scale
1222
+
1223
+ # Inverted rotation matrix with scale and shear
1224
+ d = math.cos(angle + shear[0]) * math.cos(angle + shear[1]) + \
1225
+ math.sin(angle + shear[0]) * math.sin(angle + shear[1])
1226
+ matrix = [
1227
+ math.cos(angle + shear[0]), math.sin(angle + shear[0]), 0,
1228
+ -math.sin(angle + shear[1]), math.cos(angle + shear[1]), 0
1229
+ ]
1230
+ matrix = [scale / d * m for m in matrix]
1231
+
1232
+ # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
1233
+ matrix[2] += matrix[0] * (-center[0] - translations[0]) + matrix[1] * (-center[1] - translations[1])
1234
+ matrix[5] += matrix[3] * (-center[0] - translations[0]) + matrix[4] * (-center[1] - translations[1])
1235
+
1236
+ # Apply center translation: C * RSS^-1 * C^-1 * T^-1
1237
+ matrix[2] += center[0]
1238
+ matrix[5] += center[1]
1239
+
1240
+ # Ensure fill_value of type list (from serialize JSON support) is converted to type tuple
1241
+ kwarg_fill_value = tuple(fill_value) if isinstance(fill_value, list) else fill_value
1242
+
1243
+ if Image.__version__ >= '5':
1244
+ kwargs = {"fillcolor": kwarg_fill_value}
1245
+ else:
1246
+ kwargs = {}
1247
+ return img.transform(output_size, AFFINE, matrix, resample, **kwargs)
1248
+
1249
+
1250
+ def mix_up_single(batch_size, img, label, alpha=0.2):
1251
+ """
1252
+ Apply mix up transformation to image and label in single batch internal, One hot encoding should done before this.
1253
+
1254
+ Args:
1255
+ batch_size (int): The batch size of dataset.
1256
+ img (numpy.ndarray): NumPy image to be applied mix up transformation.
1257
+ label (numpy.ndarray): NumPy label to be applied mix up transformation.
1258
+ alpha (float): The mix up rate.
1259
+
1260
+ Returns:
1261
+ mix_img (numpy.ndarray): NumPy image after being applied mix up transformation.
1262
+ mix_label (numpy.ndarray): NumPy label after being applied mix up transformation.
1263
+ """
1264
+
1265
+ def cir_shift(data):
1266
+ index = list(range(1, batch_size)) + [0]
1267
+ data = data[index, ...]
1268
+ return data
1269
+
1270
+ lam = np.random.beta(alpha, alpha, batch_size)
1271
+ lam_img = lam.reshape((batch_size, 1, 1, 1))
1272
+ mix_img = lam_img * img + (1 - lam_img) * cir_shift(img)
1273
+
1274
+ lam_label = lam.reshape((batch_size, 1))
1275
+ mix_label = lam_label * label + (1 - lam_label) * cir_shift(label)
1276
+
1277
+ return mix_img, mix_label
1278
+
1279
+
1280
+ def mix_up_muti(tmp, batch_size, img, label, alpha=0.2):
1281
+ """
1282
+ Apply mix up transformation to image and label in continuous batch, one hot encoding should done before this.
1283
+
1284
+ Args:
1285
+ tmp (class object): mainly for saving the tmp parameter.
1286
+ batch_size (int): the batch size of dataset.
1287
+ img (numpy.ndarray): NumPy image to be applied mix up transformation.
1288
+ label (numpy.ndarray): NumPy label to be applied mix up transformation.
1289
+ alpha (float): refer to the mix up rate.
1290
+
1291
+ Returns:
1292
+ mix_img (numpy.ndarray): NumPy image after being applied mix up transformation.
1293
+ mix_label (numpy.ndarray): NumPy label after being applied mix up transformation.
1294
+ """
1295
+ lam = np.random.beta(alpha, alpha, batch_size)
1296
+ if tmp.is_first:
1297
+ lam = np.ones(batch_size)
1298
+ tmp.is_first = False
1299
+
1300
+ lam_img = lam.reshape((batch_size, 1, 1, 1))
1301
+ mix_img = lam_img * img + (1 - lam_img) * tmp.image
1302
+
1303
+ lam_label = lam.reshape(batch_size, 1)
1304
+ mix_label = lam_label * label + (1 - lam_label) * tmp.label
1305
+ tmp.image = mix_img
1306
+ tmp.label = mix_label
1307
+
1308
+ return mix_img, mix_label
1309
+
1310
+
1311
+ def rgb_to_bgr(np_rgb_img, is_hwc):
1312
+ """
1313
+ Convert RGB img to BGR img.
1314
+
1315
+ Args:
1316
+ np_rgb_img (numpy.ndarray): NumPy RGB image array of shape (H, W, C) or (C, H, W) to be converted.
1317
+ is_hwc (Bool): If ``True``, the shape of np_hsv_img is (H, W, C), otherwise must be (C, H, W).
1318
+
1319
+ Returns:
1320
+ np_bgr_img (numpy.ndarray), NumPy BGR image with same type of `np_rgb_img`.
1321
+ """
1322
+ if is_hwc:
1323
+ np_bgr_img = np_rgb_img[:, :, ::-1]
1324
+ else:
1325
+ np_bgr_img = np_rgb_img[::-1, :, :]
1326
+ return np_bgr_img
1327
+
1328
+
1329
+ def rgb_to_bgrs(np_rgb_imgs, is_hwc):
1330
+ """
1331
+ Convert RGB imgs to BGR imgs.
1332
+
1333
+ Args:
1334
+ np_rgb_imgs (numpy.ndarray): NumPy RGB images array of shape (H, W, C) or (N, H, W, C),
1335
+ or (C, H, W) or (N, C, H, W) to be converted.
1336
+ is_hwc (Bool): If True, the shape of np_rgb_imgs is (H, W, C) or (N, H, W, C);
1337
+ If False, the shape of np_rgb_imgs is (C, H, W) or (N, C, H, W).
1338
+
1339
+ Returns:
1340
+ np_bgr_imgs (numpy.ndarray), NumPy BGR images with same type of np_rgb_imgs.
1341
+ """
1342
+ if not is_numpy(np_rgb_imgs):
1343
+ raise TypeError("img should be NumPy image. Got {}".format(type(np_rgb_imgs)))
1344
+
1345
+ if not isinstance(is_hwc, bool):
1346
+ raise TypeError("is_hwc should be bool type. Got {}.".format(type(is_hwc)))
1347
+
1348
+ shape_size = len(np_rgb_imgs.shape)
1349
+
1350
+ if shape_size not in (3, 4):
1351
+ raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). "
1352
+ "Got {}.".format(np_rgb_imgs.shape))
1353
+
1354
+ if shape_size == 3:
1355
+ batch_size = 0
1356
+ if is_hwc:
1357
+ num_channels = np_rgb_imgs.shape[2]
1358
+ else:
1359
+ num_channels = np_rgb_imgs.shape[0]
1360
+ else:
1361
+ batch_size = np_rgb_imgs.shape[0]
1362
+ if is_hwc:
1363
+ num_channels = np_rgb_imgs.shape[3]
1364
+ else:
1365
+ num_channels = np_rgb_imgs.shape[1]
1366
+
1367
+ if num_channels != 3:
1368
+ raise TypeError("img should be 3 channels RGB img. Got {} channels.".format(num_channels))
1369
+ if batch_size == 0:
1370
+ return rgb_to_bgr(np_rgb_imgs, is_hwc)
1371
+ return np.array([rgb_to_bgr(img, is_hwc) for img in np_rgb_imgs])
1372
+
1373
+
1374
+ def rgb_to_hsv(np_rgb_img, is_hwc):
1375
+ """
1376
+ Convert RGB img to HSV img.
1377
+
1378
+ Args:
1379
+ np_rgb_img (numpy.ndarray): NumPy RGB image array of shape (H, W, C) or (C, H, W) to be converted.
1380
+ is_hwc (Bool): If ``True``, the shape of np_hsv_img is (H, W, C), otherwise must be (C, H, W).
1381
+
1382
+ Returns:
1383
+ np_hsv_img (numpy.ndarray), NumPy HSV image with same type of `np_rgb_img`.
1384
+ """
1385
+ if is_hwc:
1386
+ r, g, b = np_rgb_img[:, :, 0], np_rgb_img[:, :, 1], np_rgb_img[:, :, 2]
1387
+ else:
1388
+ r, g, b = np_rgb_img[0, :, :], np_rgb_img[1, :, :], np_rgb_img[2, :, :]
1389
+ to_hsv = np.vectorize(colorsys.rgb_to_hsv)
1390
+ h, s, v = to_hsv(r, g, b)
1391
+ if is_hwc:
1392
+ axis = 2
1393
+ else:
1394
+ axis = 0
1395
+ np_hsv_img = np.stack((h, s, v), axis=axis)
1396
+ return np_hsv_img
1397
+
1398
+
1399
+ def rgb_to_hsvs(np_rgb_imgs, is_hwc):
1400
+ """
1401
+ Convert RGB imgs to HSV imgs.
1402
+
1403
+ Args:
1404
+ np_rgb_imgs (numpy.ndarray): NumPy RGB images array of shape (H, W, C) or (N, H, W, C),
1405
+ or (C, H, W) or (N, C, H, W) to be converted.
1406
+ is_hwc (Bool): If ``True``, the shape of `np_rgb_imgs` is (H, W, C) or (N, H, W, C);
1407
+ If ``False``, the shape of `np_rgb_imgs` is (C, H, W) or (N, C, H, W).
1408
+
1409
+ Returns:
1410
+ np_hsv_imgs (numpy.ndarray), NumPy HSV images with same type of np_rgb_imgs.
1411
+ """
1412
+ if not is_numpy(np_rgb_imgs):
1413
+ raise TypeError("img should be NumPy image. Got {}".format(type(np_rgb_imgs)))
1414
+
1415
+ if not isinstance(is_hwc, bool):
1416
+ raise TypeError("is_hwc should be bool type. Got {}.".format(type(is_hwc)))
1417
+
1418
+ shape_size = len(np_rgb_imgs.shape)
1419
+
1420
+ if shape_size not in (3, 4):
1421
+ raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C ,H, W)/(N, C, H, W). "
1422
+ "Got {}.".format(np_rgb_imgs.shape))
1423
+
1424
+ if shape_size == 3:
1425
+ batch_size = 0
1426
+ if is_hwc:
1427
+ num_channels = np_rgb_imgs.shape[2]
1428
+ else:
1429
+ num_channels = np_rgb_imgs.shape[0]
1430
+ else:
1431
+ batch_size = np_rgb_imgs.shape[0]
1432
+ if is_hwc:
1433
+ num_channels = np_rgb_imgs.shape[3]
1434
+ else:
1435
+ num_channels = np_rgb_imgs.shape[1]
1436
+
1437
+ if num_channels != 3:
1438
+ raise TypeError("img should be 3 channels RGB img. Got {} channels.".format(num_channels))
1439
+ if batch_size == 0:
1440
+ return rgb_to_hsv(np_rgb_imgs, is_hwc)
1441
+ return np.array([rgb_to_hsv(img, is_hwc) for img in np_rgb_imgs])
1442
+
1443
+
1444
+ def hsv_to_rgb(np_hsv_img, is_hwc):
1445
+ """
1446
+ Convert HSV img to RGB img.
1447
+
1448
+ Args:
1449
+ np_hsv_img (numpy.ndarray): NumPy HSV image array of shape (H, W, C) or (C, H, W) to be converted.
1450
+ is_hwc (Bool): If ``True``, the shape of `np_hsv_img` is (H, W, C), otherwise must be (C, H, W).
1451
+
1452
+ Returns:
1453
+ np_rgb_img (numpy.ndarray), NumPy HSV image with same shape of `np_hsv_img`.
1454
+ """
1455
+ if is_hwc:
1456
+ h, s, v = np_hsv_img[:, :, 0], np_hsv_img[:, :, 1], np_hsv_img[:, :, 2]
1457
+ else:
1458
+ h, s, v = np_hsv_img[0, :, :], np_hsv_img[1, :, :], np_hsv_img[2, :, :]
1459
+ to_rgb = np.vectorize(colorsys.hsv_to_rgb)
1460
+ r, g, b = to_rgb(h, s, v)
1461
+
1462
+ if is_hwc:
1463
+ axis = 2
1464
+ else:
1465
+ axis = 0
1466
+ np_rgb_img = np.stack((r, g, b), axis=axis)
1467
+ return np_rgb_img
1468
+
1469
+
1470
+ def hsv_to_rgbs(np_hsv_imgs, is_hwc):
1471
+ """
1472
+ Convert HSV imgs to RGB imgs.
1473
+
1474
+ Args:
1475
+ np_hsv_imgs (numpy.ndarray): NumPy HSV images array of shape (H, W, C) or (N, H, W, C),
1476
+ or (C, H, W) or (N, C, H, W) to be converted.
1477
+ is_hwc (Bool): If ``True``, the shape of `np_hsv_imgs` is (H, W, C) or (N, H, W, C);
1478
+ If ``False``, the shape of `np_hsv_imgs` is (C, H, W) or (N, C, H, W).
1479
+
1480
+ Returns:
1481
+ np_rgb_imgs (numpy.ndarray), NumPy RGB images with same type of `np_hsv_imgs`.
1482
+ """
1483
+ if not is_numpy(np_hsv_imgs):
1484
+ raise TypeError("img should be NumPy image. Got {}.".format(type(np_hsv_imgs)))
1485
+
1486
+ if not isinstance(is_hwc, bool):
1487
+ raise TypeError("is_hwc should be bool type. Got {}.".format(type(is_hwc)))
1488
+
1489
+ shape_size = len(np_hsv_imgs.shape)
1490
+
1491
+ if shape_size not in (3, 4):
1492
+ raise TypeError("img shape should be (H, W, C)/(N, H, W, C)/(C, H, W)/(N, C, H, W). "
1493
+ "Got {}.".format(np_hsv_imgs.shape))
1494
+
1495
+ if shape_size == 3:
1496
+ batch_size = 0
1497
+ if is_hwc:
1498
+ num_channels = np_hsv_imgs.shape[2]
1499
+ else:
1500
+ num_channels = np_hsv_imgs.shape[0]
1501
+ else:
1502
+ batch_size = np_hsv_imgs.shape[0]
1503
+ if is_hwc:
1504
+ num_channels = np_hsv_imgs.shape[3]
1505
+ else:
1506
+ num_channels = np_hsv_imgs.shape[1]
1507
+
1508
+ if num_channels != 3:
1509
+ raise TypeError("img should be 3 channels RGB img. Got {} channels.".format(num_channels))
1510
+ if batch_size == 0:
1511
+ return hsv_to_rgb(np_hsv_imgs, is_hwc)
1512
+ return np.array([hsv_to_rgb(img, is_hwc) for img in np_hsv_imgs])
1513
+
1514
+
1515
+ def random_color(img, degrees):
1516
+ """
1517
+ Adjust the color of the input PIL Image by a random degree.
1518
+
1519
+ Args:
1520
+ img (PIL.Image.Image): Image to be color adjusted.
1521
+ degrees (sequence): Range of random color adjustment degrees.
1522
+ It should be in (min, max) format. Default: ``(0.1,1.9)``.
1523
+
1524
+ Returns:
1525
+ PIL.Image.Image, color adjusted image.
1526
+ """
1527
+
1528
+ if not is_pil(img):
1529
+ raise TypeError(augment_error_message.format(type(img)))
1530
+
1531
+ v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
1532
+ return ImageEnhance.Color(img).enhance(v)
1533
+
1534
+
1535
+ def random_sharpness(img, degrees):
1536
+ """
1537
+ Adjust the sharpness of the input PIL Image by a random degree.
1538
+
1539
+ Args:
1540
+ img (PIL.Image.Image): Image to be sharpness adjusted.
1541
+ degrees (sequence): Range of random sharpness adjustment degrees.
1542
+ It should be in (min, max) format. Default: ``(0.1,1.9)``.
1543
+
1544
+ Returns:
1545
+ PIL.Image.Image, sharpness adjusted image.
1546
+ """
1547
+
1548
+ if not is_pil(img):
1549
+ raise TypeError(augment_error_message.format(type(img)))
1550
+
1551
+ v = (degrees[1] - degrees[0]) * random.random() + degrees[0]
1552
+ return ImageEnhance.Sharpness(img).enhance(v)
1553
+
1554
+
1555
+ def adjust_gamma(img, gamma, gain):
1556
+ """
1557
+ Adjust gamma of the input PIL Image.
1558
+
1559
+ Args:
1560
+ img (PIL.Image.Image): Image to be augmented with AdjustGamma.
1561
+ gamma (float): Non negative real number, same as gamma in the equation.
1562
+ gain (float, optional): The constant multiplier.
1563
+
1564
+ Returns:
1565
+ PIL.Image.Image, augmented image.
1566
+
1567
+ """
1568
+
1569
+ if not is_pil(img):
1570
+ raise TypeError("img should be PIL image. Got {}.".format(type(img)))
1571
+
1572
+ gamma_table = [int((255 + 1 - 1e-3) * gain * pow(x / 255., gamma)) for x in range(256)]
1573
+ if len(img.split()) == 3:
1574
+ gamma_table = gamma_table * 3
1575
+ img = img.point(gamma_table)
1576
+ elif len(img.split()) == 1:
1577
+ img = img.point(gamma_table)
1578
+ return img
1579
+
1580
+
1581
+ def auto_contrast(img, cutoff, ignore):
1582
+ """
1583
+ Automatically maximize the contrast of the input PIL Image.
1584
+
1585
+ Args:
1586
+ img (PIL.Image): Image to be augmented with AutoContrast.
1587
+ cutoff (float, optional): Percent of pixels to cut off from the histogram. Default: ``0.0``.
1588
+ ignore (Union[int, Sequence[int]], optional): Pixel values to ignore. Default: ``None``.
1589
+
1590
+ Returns:
1591
+ PIL.Image, augmented image.
1592
+ """
1593
+
1594
+ if not is_pil(img):
1595
+ raise TypeError(augment_error_message.format(type(img)))
1596
+
1597
+ return ImageOps.autocontrast(img, cutoff, ignore)
1598
+
1599
+
1600
+ def invert_color(img):
1601
+ """
1602
+ Invert colors of input PIL Image.
1603
+
1604
+ Args:
1605
+ img (PIL.Image.Image): Image to be color inverted.
1606
+
1607
+ Returns:
1608
+ PIL.Image.Image, color inverted image.
1609
+
1610
+ """
1611
+
1612
+ if not is_pil(img):
1613
+ raise TypeError(augment_error_message.format(type(img)))
1614
+
1615
+ return ImageOps.invert(img)
1616
+
1617
+
1618
+ def equalize(img):
1619
+ """
1620
+ Equalize the histogram of input PIL Image.
1621
+
1622
+ Args:
1623
+ img (PIL.Image.Image): Image to be equalized
1624
+
1625
+ Returns:
1626
+ PIL.Image.Image, equalized image.
1627
+
1628
+ """
1629
+
1630
+ if not is_pil(img):
1631
+ raise TypeError(augment_error_message.format(type(img)))
1632
+
1633
+ return ImageOps.equalize(img)
1634
+
1635
+
1636
+ def uniform_augment(img, transforms, num_ops):
1637
+ """
1638
+ Uniformly select and apply a number of transforms sequentially from
1639
+ a list of transforms. Randomly assigns a probability to each transform for
1640
+ each image to decide whether apply it or not.
1641
+ All the transforms in transform list must have the same input/output data type.
1642
+
1643
+ Args:
1644
+ img: Image to be applied transformation.
1645
+ transforms (list): List of transformations to be chosen from to apply.
1646
+ num_ops (int): number of transforms to sequentially aaply.
1647
+
1648
+ Returns:
1649
+ img, Transformed image.
1650
+
1651
+ """
1652
+
1653
+ op_idx = np.random.choice(len(transforms), size=num_ops, replace=False)
1654
+ for idx in op_idx:
1655
+ augment_op = transforms[idx]
1656
+ pr = random.random()
1657
+ if random.random() < pr:
1658
+ img = augment_op(img.copy())
1659
+
1660
+ return img